Commit 1c7ec449 authored by Ömer Sinan Ağacan's avatar Ömer Sinan Ağacan

Return results of Cmm streams in backends

This generalizes code generators (outputAsm, outputLlvm, outputC, and
the call site codeOutput) so that they'll return the return values of
the passed Cmm streams.

This allows accumulating data during Cmm generation and returning it to
the call site in HscMain.

Previously the Cmm streams were assumed to return (), so the code
generators returned () as well.

This change is required by !1304 and !1530.

Skipping CI as this was tested before and I only updated the commit
message.

[skip ci]
parent ee2fad9e
...@@ -67,16 +67,17 @@ mkEmptyContInfoTable info_lbl ...@@ -67,16 +67,17 @@ mkEmptyContInfoTable info_lbl
, cit_srt = Nothing , cit_srt = Nothing
, cit_clo = Nothing } , cit_clo = Nothing }
cmmToRawCmm :: DynFlags -> Stream IO CmmGroup () cmmToRawCmm :: DynFlags -> Stream IO CmmGroup a
-> IO (Stream IO RawCmmGroup ()) -> IO (Stream IO RawCmmGroup a)
cmmToRawCmm dflags cmms cmmToRawCmm dflags cmms
= do { uniqs <- mkSplitUniqSupply 'i' = do { uniqs <- mkSplitUniqSupply 'i'
; let do_one uniqs cmm = ; let do_one :: UniqSupply -> [CmmDecl] -> IO (UniqSupply, [RawCmmDecl])
do_one uniqs cmm =
-- NB. strictness fixes a space leak. DO NOT REMOVE. -- NB. strictness fixes a space leak. DO NOT REMOVE.
withTiming (return dflags) (text "Cmm -> Raw Cmm") forceRes $ withTiming (return dflags) (text "Cmm -> Raw Cmm") forceRes $
case initUs uniqs $ concatMapM (mkInfoTable dflags) cmm of case initUs uniqs $ concatMapM (mkInfoTable dflags) cmm of
(b,uniqs') -> return (uniqs',b) (b,uniqs') -> return (uniqs',b)
; return (Stream.mapAccumL do_one uniqs cmms >> return ()) ; return (snd <$> Stream.mapAccumL_ do_one uniqs cmms)
} }
where forceRes (uniqs, rawcmms) = where forceRes (uniqs, rawcmms) =
......
...@@ -42,8 +42,8 @@ import System.IO ...@@ -42,8 +42,8 @@ import System.IO
-- | Top-level of the LLVM Code generator -- | Top-level of the LLVM Code generator
-- --
llvmCodeGen :: DynFlags -> Handle -> UniqSupply llvmCodeGen :: DynFlags -> Handle -> UniqSupply
-> Stream.Stream IO RawCmmGroup () -> Stream.Stream IO RawCmmGroup a
-> IO () -> IO a
llvmCodeGen dflags h us cmm_stream llvmCodeGen dflags h us cmm_stream
= withTiming (pure dflags) (text "LLVM CodeGen") (const ()) $ do = withTiming (pure dflags) (text "LLVM CodeGen") (const ()) $ do
bufh <- newBufHandle h bufh <- newBufHandle h
...@@ -66,12 +66,14 @@ llvmCodeGen dflags h us cmm_stream ...@@ -66,12 +66,14 @@ llvmCodeGen dflags h us cmm_stream
$+$ text "We will try though...") $+$ text "We will try though...")
-- run code generation -- run code generation
runLlvm dflags ver bufh us $ a <- runLlvm dflags ver bufh us $
llvmCodeGen' (liftStream cmm_stream) llvmCodeGen' (liftStream cmm_stream)
bFlush bufh bFlush bufh
llvmCodeGen' :: Stream.Stream LlvmM RawCmmGroup () -> LlvmM () return a
llvmCodeGen' :: Stream.Stream LlvmM RawCmmGroup a -> LlvmM a
llvmCodeGen' cmm_stream llvmCodeGen' cmm_stream
= do -- Preamble = do -- Preamble
renderLlvm header renderLlvm header
...@@ -79,13 +81,15 @@ llvmCodeGen' cmm_stream ...@@ -79,13 +81,15 @@ llvmCodeGen' cmm_stream
cmmMetaLlvmPrelude cmmMetaLlvmPrelude
-- Procedures -- Procedures
() <- Stream.consume cmm_stream llvmGroupLlvmGens a <- Stream.consume cmm_stream llvmGroupLlvmGens
-- Declare aliases for forward references -- Declare aliases for forward references
renderLlvm . pprLlvmData =<< generateExternDecls renderLlvm . pprLlvmData =<< generateExternDecls
-- Postamble -- Postamble
cmmUsedLlvmGens cmmUsedLlvmGens
return a
where where
header :: SDoc header :: SDoc
header = sdocWithDynFlags $ \dflags -> header = sdocWithDynFlags $ \dflags ->
......
...@@ -253,10 +253,10 @@ liftIO m = LlvmM $ \env -> do x <- m ...@@ -253,10 +253,10 @@ liftIO m = LlvmM $ \env -> do x <- m
return (x, env) return (x, env)
-- | Get initial Llvm environment. -- | Get initial Llvm environment.
runLlvm :: DynFlags -> LlvmVersion -> BufHandle -> UniqSupply -> LlvmM () -> IO () runLlvm :: DynFlags -> LlvmVersion -> BufHandle -> UniqSupply -> LlvmM a -> IO a
runLlvm dflags ver out us m = do runLlvm dflags ver out us m = do
_ <- runLlvmM m env (a, _) <- runLlvmM m env
return () return a
where env = LlvmEnv { envFunMap = emptyUFM where env = LlvmEnv { envFunMap = emptyUFM
, envVarMap = emptyUFM , envVarMap = emptyUFM
, envStackRegs = [] , envStackRegs = []
......
...@@ -54,10 +54,11 @@ codeOutput :: DynFlags ...@@ -54,10 +54,11 @@ codeOutput :: DynFlags
-> [(ForeignSrcLang, FilePath)] -> [(ForeignSrcLang, FilePath)]
-- ^ additional files to be compiled with with the C compiler -- ^ additional files to be compiled with with the C compiler
-> [InstalledUnitId] -> [InstalledUnitId]
-> Stream IO RawCmmGroup () -- Compiled C-- -> Stream IO RawCmmGroup a -- Compiled C--
-> IO (FilePath, -> IO (FilePath,
(Bool{-stub_h_exists-}, Maybe FilePath{-stub_c_exists-}), (Bool{-stub_h_exists-}, Maybe FilePath{-stub_c_exists-}),
[(ForeignSrcLang, FilePath)]{-foreign_fps-}) [(ForeignSrcLang, FilePath)]{-foreign_fps-},
a)
codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps
cmm_stream cmm_stream
...@@ -87,15 +88,14 @@ codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps ...@@ -87,15 +88,14 @@ codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps
} }
; stubs_exist <- outputForeignStubs dflags this_mod location foreign_stubs ; stubs_exist <- outputForeignStubs dflags this_mod location foreign_stubs
; case hscTarget dflags of { ; a <- case hscTarget dflags of
HscAsm -> outputAsm dflags this_mod location filenm HscAsm -> outputAsm dflags this_mod location filenm
linted_cmm_stream; linted_cmm_stream
HscC -> outputC dflags filenm linted_cmm_stream pkg_deps; HscC -> outputC dflags filenm linted_cmm_stream pkg_deps
HscLlvm -> outputLlvm dflags filenm linted_cmm_stream; HscLlvm -> outputLlvm dflags filenm linted_cmm_stream
HscInterpreted -> panic "codeOutput: HscInterpreted"; HscInterpreted -> panic "codeOutput: HscInterpreted"
HscNothing -> panic "codeOutput: HscNothing" HscNothing -> panic "codeOutput: HscNothing"
} ; return (filenm, stubs_exist, foreign_fps, a)
; return (filenm, stubs_exist, foreign_fps)
} }
doOutput :: String -> (Handle -> IO a) -> IO a doOutput :: String -> (Handle -> IO a) -> IO a
...@@ -111,13 +111,13 @@ doOutput filenm io_action = bracket (openFile filenm WriteMode) hClose io_action ...@@ -111,13 +111,13 @@ doOutput filenm io_action = bracket (openFile filenm WriteMode) hClose io_action
outputC :: DynFlags outputC :: DynFlags
-> FilePath -> FilePath
-> Stream IO RawCmmGroup () -> Stream IO RawCmmGroup a
-> [InstalledUnitId] -> [InstalledUnitId]
-> IO () -> IO a
outputC dflags filenm cmm_stream packages outputC dflags filenm cmm_stream packages
= do = do
withTiming (return dflags) (text "C codegen") id $ do withTiming (return dflags) (text "C codegen") (\a -> seq a () {- FIXME -}) $ do
-- figure out which header files to #include in the generated .hc file: -- figure out which header files to #include in the generated .hc file:
-- --
...@@ -150,18 +150,17 @@ outputC dflags filenm cmm_stream packages ...@@ -150,18 +150,17 @@ outputC dflags filenm cmm_stream packages
-} -}
outputAsm :: DynFlags -> Module -> ModLocation -> FilePath outputAsm :: DynFlags -> Module -> ModLocation -> FilePath
-> Stream IO RawCmmGroup () -> Stream IO RawCmmGroup a
-> IO () -> IO a
outputAsm dflags this_mod location filenm cmm_stream outputAsm dflags this_mod location filenm cmm_stream
| platformMisc_ghcWithNativeCodeGen $ platformMisc dflags | platformMisc_ghcWithNativeCodeGen $ platformMisc dflags
= do ncg_uniqs <- mkSplitUniqSupply 'n' = do ncg_uniqs <- mkSplitUniqSupply 'n'
debugTraceMsg dflags 4 (text "Outputing asm to" <+> text filenm) debugTraceMsg dflags 4 (text "Outputing asm to" <+> text filenm)
_ <- {-# SCC "OutputAsm" #-} doOutput filenm $ {-# SCC "OutputAsm" #-} doOutput filenm $
\h -> {-# SCC "NativeCodeGen" #-} \h -> {-# SCC "NativeCodeGen" #-}
nativeCodeGen dflags this_mod location h ncg_uniqs cmm_stream nativeCodeGen dflags this_mod location h ncg_uniqs cmm_stream
return ()
| otherwise | otherwise
= panic "This compiler was built without a native code generator" = panic "This compiler was built without a native code generator"
...@@ -174,7 +173,7 @@ outputAsm dflags this_mod location filenm cmm_stream ...@@ -174,7 +173,7 @@ outputAsm dflags this_mod location filenm cmm_stream
************************************************************************ ************************************************************************
-} -}
outputLlvm :: DynFlags -> FilePath -> Stream IO RawCmmGroup () -> IO () outputLlvm :: DynFlags -> FilePath -> Stream IO RawCmmGroup a -> IO a
outputLlvm dflags filenm cmm_stream outputLlvm dflags filenm cmm_stream
= do ncg_uniqs <- mkSplitUniqSupply 'n' = do ncg_uniqs <- mkSplitUniqSupply 'n'
......
...@@ -1426,7 +1426,7 @@ hscGenHardCode hsc_env cgguts mod_summary output_filename = do ...@@ -1426,7 +1426,7 @@ hscGenHardCode hsc_env cgguts mod_summary output_filename = do
return a return a
rawcmms1 = Stream.mapM dump rawcmms0 rawcmms1 = Stream.mapM dump rawcmms0
(output_filename, (_stub_h_exists, stub_c_exists), foreign_fps) (output_filename, (_stub_h_exists, stub_c_exists), foreign_fps, ())
<- {-# SCC "codeOutput" #-} <- {-# SCC "codeOutput" #-}
codeOutput dflags this_mod output_filename location codeOutput dflags this_mod output_filename location
foreign_stubs foreign_files dependencies rawcmms1 foreign_stubs foreign_files dependencies rawcmms1
......
...@@ -157,14 +157,14 @@ The machine-dependent bits break down as follows: ...@@ -157,14 +157,14 @@ The machine-dependent bits break down as follows:
-} -}
-------------------- --------------------
nativeCodeGen :: DynFlags -> Module -> ModLocation -> Handle -> UniqSupply nativeCodeGen :: forall a . DynFlags -> Module -> ModLocation -> Handle -> UniqSupply
-> Stream IO RawCmmGroup () -> Stream IO RawCmmGroup a
-> IO UniqSupply -> IO a
nativeCodeGen dflags this_mod modLoc h us cmms nativeCodeGen dflags this_mod modLoc h us cmms
= let platform = targetPlatform dflags = let platform = targetPlatform dflags
nCG' :: ( Outputable statics, Outputable instr nCG' :: ( Outputable statics, Outputable instr
, Outputable jumpDest, Instruction instr) , Outputable jumpDest, Instruction instr)
=> NcgImpl statics instr jumpDest -> IO UniqSupply => NcgImpl statics instr jumpDest -> IO a
nCG' ncgImpl = nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms nCG' ncgImpl = nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
in case platformArch platform of in case platformArch platform of
ArchX86 -> nCG' (x86NcgImpl dflags) ArchX86 -> nCG' (x86NcgImpl dflags)
...@@ -314,8 +314,8 @@ nativeCodeGen' :: (Outputable statics, Outputable instr,Outputable jumpDest, ...@@ -314,8 +314,8 @@ nativeCodeGen' :: (Outputable statics, Outputable instr,Outputable jumpDest,
-> NcgImpl statics instr jumpDest -> NcgImpl statics instr jumpDest
-> Handle -> Handle
-> UniqSupply -> UniqSupply
-> Stream IO RawCmmGroup () -> Stream IO RawCmmGroup a
-> IO UniqSupply -> IO a
nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
= do = do
-- BufHandle is a performance hack. We could hide it inside -- BufHandle is a performance hack. We could hide it inside
...@@ -323,9 +323,10 @@ nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms ...@@ -323,9 +323,10 @@ nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
-- printDocs here (in order to do codegen in constant space). -- printDocs here (in order to do codegen in constant space).
bufh <- newBufHandle h bufh <- newBufHandle h
let ngs0 = NGS [] [] [] [] [] [] emptyUFM mapEmpty let ngs0 = NGS [] [] [] [] [] [] emptyUFM mapEmpty
(ngs, us') <- cmmNativeGenStream dflags this_mod modLoc ncgImpl bufh us (ngs, us', a) <- cmmNativeGenStream dflags this_mod modLoc ncgImpl bufh us
cmms ngs0 cmms ngs0
finishNativeGen dflags modLoc bufh us' ngs _ <- finishNativeGen dflags modLoc bufh us' ngs
return a
finishNativeGen :: Instruction instr finishNativeGen :: Instruction instr
=> DynFlags => DynFlags
...@@ -386,20 +387,21 @@ cmmNativeGenStream :: (Outputable statics, Outputable instr ...@@ -386,20 +387,21 @@ cmmNativeGenStream :: (Outputable statics, Outputable instr
-> NcgImpl statics instr jumpDest -> NcgImpl statics instr jumpDest
-> BufHandle -> BufHandle
-> UniqSupply -> UniqSupply
-> Stream IO RawCmmGroup () -> Stream IO RawCmmGroup a
-> NativeGenAcc statics instr -> NativeGenAcc statics instr
-> IO (NativeGenAcc statics instr, UniqSupply) -> IO (NativeGenAcc statics instr, UniqSupply, a)
cmmNativeGenStream dflags this_mod modLoc ncgImpl h us cmm_stream ngs cmmNativeGenStream dflags this_mod modLoc ncgImpl h us cmm_stream ngs
= do r <- Stream.runStream cmm_stream = do r <- Stream.runStream cmm_stream
case r of case r of
Left () -> Left a ->
return (ngs { ngs_imports = reverse $ ngs_imports ngs return (ngs { ngs_imports = reverse $ ngs_imports ngs
, ngs_natives = reverse $ ngs_natives ngs , ngs_natives = reverse $ ngs_natives ngs
, ngs_colorStats = reverse $ ngs_colorStats ngs , ngs_colorStats = reverse $ ngs_colorStats ngs
, ngs_linearStats = reverse $ ngs_linearStats ngs , ngs_linearStats = reverse $ ngs_linearStats ngs
}, },
us) us,
a)
Right (cmms, cmm_stream') -> do Right (cmms, cmm_stream') -> do
(us', ngs'') <- (us', ngs'') <-
withTiming (return dflags) withTiming (return dflags)
......
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
-- ----------------------------------------------------------------------------- -- -----------------------------------------------------------------------------
module Stream ( module Stream (
Stream(..), yield, liftIO, Stream(..), yield, liftIO,
collect, consume, fromList, collect, collect_, consume, fromList,
Stream.map, Stream.mapM, Stream.mapAccumL Stream.map, Stream.mapM, Stream.mapAccumL, Stream.mapAccumL_
) where ) where
import GhcPrelude import GhcPrelude
...@@ -71,6 +71,16 @@ collect str = go str [] ...@@ -71,6 +71,16 @@ collect str = go str []
Left () -> return (reverse acc) Left () -> return (reverse acc)
Right (a, str') -> go str' (a:acc) Right (a, str') -> go str' (a:acc)
-- | Turn a Stream into an ordinary list, by demanding all the elements.
collect_ :: Monad m => Stream m a r -> m ([a], r)
collect_ str = go str []
where
go str acc = do
r <- runStream str
case r of
Left r -> return (reverse acc, r)
Right (a, str') -> go str' (a:acc)
consume :: Monad m => Stream m a b -> (a -> m ()) -> m b consume :: Monad m => Stream m a b -> (a -> m ()) -> m b
consume str f = do consume str f = do
r <- runStream str r <- runStream str
...@@ -113,3 +123,13 @@ mapAccumL f c str = Stream $ do ...@@ -113,3 +123,13 @@ mapAccumL f c str = Stream $ do
Right (a, str') -> do Right (a, str') -> do
(c',b) <- f c a (c',b) <- f c a
return (Right (b, mapAccumL f c' str')) return (Right (b, mapAccumL f c' str'))
mapAccumL_ :: Monad m => (c -> a -> m (c,b)) -> c -> Stream m a r
-> Stream m b (c, r)
mapAccumL_ f c str = Stream $ do
r <- runStream str
case r of
Left r -> return (Left (c, r))
Right (a, str') -> do
(c',b) <- f c a
return (Right (b, mapAccumL_ f c' str'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment