From e3d1bc27902a074c1430d4afde79a3db000db076 Mon Sep 17 00:00:00 2001
From: Cheng Shao <terrorjack@type.dance>
Date: Tue, 1 Oct 2024 19:21:54 +0000
Subject: [PATCH] compiler: add PIC support to wasm backend NCG

This commit adds support for generating PIC to the wasm backend NCG.

(cherry picked from commit 61f5baa5bd6e8d0daa20af4dc7c3213a48f99019)
---
 compiler/GHC/CmmToAsm/Wasm.hs          |  7 +--
 compiler/GHC/CmmToAsm/Wasm/Asm.hs      | 79 +++++++++++++++++---------
 compiler/GHC/CmmToAsm/Wasm/Types.hs    | 38 ++++++++++++-
 compiler/GHC/Driver/Config/CmmToAsm.hs |  3 +-
 4 files changed, 93 insertions(+), 34 deletions(-)

diff --git a/compiler/GHC/CmmToAsm/Wasm.hs b/compiler/GHC/CmmToAsm/Wasm.hs
index 511af517a68..faa611e6a16 100644
--- a/compiler/GHC/CmmToAsm/Wasm.hs
+++ b/compiler/GHC/CmmToAsm/Wasm.hs
@@ -40,12 +40,11 @@ ncgWasm ::
 ncgWasm ncg_config logger platform ts us loc h cmms = do
   (r, s) <- streamCmmGroups ncg_config platform us cmms
   outputWasm $ "# " <> string7 (fromJust $ ml_hs_file loc) <> "\n\n"
-  outputWasm $ execWasmAsmM do_tail_call $ asmTellEverything TagI32 s
+  -- See Note [WasmTailCall]
+  let cfg = (defaultWasmAsmConfig s) { pic = ncgPIC ncg_config, tailcall = doTailCall ts }
+  outputWasm $ execWasmAsmM cfg $ asmTellEverything TagI32 s
   pure r
   where
-    -- See Note [WasmTailCall]
-    do_tail_call = doTailCall ts
-
     outputWasm builder = do
       putDumpFileMaybe
         logger
diff --git a/compiler/GHC/CmmToAsm/Wasm/Asm.hs b/compiler/GHC/CmmToAsm/Wasm/Asm.hs
index cc28b65e946..b03d203dcab 100644
--- a/compiler/GHC/CmmToAsm/Wasm/Asm.hs
+++ b/compiler/GHC/CmmToAsm/Wasm/Asm.hs
@@ -35,13 +35,13 @@ import GHC.Utils.Outputable hiding ((<>))
 import GHC.Utils.Panic (panic)
 
 -- | Reads current indentation, appends result to state
-newtype WasmAsmM a = WasmAsmM (Bool -> Builder -> State Builder a)
+newtype WasmAsmM a = WasmAsmM (WasmAsmConfig -> Builder -> State Builder a)
   deriving
     ( Functor,
       Applicative,
       Monad
     )
-    via (ReaderT Bool (ReaderT Builder (State Builder)))
+    via (ReaderT WasmAsmConfig (ReaderT Builder (State Builder)))
 
 instance Semigroup a => Semigroup (WasmAsmM a) where
   (<>) = liftA2 (<>)
@@ -49,19 +49,18 @@ instance Semigroup a => Semigroup (WasmAsmM a) where
 instance Monoid a => Monoid (WasmAsmM a) where
   mempty = pure mempty
 
--- | To tail call or not, that is the question
-doTailCall :: WasmAsmM Bool
-doTailCall = WasmAsmM $ \do_tail_call _ -> pure do_tail_call
+getConf :: WasmAsmM WasmAsmConfig
+getConf = WasmAsmM $ \conf _ -> pure conf
 
 -- | Default indent level is none
-execWasmAsmM :: Bool -> WasmAsmM a -> Builder
-execWasmAsmM do_tail_call (WasmAsmM m) =
-  execState (m do_tail_call mempty) mempty
+execWasmAsmM :: WasmAsmConfig -> WasmAsmM a -> Builder
+execWasmAsmM conf (WasmAsmM m) =
+  execState (m conf mempty) mempty
 
 -- | Increase indent level by a tab
 asmWithTab :: WasmAsmM a -> WasmAsmM a
 asmWithTab (WasmAsmM m) =
-  WasmAsmM $ \do_tail_call t -> m do_tail_call $! char7 '\t' <> t
+  WasmAsmM $ \conf t -> m conf $! char7 '\t' <> t
 
 -- | Writes a single line starting with the current indent
 asmTellLine :: Builder -> WasmAsmM ()
@@ -113,7 +112,8 @@ asmFromSymName = shortByteString . coerce fastStringToShortByteString
 
 asmTellDefSym :: SymName -> WasmAsmM ()
 asmTellDefSym sym = do
-  asmTellTabLine $ ".hidden " <> asm_sym
+  WasmAsmConfig {..} <- getConf
+  unless pic $ asmTellTabLine $ ".hidden " <> asm_sym
   asmTellTabLine $ ".globl " <> asm_sym
   where
     asm_sym = asmFromSymName sym
@@ -136,7 +136,7 @@ asmTellDataSectionContent ty_word c = asmTellTabLine $ case c of
       <> ( case compare o 0 of
              EQ -> mempty
              GT -> "+" <> intDec o
-             LT -> intDec o
+             LT -> panic "asmTellDataSectionContent: negative offset"
          )
   DataSkip i -> ".skip " <> intDec i
   DataASCII s
@@ -245,14 +245,27 @@ asmTellWasmInstr ty_word instr = case instr of
   WasmConst TagI32 i -> asmTellLine $ "i32.const " <> integerDec i
   WasmConst TagI64 i -> asmTellLine $ "i64.const " <> integerDec i
   WasmConst {} -> panic "asmTellWasmInstr: unreachable"
-  WasmSymConst sym ->
-    asmTellLine $
-      ( case ty_word of
-          TagI32 -> "i32.const "
-          TagI64 -> "i64.const "
-          _ -> panic "asmTellWasmInstr: unreachable"
-      )
-        <> asmFromSymName sym
+  WasmSymConst sym -> do
+    WasmAsmConfig {..} <- getConf
+    let
+      asm_sym = asmFromSymName sym
+      (ty_const, ty_add) = case ty_word of
+        TagI32 -> ("i32.const ", "i32.add")
+        TagI64 -> ("i64.const ", "i64.add")
+        _ -> panic "asmTellWasmInstr: invalid word type"
+    traverse_ asmTellLine $ if
+      | pic, getKey (getUnique sym) `WS.member` mbrelSyms -> [
+          "global.get __memory_base",
+          ty_const <> asm_sym <> "@MBREL",
+          ty_add
+        ]
+      | pic, getKey (getUnique sym) `WS.member` tbrelSyms -> [
+          "global.get __table_base",
+          ty_const <> asm_sym <> "@TBREL",
+          ty_add
+        ]
+      | pic -> [ "global.get " <> asm_sym <> "@GOT" ]
+      | otherwise -> [ ty_const <> asm_sym ]
   WasmLoad ty (Just w) s o align ->
     asmTellLine $
       asmFromWasmType ty
@@ -398,12 +411,12 @@ asmTellWasmControl ty_word c = case c of
     asmTellLine $ "br_table {" <> builderCommas intDec (ts <> [t]) <> "}"
   -- See Note [WasmTailCall]
   WasmTailCall (WasmExpr e) -> do
-    do_tail_call <- doTailCall
+    WasmAsmConfig {..} <- getConf
     if
-        | do_tail_call,
+        | tailcall,
           WasmSymConst sym <- e ->
             asmTellLine $ "return_call " <> asmFromSymName sym
-        | do_tail_call ->
+        | tailcall ->
             do
               asmTellWasmInstr ty_word e
               asmTellLine $
@@ -440,13 +453,25 @@ asmTellFunc ty_word def_syms sym (func_ty, FuncBody {..}) = do
 
 asmTellGlobals :: WasmTypeTag w -> WasmAsmM ()
 asmTellGlobals ty_word = do
+  WasmAsmConfig {..} <- getConf
+  when pic $ traverse_ asmTellTabLine [
+      ".globaltype __memory_base, i32, immutable",
+      ".globaltype __table_base, i32, immutable"
+    ]
   for_ supportedCmmGlobalRegs $ \reg ->
-    let (sym, ty) = fromJust $ globalInfoFromCmmGlobalReg ty_word reg
-     in asmTellTabLine $
+    let
+      (sym, ty) = fromJust $ globalInfoFromCmmGlobalReg ty_word reg
+      asm_sym = asmFromSymName sym
+     in do
+      asmTellTabLine $
           ".globaltype "
-            <> asmFromSymName sym
+            <> asm_sym
             <> ", "
             <> asmFromSomeWasmType ty
+      when pic $ traverse_ asmTellTabLine [
+          ".import_module " <> asm_sym <> ", regs",
+          ".import_name " <> asm_sym <> ", " <> asm_sym
+        ]
   asmTellLF
 
 asmTellCtors :: WasmTypeTag w -> [SymName] -> WasmAsmM ()
@@ -494,14 +519,14 @@ asmTellProducers = do
 
 asmTellTargetFeatures :: WasmAsmM ()
 asmTellTargetFeatures = do
-  do_tail_call <- doTailCall
+  WasmAsmConfig {..} <- getConf
   asmTellSectionHeader ".custom_section.target_features"
   asmTellVec
     [ do
         asmTellTabLine ".int8 0x2b"
         asmTellBS feature
       | feature <-
-          ["tail-call" | do_tail_call]
+          ["tail-call" | tailcall]
             <> [ "bulk-memory",
                  "mutable-globals",
                  "nontrapping-fptoint",
diff --git a/compiler/GHC/CmmToAsm/Wasm/Types.hs b/compiler/GHC/CmmToAsm/Wasm/Types.hs
index e00c0352601..900cdbaa958 100644
--- a/compiler/GHC/CmmToAsm/Wasm/Types.hs
+++ b/compiler/GHC/CmmToAsm/Wasm/Types.hs
@@ -45,6 +45,8 @@ module GHC.CmmToAsm.Wasm.Types
     wasmStateM,
     wasmModifyM,
     wasmExecM,
+    WasmAsmConfig (..),
+    defaultWasmAsmConfig
   )
 where
 
@@ -136,7 +138,9 @@ data SymVisibility
     SymStatic
   | -- | Defined, visible to other compilation units.
     --
-    -- Adds @.hidden@ & @.globl@ directives in the output assembly.
+    -- Adds @.globl@ directives in the output assembly. Also adds
+    -- @.hidden@ when not generating PIC code, similar to
+    -- -fvisibility=hidden in clang.
     --
     -- @[ binding=global vis=hidden ]@
     SymDefault
@@ -480,3 +484,35 @@ instance MonadUnique (WasmCodeGenM w) where
     u <- getUniqueM
     s <- WasmCodeGenM get
     pure $ u:(wasmEvalM getUniquesM s)
+
+data WasmAsmConfig = WasmAsmConfig
+  {
+    pic, tailcall :: Bool,
+    -- | Data/function symbols with 'SymStatic' visibility (defined
+    -- but not visible to other compilation units). When doing PIC
+    -- codegen, private symbols must be emitted as @MBREL@/@TBREL@
+    -- relocations in the code section. The public symbols, defined or
+    -- elsewhere, are all emitted as @GOT@ relocations instead.
+    mbrelSyms, tbrelSyms :: ~SymSet
+  }
+
+-- | The default 'WasmAsmConfig' must be extracted from the final
+-- 'WasmCodeGenState'.
+defaultWasmAsmConfig :: WasmCodeGenState w -> WasmAsmConfig
+defaultWasmAsmConfig WasmCodeGenState {..} =
+  WasmAsmConfig
+    { pic = False,
+      tailcall = False,
+      mbrelSyms = mk_rel_syms dataSections,
+      tbrelSyms = mk_rel_syms funcBodies
+    }
+  where
+    mk_rel_syms :: SymMap a -> SymSet
+    mk_rel_syms =
+      nonDetFoldUniqMap
+        ( \(sym, _) acc ->
+            if getKey (getUnique sym) `WS.member` defaultSyms
+              then acc
+              else WS.insert (getKey (getUnique sym)) acc
+        )
+        WS.empty
diff --git a/compiler/GHC/Driver/Config/CmmToAsm.hs b/compiler/GHC/Driver/Config/CmmToAsm.hs
index e3452c117df..2ba14caf4bd 100644
--- a/compiler/GHC/Driver/Config/CmmToAsm.hs
+++ b/compiler/GHC/Driver/Config/CmmToAsm.hs
@@ -21,8 +21,7 @@ initNCGConfig dflags this_mod = NCGConfig
    , ncgAsmContext            = initSDocContext dflags PprCode
    , ncgProcAlignment         = cmmProcAlignment dflags
    , ncgExternalDynamicRefs   = gopt Opt_ExternalDynamicRefs dflags
-   -- no PIC on wasm32 for now
-   , ncgPIC                   = positionIndependent dflags && not (platformArch (targetPlatform dflags) == ArchWasm32)
+   , ncgPIC                   = positionIndependent dflags
    , ncgInlineThresholdMemcpy = fromIntegral $ maxInlineMemcpyInsns dflags
    , ncgInlineThresholdMemset = fromIntegral $ maxInlineMemsetInsns dflags
    , ncgSplitSections         = gopt Opt_SplitSections dflags
-- 
GitLab