Commit 2bb19fad authored by Joachim Breitner's avatar Joachim Breitner

Make worker-wrapper unbox data families

by passing the FamInstEnvs all the way down. This closes #7619.
parent fe3740bd
...@@ -39,7 +39,7 @@ module HscTypes ( ...@@ -39,7 +39,7 @@ module HscTypes (
PackageTypeEnv, PackageIfaceTable, emptyPackageIfaceTable, PackageTypeEnv, PackageIfaceTable, emptyPackageIfaceTable,
lookupIfaceByModule, emptyModIface, lookupIfaceByModule, emptyModIface,
PackageInstEnv, PackageRuleBase, PackageInstEnv, PackageFamInstEnv, PackageRuleBase,
mkSOName, mkHsSOName, soExt, mkSOName, mkHsSOName, soExt,
......
...@@ -34,7 +34,7 @@ module CoreMonad ( ...@@ -34,7 +34,7 @@ module CoreMonad (
-- ** Reading from the monad -- ** Reading from the monad
getHscEnv, getRuleBase, getModule, getHscEnv, getRuleBase, getModule,
getDynFlags, getOrigNameCache, getDynFlags, getOrigNameCache, getPackageFamInstEnv,
-- ** Writing to the monad -- ** Writing to the monad
addSimplCount, addSimplCount,
...@@ -953,6 +953,12 @@ getOrigNameCache :: CoreM OrigNameCache ...@@ -953,6 +953,12 @@ getOrigNameCache :: CoreM OrigNameCache
getOrigNameCache = do getOrigNameCache = do
nameCacheRef <- fmap hsc_NC getHscEnv nameCacheRef <- fmap hsc_NC getHscEnv
liftIO $ fmap nsNames $ readIORef nameCacheRef liftIO $ fmap nsNames $ readIORef nameCacheRef
getPackageFamInstEnv :: CoreM PackageFamInstEnv
getPackageFamInstEnv = do
hsc_env <- getHscEnv
eps <- liftIO $ hscEPS hsc_env
return $ eps_fam_inst_env eps
\end{code} \end{code}
%************************************************************************ %************************************************************************
......
...@@ -36,7 +36,7 @@ import LiberateCase ( liberateCase ) ...@@ -36,7 +36,7 @@ import LiberateCase ( liberateCase )
import SAT ( doStaticArgs ) import SAT ( doStaticArgs )
import Specialise ( specProgram) import Specialise ( specProgram)
import SpecConstr ( specConstrProgram) import SpecConstr ( specConstrProgram)
import DmdAnal ( dmdAnalProgram ) import DmdAnal ( dmdAnalProgram )
import WorkWrap ( wwTopBinds ) import WorkWrap ( wwTopBinds )
import Vectorise ( vectorise ) import Vectorise ( vectorise )
import FastString import FastString
...@@ -387,8 +387,8 @@ doCorePass _ CoreCSE = {-# SCC "CommonSubExpr" #-} ...@@ -387,8 +387,8 @@ doCorePass _ CoreCSE = {-# SCC "CommonSubExpr" #-}
doCorePass _ CoreLiberateCase = {-# SCC "LiberateCase" #-} doCorePass _ CoreLiberateCase = {-# SCC "LiberateCase" #-}
doPassD liberateCase doPassD liberateCase
doCorePass dflags CoreDoFloatInwards = {-# SCC "FloatInwards" #-} doCorePass _ CoreDoFloatInwards = {-# SCC "FloatInwards" #-}
doPass (floatInwards dflags) doPassD floatInwards
doCorePass _ (CoreDoFloatOutwards f) = {-# SCC "FloatOutwards" #-} doCorePass _ (CoreDoFloatOutwards f) = {-# SCC "FloatOutwards" #-}
doPassDUM (floatOutwards f) doPassDUM (floatOutwards f)
...@@ -397,10 +397,10 @@ doCorePass _ CoreDoStaticArgs = {-# SCC "StaticArgs" #-} ...@@ -397,10 +397,10 @@ doCorePass _ CoreDoStaticArgs = {-# SCC "StaticArgs" #-}
doPassU doStaticArgs doPassU doStaticArgs
doCorePass _ CoreDoStrictness = {-# SCC "NewStranal" #-} doCorePass _ CoreDoStrictness = {-# SCC "NewStranal" #-}
doPassDM dmdAnalProgram doPassDFM dmdAnalProgram
doCorePass dflags CoreDoWorkerWrapper = {-# SCC "WorkWrap" #-} doCorePass _ CoreDoWorkerWrapper = {-# SCC "WorkWrap" #-}
doPassU (wwTopBinds dflags) doPassDFU wwTopBinds
doCorePass dflags CoreDoSpecialising = {-# SCC "Specialise" #-} doCorePass dflags CoreDoSpecialising = {-# SCC "Specialise" #-}
specProgram dflags specProgram dflags
...@@ -462,6 +462,21 @@ doPassDU do_pass = doPassDUM (\dflags us -> return . do_pass dflags us) ...@@ -462,6 +462,21 @@ doPassDU do_pass = doPassDUM (\dflags us -> return . do_pass dflags us)
doPassU :: (UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts doPassU :: (UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassU do_pass = doPassDU (const do_pass) doPassU do_pass = doPassDU (const do_pass)
doPassDFM :: (DynFlags -> FamInstEnvs -> CoreProgram -> IO CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDFM do_pass guts = do
dflags <- getDynFlags
p_fam_env <- getPackageFamInstEnv
let fam_envs = (p_fam_env, mg_fam_inst_env guts)
doPassM (liftIO . do_pass dflags fam_envs) guts
doPassDFU :: (DynFlags -> FamInstEnvs -> UniqSupply -> CoreProgram -> CoreProgram) -> ModGuts -> CoreM ModGuts
doPassDFU do_pass guts = do
dflags <- getDynFlags
us <- getUniqueSupplyM
p_fam_env <- getPackageFamInstEnv
let fam_envs = (p_fam_env, mg_fam_inst_env guts)
doPass (do_pass dflags fam_envs us) guts
-- Most passes return no stats and don't change rules: these combinators -- Most passes return no stats and don't change rules: these combinators
-- let us lift them to the full blown ModGuts+CoreM world -- let us lift them to the full blown ModGuts+CoreM world
doPassM :: Monad m => (CoreProgram -> m CoreProgram) -> ModGuts -> m ModGuts doPassM :: Monad m => (CoreProgram -> m CoreProgram) -> ModGuts -> m ModGuts
......
...@@ -31,6 +31,7 @@ import TyCon ...@@ -31,6 +31,7 @@ import TyCon
import Type ( eqType ) import Type ( eqType )
-- import Pair -- import Pair
-- import Coercion ( coercionKind ) -- import Coercion ( coercionKind )
import FamInstEnv
import Util import Util
import Maybes ( isJust ) import Maybes ( isJust )
import TysWiredIn ( unboxedPairDataCon ) import TysWiredIn ( unboxedPairDataCon )
...@@ -47,8 +48,8 @@ import Data.Function ( on ) ...@@ -47,8 +48,8 @@ import Data.Function ( on )
%************************************************************************ %************************************************************************
\begin{code} \begin{code}
dmdAnalProgram :: DynFlags -> CoreProgram -> IO CoreProgram dmdAnalProgram :: DynFlags -> FamInstEnvs -> CoreProgram -> IO CoreProgram
dmdAnalProgram dflags binds dmdAnalProgram dflags fam_envs binds
= do { = do {
let { binds_plus_dmds = do_prog binds } ; let { binds_plus_dmds = do_prog binds } ;
dumpIfSet_dyn dflags Opt_D_dump_strsigs "Strictness signatures" $ dumpIfSet_dyn dflags Opt_D_dump_strsigs "Strictness signatures" $
...@@ -57,7 +58,7 @@ dmdAnalProgram dflags binds ...@@ -57,7 +58,7 @@ dmdAnalProgram dflags binds
} }
where where
do_prog :: CoreProgram -> CoreProgram do_prog :: CoreProgram -> CoreProgram
do_prog binds = snd $ mapAccumL dmdAnalTopBind (emptyAnalEnv dflags) binds do_prog binds = snd $ mapAccumL dmdAnalTopBind (emptyAnalEnv dflags fam_envs) binds
-- Analyse a (group of) top-level binding(s) -- Analyse a (group of) top-level binding(s)
dmdAnalTopBind :: AnalEnv dmdAnalTopBind :: AnalEnv
...@@ -611,7 +612,7 @@ dmdAnalRhs top_lvl rec_flag env id rhs ...@@ -611,7 +612,7 @@ dmdAnalRhs top_lvl rec_flag env id rhs
-- See Note [NOINLINE and strictness] -- See Note [NOINLINE and strictness]
-- See Note [Product demands for function body] -- See Note [Product demands for function body]
body_dmd = case deepSplitProductType_maybe (exprType body) of body_dmd = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
Nothing -> cleanEvalDmd Nothing -> cleanEvalDmd
Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc) Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
...@@ -1006,6 +1007,7 @@ data AnalEnv ...@@ -1006,6 +1007,7 @@ data AnalEnv
, ae_virgin :: Bool -- True on first iteration only , ae_virgin :: Bool -- True on first iteration only
-- See Note [Initialising strictness] -- See Note [Initialising strictness]
, ae_rec_tc :: RecTcChecker , ae_rec_tc :: RecTcChecker
, ae_fam_envs :: FamInstEnvs
} }
-- We use the se_env to tell us whether to -- We use the se_env to tell us whether to
...@@ -1023,9 +1025,14 @@ instance Outputable AnalEnv where ...@@ -1023,9 +1025,14 @@ instance Outputable AnalEnv where
[ ptext (sLit "ae_virgin =") <+> ppr virgin [ ptext (sLit "ae_virgin =") <+> ppr virgin
, ptext (sLit "ae_sigs =") <+> ppr env ]) , ptext (sLit "ae_sigs =") <+> ppr env ])
emptyAnalEnv :: DynFlags -> AnalEnv emptyAnalEnv :: DynFlags -> FamInstEnvs -> AnalEnv
emptyAnalEnv dflags = AE { ae_dflags = dflags, ae_sigs = emptySigEnv emptyAnalEnv dflags fam_envs
, ae_virgin = True, ae_rec_tc = initRecTc } = AE { ae_dflags = dflags
, ae_sigs = emptySigEnv
, ae_virgin = True
, ae_rec_tc = initRecTc
, ae_fam_envs = fam_envs
}
emptySigEnv :: SigEnv emptySigEnv :: SigEnv
emptySigEnv = emptyVarEnv emptySigEnv = emptyVarEnv
...@@ -1071,7 +1078,7 @@ extendSigsWithLam env id ...@@ -1071,7 +1078,7 @@ extendSigsWithLam env id
, isStrictDmd (idDemandInfo id) || ae_virgin env , isStrictDmd (idDemandInfo id) || ae_virgin env
-- See Note [Optimistic CPR in the "virgin" case] -- See Note [Optimistic CPR in the "virgin" case]
-- See Note [Initial CPR for strict binders] -- See Note [Initial CPR for strict binders]
, Just (dc,_,_,_) <- deepSplitProductType_maybe $ idType id , Just (dc,_,_,_) <- deepSplitProductType_maybe (ae_fam_envs env) $ idType id
= extendAnalEnv NotTopLevel env id (cprProdSig (dataConRepArity dc)) = extendAnalEnv NotTopLevel env id (cprProdSig (dataConRepArity dc))
| otherwise | otherwise
......
...@@ -28,6 +28,7 @@ import Demand ...@@ -28,6 +28,7 @@ import Demand
import WwLib import WwLib
import Util import Util
import Outputable import Outputable
import FamInstEnv
import MonadUtils import MonadUtils
#include "HsVersions.h" #include "HsVersions.h"
...@@ -60,11 +61,11 @@ info for exported values). ...@@ -60,11 +61,11 @@ info for exported values).
\end{enumerate} \end{enumerate}
\begin{code} \begin{code}
wwTopBinds :: DynFlags -> UniqSupply -> CoreProgram -> CoreProgram wwTopBinds :: DynFlags -> FamInstEnvs -> UniqSupply -> CoreProgram -> CoreProgram
wwTopBinds dflags us top_binds wwTopBinds dflags fam_envs us top_binds
= initUs_ us $ do = initUs_ us $ do
top_binds' <- mapM (wwBind dflags) top_binds top_binds' <- mapM (wwBind dflags fam_envs) top_binds
return (concat top_binds') return (concat top_binds')
\end{code} \end{code}
...@@ -79,23 +80,24 @@ turn. Non-recursive case first, then recursive... ...@@ -79,23 +80,24 @@ turn. Non-recursive case first, then recursive...
\begin{code} \begin{code}
wwBind :: DynFlags wwBind :: DynFlags
-> FamInstEnvs
-> CoreBind -> CoreBind
-> UniqSM [CoreBind] -- returns a WwBinding intermediate form; -> UniqSM [CoreBind] -- returns a WwBinding intermediate form;
-- the caller will convert to Expr/Binding, -- the caller will convert to Expr/Binding,
-- as appropriate. -- as appropriate.
wwBind dflags (NonRec binder rhs) = do wwBind dflags fam_envs (NonRec binder rhs) = do
new_rhs <- wwExpr dflags rhs new_rhs <- wwExpr dflags fam_envs rhs
new_pairs <- tryWW dflags NonRecursive binder new_rhs new_pairs <- tryWW dflags fam_envs NonRecursive binder new_rhs
return [NonRec b e | (b,e) <- new_pairs] return [NonRec b e | (b,e) <- new_pairs]
-- Generated bindings must be non-recursive -- Generated bindings must be non-recursive
-- because the original binding was. -- because the original binding was.
wwBind dflags (Rec pairs) wwBind dflags fam_envs (Rec pairs)
= return . Rec <$> concatMapM do_one pairs = return . Rec <$> concatMapM do_one pairs
where where
do_one (binder, rhs) = do new_rhs <- wwExpr dflags rhs do_one (binder, rhs) = do new_rhs <- wwExpr dflags fam_envs rhs
tryWW dflags Recursive binder new_rhs tryWW dflags fam_envs Recursive binder new_rhs
\end{code} \end{code}
@wwExpr@ basically just walks the tree, looking for appropriate @wwExpr@ basically just walks the tree, looking for appropriate
...@@ -104,36 +106,36 @@ matching by looking for strict arguments of the correct type. ...@@ -104,36 +106,36 @@ matching by looking for strict arguments of the correct type.
@wwExpr@ is a version that just returns the ``Plain'' Tree. @wwExpr@ is a version that just returns the ``Plain'' Tree.
\begin{code} \begin{code}
wwExpr :: DynFlags -> CoreExpr -> UniqSM CoreExpr wwExpr :: DynFlags -> FamInstEnvs -> CoreExpr -> UniqSM CoreExpr
wwExpr _ e@(Type {}) = return e wwExpr _ _ e@(Type {}) = return e
wwExpr _ e@(Coercion {}) = return e wwExpr _ _ e@(Coercion {}) = return e
wwExpr _ e@(Lit {}) = return e wwExpr _ _ e@(Lit {}) = return e
wwExpr _ e@(Var {}) = return e wwExpr _ _ e@(Var {}) = return e
wwExpr dflags (Lam binder expr) wwExpr dflags fam_envs (Lam binder expr)
= Lam binder <$> wwExpr dflags expr = Lam binder <$> wwExpr dflags fam_envs expr
wwExpr dflags (App f a) wwExpr dflags fam_envs (App f a)
= App <$> wwExpr dflags f <*> wwExpr dflags a = App <$> wwExpr dflags fam_envs f <*> wwExpr dflags fam_envs a
wwExpr dflags (Tick note expr) wwExpr dflags fam_envs (Tick note expr)
= Tick note <$> wwExpr dflags expr = Tick note <$> wwExpr dflags fam_envs expr
wwExpr dflags (Cast expr co) = do wwExpr dflags fam_envs (Cast expr co) = do
new_expr <- wwExpr dflags expr new_expr <- wwExpr dflags fam_envs expr
return (Cast new_expr co) return (Cast new_expr co)
wwExpr dflags (Let bind expr) wwExpr dflags fam_envs (Let bind expr)
= mkLets <$> wwBind dflags bind <*> wwExpr dflags expr = mkLets <$> wwBind dflags fam_envs bind <*> wwExpr dflags fam_envs expr
wwExpr dflags (Case expr binder ty alts) = do wwExpr dflags fam_envs (Case expr binder ty alts) = do
new_expr <- wwExpr dflags expr new_expr <- wwExpr dflags fam_envs expr
new_alts <- mapM ww_alt alts new_alts <- mapM ww_alt alts
return (Case new_expr binder ty new_alts) return (Case new_expr binder ty new_alts)
where where
ww_alt (con, binders, rhs) = do ww_alt (con, binders, rhs) = do
new_rhs <- wwExpr dflags rhs new_rhs <- wwExpr dflags fam_envs rhs
return (con, binders, new_rhs) return (con, binders, new_rhs)
\end{code} \end{code}
...@@ -238,6 +240,7 @@ it appears in the first place in the defining module. ...@@ -238,6 +240,7 @@ it appears in the first place in the defining module.
\begin{code} \begin{code}
tryWW :: DynFlags tryWW :: DynFlags
-> FamInstEnvs
-> RecFlag -> RecFlag
-> Id -- The fn binder -> Id -- The fn binder
-> CoreExpr -- The bound rhs; its innards -> CoreExpr -- The bound rhs; its innards
...@@ -247,7 +250,7 @@ tryWW :: DynFlags ...@@ -247,7 +250,7 @@ tryWW :: DynFlags
-- the orig "wrapper" lives on); -- the orig "wrapper" lives on);
-- if two, then a worker and a -- if two, then a worker and a
-- wrapper. -- wrapper.
tryWW dflags is_rec fn_id rhs tryWW dflags fam_envs is_rec fn_id rhs
| isNeverActive inline_act | isNeverActive inline_act
-- No point in worker/wrappering if the thing is never inlined! -- No point in worker/wrappering if the thing is never inlined!
-- Because the no-inline prag will prevent the wrapper ever -- Because the no-inline prag will prevent the wrapper ever
...@@ -258,8 +261,8 @@ tryWW dflags is_rec fn_id rhs ...@@ -258,8 +261,8 @@ tryWW dflags is_rec fn_id rhs
| otherwise | otherwise
= do = do
let doSplit | is_fun = splitFun dflags new_fn_id fn_info wrap_dmds res_info rhs let doSplit | is_fun = splitFun dflags fam_envs new_fn_id fn_info wrap_dmds res_info rhs
| is_thunk = splitThunk dflags is_rec new_fn_id rhs | is_thunk = splitThunk dflags fam_envs is_rec new_fn_id rhs
-- See Note [Thunk splitting] -- See Note [Thunk splitting]
| otherwise = return Nothing | otherwise = return Nothing
try <- doSplit try <- doSplit
...@@ -309,12 +312,12 @@ checkSize dflags fn_id rhs thing_inside ...@@ -309,12 +312,12 @@ checkSize dflags fn_id rhs thing_inside
inline_rule = mkInlineUnfolding Nothing rhs inline_rule = mkInlineUnfolding Nothing rhs
--------------------- ---------------------
splitFun :: DynFlags -> Id -> IdInfo -> [Demand] -> DmdResult -> CoreExpr splitFun :: DynFlags -> FamInstEnvs -> Id -> IdInfo -> [Demand] -> DmdResult -> CoreExpr
-> UniqSM (Maybe [(Id, CoreExpr)]) -> UniqSM (Maybe [(Id, CoreExpr)])
splitFun dflags fn_id fn_info wrap_dmds res_info rhs splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
= WARN( not (wrap_dmds `lengthIs` arity), ppr fn_id <+> (ppr arity $$ ppr wrap_dmds $$ ppr res_info) ) do = WARN( not (wrap_dmds `lengthIs` arity), ppr fn_id <+> (ppr arity $$ ppr wrap_dmds $$ ppr res_info) ) do
-- The arity should match the signature -- The arity should match the signature
stuff <- mkWwBodies dflags fun_ty wrap_dmds res_info one_shots stuff <- mkWwBodies dflags fam_envs fun_ty wrap_dmds res_info one_shots
case stuff of case stuff of
Just (work_demands, wrap_fn, work_fn) -> do Just (work_demands, wrap_fn, work_fn) -> do
work_uniq <- getUniqueM work_uniq <- getUniqueM
...@@ -449,9 +452,9 @@ then the splitting will go deeper too. ...@@ -449,9 +452,9 @@ then the splitting will go deeper too.
-- --> x = let x = e in -- --> x = let x = e in
-- case x of (a,b) -> let x = (a,b) in x -- case x of (a,b) -> let x = (a,b) in x
splitThunk :: DynFlags -> RecFlag -> Var -> Expr Var -> UniqSM (Maybe [(Var, Expr Var)]) splitThunk :: DynFlags -> FamInstEnvs -> RecFlag -> Var -> Expr Var -> UniqSM (Maybe [(Var, Expr Var)])
splitThunk dflags is_rec fn_id rhs = do splitThunk dflags fam_envs is_rec fn_id rhs = do
(useful,_, wrap_fn, work_fn) <- mkWWstr dflags [fn_id] (useful,_, wrap_fn, work_fn) <- mkWWstr dflags fam_envs [fn_id]
let res = [ (fn_id, Let (NonRec fn_id rhs) (wrap_fn (work_fn (Var fn_id)))) ] let res = [ (fn_id, Let (NonRec fn_id rhs) (wrap_fn (work_fn (Var fn_id)))) ]
if useful then ASSERT2( isNonRec is_rec, ppr fn_id ) -- The thunk must be non-recursive if useful then ASSERT2( isNonRec is_rec, ppr fn_id ) -- The thunk must be non-recursive
return (Just res) return (Just res)
......
...@@ -23,6 +23,7 @@ import TysPrim ( voidPrimTy ) ...@@ -23,6 +23,7 @@ import TysPrim ( voidPrimTy )
import TysWiredIn ( tupleCon ) import TysWiredIn ( tupleCon )
import Type import Type
import Coercion hiding ( substTy, substTyVarBndr ) import Coercion hiding ( substTy, substTyVarBndr )
import FamInstEnv
import BasicTypes ( TupleSort(..), OneShotInfo(..), worstOneShot ) import BasicTypes ( TupleSort(..), OneShotInfo(..), worstOneShot )
import Literal ( absentLiteralOf ) import Literal ( absentLiteralOf )
import TyCon import TyCon
...@@ -105,6 +106,7 @@ the unusable strictness-info into the interfaces. ...@@ -105,6 +106,7 @@ the unusable strictness-info into the interfaces.
\begin{code} \begin{code}
mkWwBodies :: DynFlags mkWwBodies :: DynFlags
-> FamInstEnvs
-> Type -- Type of original function -> Type -- Type of original function
-> [Demand] -- Strictness of original function -> [Demand] -- Strictness of original function
-> DmdResult -- Info about function result -> DmdResult -- Info about function result
...@@ -124,14 +126,14 @@ mkWwBodies :: DynFlags ...@@ -124,14 +126,14 @@ mkWwBodies :: DynFlags
-- let x = (a,b) in -- let x = (a,b) in
-- E -- E
mkWwBodies dflags fun_ty demands res_info one_shots mkWwBodies dflags fam_envs fun_ty demands res_info one_shots
= do { let arg_info = demands `zip` (one_shots ++ repeat NoOneShotInfo) = do { let arg_info = demands `zip` (one_shots ++ repeat NoOneShotInfo)
all_one_shots = foldr (worstOneShot . snd) OneShotLam arg_info all_one_shots = foldr (worstOneShot . snd) OneShotLam arg_info
; (wrap_args, wrap_fn_args, work_fn_args, res_ty) <- mkWWargs emptyTvSubst fun_ty arg_info ; (wrap_args, wrap_fn_args, work_fn_args, res_ty) <- mkWWargs emptyTvSubst fun_ty arg_info
; (useful1, work_args, wrap_fn_str, work_fn_str) <- mkWWstr dflags wrap_args ; (useful1, work_args, wrap_fn_str, work_fn_str) <- mkWWstr dflags fam_envs wrap_args
-- Do CPR w/w. See Note [Always do CPR w/w] -- Do CPR w/w. See Note [Always do CPR w/w]
; (useful2, wrap_fn_cpr, work_fn_cpr, cpr_res_ty) <- mkWWcpr res_ty res_info ; (useful2, wrap_fn_cpr, work_fn_cpr, cpr_res_ty) <- mkWWcpr fam_envs res_ty res_info
; let (work_lam_args, work_call_args) = mkWorkerArgs dflags work_args all_one_shots cpr_res_ty ; let (work_lam_args, work_call_args) = mkWorkerArgs dflags work_args all_one_shots cpr_res_ty
worker_args_dmds = [idDemandInfo v | v <- work_call_args, isId v] worker_args_dmds = [idDemandInfo v | v <- work_call_args, isId v]
...@@ -371,6 +373,7 @@ That's why we carry the TvSubst through mkWWargs ...@@ -371,6 +373,7 @@ That's why we carry the TvSubst through mkWWargs
\begin{code} \begin{code}
mkWWstr :: DynFlags mkWWstr :: DynFlags
-> FamInstEnvs
-> [Var] -- Wrapper args; have their demand info on them -> [Var] -- Wrapper args; have their demand info on them
-- *Includes type variables* -- *Includes type variables*
-> UniqSM (Bool, -- Is this useful -> UniqSM (Bool, -- Is this useful
...@@ -382,12 +385,12 @@ mkWWstr :: DynFlags ...@@ -382,12 +385,12 @@ mkWWstr :: DynFlags
CoreExpr -> CoreExpr) -- Worker body, lacking the original body of the function, CoreExpr -> CoreExpr) -- Worker body, lacking the original body of the function,
-- and lacking its lambdas. -- and lacking its lambdas.
-- This fn does the reboxing -- This fn does the reboxing
mkWWstr _ [] mkWWstr _ _ []
= return (False, [], nop_fn, nop_fn) = return (False, [], nop_fn, nop_fn)
mkWWstr dflags (arg : args) = do mkWWstr dflags fam_envs (arg : args) = do
(useful1, args1, wrap_fn1, work_fn1) <- mkWWstr_one dflags arg (useful1, args1, wrap_fn1, work_fn1) <- mkWWstr_one dflags fam_envs arg
(useful2, args2, wrap_fn2, work_fn2) <- mkWWstr dflags args (useful2, args2, wrap_fn2, work_fn2) <- mkWWstr dflags fam_envs args
return (useful1 || useful2, args1 ++ args2, wrap_fn1 . wrap_fn2, work_fn1 . work_fn2) return (useful1 || useful2, args1 ++ args2, wrap_fn1 . wrap_fn2, work_fn1 . work_fn2)
\end{code} \end{code}
...@@ -426,8 +429,9 @@ as-yet-un-filled-in pkgState files. ...@@ -426,8 +429,9 @@ as-yet-un-filled-in pkgState files.
-- brings into scope work_args (via cases) -- brings into scope work_args (via cases)
-- * work_fn assumes work_args are in scope, a -- * work_fn assumes work_args are in scope, a
-- brings into scope wrap_arg (via lets) -- brings into scope wrap_arg (via lets)
mkWWstr_one :: DynFlags -> Var -> UniqSM (Bool, [Var], CoreExpr -> CoreExpr, CoreExpr -> CoreExpr) mkWWstr_one :: DynFlags -> FamInstEnvs -> Var
mkWWstr_one dflags arg -> UniqSM (Bool, [Var], CoreExpr -> CoreExpr, CoreExpr -> CoreExpr)
mkWWstr_one dflags fam_envs arg
| isTyVar arg | isTyVar arg
= return (False, [arg], nop_fn, nop_fn) = return (False, [arg], nop_fn, nop_fn)
...@@ -463,7 +467,7 @@ mkWWstr_one dflags arg ...@@ -463,7 +467,7 @@ mkWWstr_one dflags arg
, Just cs <- splitProdDmd_maybe dmd , Just cs <- splitProdDmd_maybe dmd
-- See Note [Unpacking arguments with product and polymorphic demands] -- See Note [Unpacking arguments with product and polymorphic demands]
, Just (data_con, inst_tys, inst_con_arg_tys, co) , Just (data_con, inst_tys, inst_con_arg_tys, co)
<- deepSplitProductType_maybe (idType arg) <- deepSplitProductType_maybe fam_envs (idType arg)
, cs `equalLength` inst_con_arg_tys , cs `equalLength` inst_con_arg_tys
-- See Note [mkWWstr and unsafeCoerce] -- See Note [mkWWstr and unsafeCoerce]
= do { (uniq1:uniqs) <- getUniquesM = do { (uniq1:uniqs) <- getUniquesM
...@@ -473,7 +477,7 @@ mkWWstr_one dflags arg ...@@ -473,7 +477,7 @@ mkWWstr_one dflags arg
data_con unpk_args data_con unpk_args
rebox_fn = Let (NonRec arg con_app) rebox_fn = Let (NonRec arg con_app)
con_app = mkConApp2 data_con inst_tys unpk_args `mkCast` mkSymCo co con_app = mkConApp2 data_con inst_tys unpk_args `mkCast` mkSymCo co
; (_, worker_args, wrap_fn, work_fn) <- mkWWstr dflags unpk_args_w_ds ; (_, worker_args, wrap_fn, work_fn) <- mkWWstr dflags fam_envs unpk_args_w_ds
; return (True, worker_args, unbox_fn . wrap_fn, work_fn . rebox_fn) } ; return (True, worker_args, unbox_fn . wrap_fn, work_fn . rebox_fn) }
-- Don't pass the arg, rebox instead -- Don't pass the arg, rebox instead
...@@ -503,29 +507,31 @@ If so, the worker/wrapper split doesn't work right and we get a Core Lint ...@@ -503,29 +507,31 @@ If so, the worker/wrapper split doesn't work right and we get a Core Lint
bug. The fix here is simply to decline to do w/w if that happens. bug. The fix here is simply to decline to do w/w if that happens.
\begin{code} \begin{code}
deepSplitProductType_maybe :: Type -> Maybe (DataCon, [Type], [Type], Coercion) deepSplitProductType_maybe :: FamInstEnvs -> Type -> Maybe (DataCon, [Type], [Type], Coercion)
-- If deepSplitProductType_maybe ty = Just (dc, tys, arg_tys, co) -- If deepSplitProductType_maybe ty = Just (dc, tys, arg_tys, co)
-- then dc @ tys (args::arg_tys) :: rep_ty -- then dc @ tys (args::arg_tys) :: rep_ty
-- co :: ty ~ rep_ty -- co :: ty ~ rep_ty
deepSplitProductType_maybe ty deepSplitProductType_maybe fam_envs ty
| let (co, ty1) = topNormaliseNewType_maybe ty `orElse` (mkReflCo Representational ty, ty) | let (co, ty1) = topNormaliseType_maybe fam_envs ty
`orElse` (mkReflCo Representational ty, ty)
, Just (tc, tc_args) <- splitTyConApp_maybe ty1 , Just (tc, tc_args) <- splitTyConApp_maybe ty1
, Just con <- isDataProductTyCon_maybe tc , Just con <- isDataProductTyCon_maybe tc
= Just (con, tc_args, dataConInstArgTys con tc_args, co) = Just (con, tc_args, dataConInstArgTys con tc_args, co)
deepSplitProductType_maybe _ = Nothing deepSplitProductType_maybe _ _ = Nothing
deepSplitCprType_maybe :: ConTag -> Type -> Maybe (DataCon, [Type], [Type], Coercion) deepSplitCprType_maybe :: FamInstEnvs -> ConTag -> Type -> Maybe (DataCon, [Type], [Type], Coercion)
-- If deepSplitCprType_maybe n ty = Just (dc, tys, arg_tys, co) -- If deepSplitCprType_maybe n ty = Just (dc, tys, arg_tys, co)
-- then dc @ tys (args::arg_tys) :: rep_ty -- then dc @ tys (args::arg_tys) :: rep_ty
-- co :: ty ~ rep_ty -- co :: ty ~ rep_ty
deepSplitCprType_maybe con_tag ty deepSplitCprType_maybe fam_envs con_tag ty
| let (co, ty1) = topNormaliseNewType_maybe ty `orElse` (mkReflCo Representational ty, ty) | let (co, ty1) = topNormaliseType_maybe fam_envs ty
`orElse` (mkReflCo Representational ty, ty)
, Just (tc, tc_args) <- splitTyConApp_maybe ty1 , Just (tc, tc_args) <- splitTyConApp_maybe ty1
, isDataTyCon tc , isDataTyCon tc
, let cons = tyConDataCons tc , let cons = tyConDataCons tc
con = ASSERT( cons `lengthAtLeast` con_tag ) cons !! (con_tag - fIRST_TAG) con = ASSERT( cons `lengthAtLeast` con_tag ) cons !! (con_tag - fIRST_TAG)
= Just (con, tc_args, dataConInstArgTys con tc_args, co) = Just (con, tc_args, dataConInstArgTys con tc_args, co)
deepSplitCprType_maybe _ _ = Nothing deepSplitCprType_maybe _ _ _ = Nothing
\end{code} \end{code}
...@@ -546,17 +552,18 @@ left-to-right traversal of the result structure. ...@@ -546,17 +552,18 @@ left-to-right traversal of the result structure.
\begin{code} \begin{code}
mkWWcpr :: Type -- function body type mkWWcpr :: FamInstEnvs
-> Type -- function body type
-> DmdResult -- CPR analysis results -> DmdResult -- CPR analysis results
-> UniqSM (Bool, -- Is w/w'ing useful? -> UniqSM (Bool, -- Is w/w'ing useful?
CoreExpr -> CoreExpr, -- New wrapper CoreExpr -> CoreExpr, -- New wrapper
CoreExpr -> CoreExpr, -- New worker CoreExpr -> CoreExpr, -- New worker
Type) -- Type of worker's body Type) -- Type of worker's body
mkWWcpr body_ty res mkWWcpr fam_envs body_ty res
= case returnsCPR_maybe res of = case returnsCPR_maybe res of
Nothing -> return (False, id, id, body_ty) -- No CPR info Nothing -> return (False, id, id, body_ty) -- No CPR info
Just con_tag | Just stuff <- deepSplitCprType_maybe con_tag body_ty Just con_tag | Just stuff <- deepSplitCprType_maybe fam_envs con_tag body_ty
-> mkWWcpr_help stuff -> mkWWcpr_help stuff
| otherwise | otherwise
-- See Note [non-algebraic or open body type warning] -- See Note [non-algebraic or open body type warning]
......
...@@ -1185,6 +1185,12 @@ topNormaliseNewType_maybe :: Type -> Maybe (Coercion, Type) ...@@ -1185,6 +1185,12 @@ topNormaliseNewType_maybe :: Type -> Maybe (Coercion, Type)
-- --
-- The function returns @Nothing@ for non-@newtypes@, -- The function returns @Nothing@ for non-@newtypes@,
-- or unsaturated applications -- or unsaturated applications
--
-- This function does *not* look through type families, because it has no access to
-- the type family environment. If you do have that at hand, consider to use
-- topNormaliseType_maybe, which should be a drop-in replacement for
-- topNormaliseNewType_maybe
--
topNormaliseNewType_maybe ty