Commit f1a90f54 authored by simonpj@microsoft.com's avatar simonpj@microsoft.com
Browse files

Fix an egregious strictness analyser bug (Trac #4924)

The "virgin" flag was being threaded rather than treated
like an environment.  As a result, the second and subsequent
recursive definitions in a module were not getting a
correctly-initialised fixpoint loop, causing much worse
strictness analysis results.  Indeed the symptoms in
Trac #4924 were quite bizarre.

Anyway, it's easily fixed.  Merge to stable branch.
parent a0f6d307
......@@ -74,35 +74,33 @@ dmdAnalTopBind :: SigEnv
-> CoreBind
-> (SigEnv, CoreBind)
dmdAnalTopBind sigs (NonRec id rhs)
= let
( _, _, (_, rhs1)) = dmdAnalRhs TopLevel NonRecursive sigs (id, rhs)
(sigs2, _, (id2, rhs2)) = dmdAnalRhs TopLevel NonRecursive sigs (id, rhs1)
-- Do two passes to improve CPR information
-- See comments with ignore_cpr_info in mk_sig_ty
-- and with extendSigsWithLam
in
(sigs2, NonRec id2 rhs2)
= (sigs2, NonRec id2 rhs2)
where
( _, _, (_, rhs1)) = dmdAnalRhs TopLevel NonRecursive (virgin sigs) (id, rhs)
(sigs2, _, (id2, rhs2)) = dmdAnalRhs TopLevel NonRecursive (nonVirgin sigs) (id, rhs1)
-- Do two passes to improve CPR information
-- See comments with ignore_cpr_info in mk_sig_ty
-- and with extendSigsWithLam
dmdAnalTopBind sigs (Rec pairs)
= let
(sigs', _, pairs') = dmdFix TopLevel sigs pairs
= (sigs', Rec pairs')
where
(sigs', _, pairs') = dmdFix TopLevel (virgin sigs) pairs
-- We get two iterations automatically
-- c.f. the NonRec case above
in
(sigs', Rec pairs')
\end{code}
\begin{code}
dmdAnalTopRhs :: CoreExpr -> (StrictSig, CoreExpr)
-- Analyse the RHS and return
-- a) appropriate strictness info
-- b) the unfolding (decorated with stricntess info)
-- b) the unfolding (decorated with strictness info)
dmdAnalTopRhs rhs
= (sig, rhs2)
where
call_dmd = vanillaCall (exprArity rhs)
(_, rhs1) = dmdAnal emptySigEnv call_dmd rhs
(rhs_ty, rhs2) = dmdAnal emptySigEnv call_dmd rhs1
(_, rhs1) = dmdAnal (virgin emptySigEnv) call_dmd rhs
(rhs_ty, rhs2) = dmdAnal (nonVirgin emptySigEnv) call_dmd rhs1
sig = mkTopSigTy rhs rhs_ty
-- Do two passes; see notes with extendSigsWithLam
-- Otherwise we get bogus CPR info for constructors like
......@@ -119,14 +117,14 @@ dmdAnalTopRhs rhs
%************************************************************************
\begin{code}
dmdAnal :: SigEnv -> Demand -> CoreExpr -> (DmdType, CoreExpr)
dmdAnal :: AnalEnv -> Demand -> CoreExpr -> (DmdType, CoreExpr)
dmdAnal _ Abs e = (topDmdType, e)
dmdAnal sigs dmd e
dmdAnal env dmd e
| not (isStrictDmd dmd)
= let
(res_ty, e') = dmdAnal sigs evalDmd e
(res_ty, e') = dmdAnal env evalDmd e
in
(deferType res_ty, e')
-- It's important not to analyse e with a lazy demand because
......@@ -147,13 +145,13 @@ dmdAnal sigs dmd e
dmdAnal _ _ (Lit lit) = (topDmdType, Lit lit)
dmdAnal _ _ (Type ty) = (topDmdType, Type ty) -- Doesn't happen, in fact
dmdAnal sigs dmd (Var var)
= (dmdTransform sigs var dmd, Var var)
dmdAnal env dmd (Var var)
= (dmdTransform env var dmd, Var var)
dmdAnal sigs dmd (Cast e co)
dmdAnal env dmd (Cast e co)
= (dmd_ty, Cast e' co)
where
(dmd_ty, e') = dmdAnal sigs dmd' e
(dmd_ty, e') = dmdAnal env dmd' e
to_co = snd (coercionKind co)
dmd'
| Just (tc, _) <- splitTyConApp_maybe to_co
......@@ -165,55 +163,55 @@ dmdAnal sigs dmd (Cast e co)
-- inside recursive products -- we might not reach
-- a fixpoint. So revert to a vanilla Eval demand
dmdAnal sigs dmd (Note n e)
dmdAnal env dmd (Note n e)
= (dmd_ty, Note n e')
where
(dmd_ty, e') = dmdAnal sigs dmd e
(dmd_ty, e') = dmdAnal env dmd e
dmdAnal sigs dmd (App fun (Type ty))
dmdAnal env dmd (App fun (Type ty))
= (fun_ty, App fun' (Type ty))
where
(fun_ty, fun') = dmdAnal sigs dmd fun
(fun_ty, fun') = dmdAnal env dmd fun
-- Lots of the other code is there to make this
-- beautiful, compositional, application rule :-)
dmdAnal sigs dmd (App fun arg) -- Non-type arguments
dmdAnal env dmd (App fun arg) -- Non-type arguments
= let -- [Type arg handled above]
(fun_ty, fun') = dmdAnal sigs (Call dmd) fun
(arg_ty, arg') = dmdAnal sigs arg_dmd arg
(fun_ty, fun') = dmdAnal env (Call dmd) fun
(arg_ty, arg') = dmdAnal env arg_dmd arg
(arg_dmd, res_ty) = splitDmdTy fun_ty
in
(res_ty `bothType` arg_ty, App fun' arg')
dmdAnal sigs dmd (Lam var body)
dmdAnal env dmd (Lam var body)
| isTyCoVar var
= let
(body_ty, body') = dmdAnal sigs dmd body
(body_ty, body') = dmdAnal env dmd body
in
(body_ty, Lam var body')
| Call body_dmd <- dmd -- A call demand: good!
= let
sigs' = extendSigsWithLam sigs var
(body_ty, body') = dmdAnal sigs' body_dmd body
(lam_ty, var') = annotateLamIdBndr sigs body_ty var
env' = extendSigsWithLam env var
(body_ty, body') = dmdAnal env' body_dmd body
(lam_ty, var') = annotateLamIdBndr env body_ty var
in
(lam_ty, Lam var' body')
| otherwise -- Not enough demand on the lambda; but do the body
= let -- anyway to annotate it and gather free var info
(body_ty, body') = dmdAnal sigs evalDmd body
(lam_ty, var') = annotateLamIdBndr sigs body_ty var
(body_ty, body') = dmdAnal env evalDmd body
(lam_ty, var') = annotateLamIdBndr env body_ty var
in
(deferType lam_ty, Lam var' body')
dmdAnal sigs dmd (Case scrut case_bndr ty [alt@(DataAlt dc, _, _)])
dmdAnal env dmd (Case scrut case_bndr ty [alt@(DataAlt dc, _, _)])
| let tycon = dataConTyCon dc
, isProductTyCon tycon
, not (isRecursiveTyCon tycon)
= let
sigs_alt = extendSigEnv NotTopLevel sigs case_bndr case_bndr_sig
(alt_ty, alt') = dmdAnalAlt sigs_alt dmd alt
env_alt = extendAnalEnv NotTopLevel env case_bndr case_bndr_sig
(alt_ty, alt') = dmdAnalAlt env_alt dmd alt
(alt_ty1, case_bndr') = annotateBndr alt_ty case_bndr
(_, bndrs', _) = alt'
case_bndr_sig = cprSig
......@@ -251,23 +249,23 @@ dmdAnal sigs dmd (Case scrut case_bndr ty [alt@(DataAlt dc, _, _)])
scrut_dmd = alt_dmd `both`
idDemandInfo case_bndr'
(scrut_ty, scrut') = dmdAnal sigs scrut_dmd scrut
(scrut_ty, scrut') = dmdAnal env scrut_dmd scrut
in
(alt_ty1 `bothType` scrut_ty, Case scrut' case_bndr' ty [alt'])
dmdAnal sigs dmd (Case scrut case_bndr ty alts)
dmdAnal env dmd (Case scrut case_bndr ty alts)
= let
(alt_tys, alts') = mapAndUnzip (dmdAnalAlt sigs dmd) alts
(scrut_ty, scrut') = dmdAnal sigs evalDmd scrut
(alt_tys, alts') = mapAndUnzip (dmdAnalAlt env dmd) alts
(scrut_ty, scrut') = dmdAnal env evalDmd scrut
(alt_ty, case_bndr') = annotateBndr (foldr1 lubType alt_tys) case_bndr
in
-- pprTrace "dmdAnal:Case" (ppr alts $$ ppr alt_tys)
(alt_ty `bothType` scrut_ty, Case scrut' case_bndr' ty alts')
dmdAnal sigs dmd (Let (NonRec id rhs) body)
dmdAnal env dmd (Let (NonRec id rhs) body)
= let
(sigs', lazy_fv, (id1, rhs')) = dmdAnalRhs NotTopLevel NonRecursive sigs (id, rhs)
(body_ty, body') = dmdAnal sigs' dmd body
(sigs', lazy_fv, (id1, rhs')) = dmdAnalRhs NotTopLevel NonRecursive env (id, rhs)
(body_ty, body') = dmdAnal (updSigEnv env sigs') dmd body
(body_ty1, id2) = annotateBndr body_ty id1
body_ty2 = addLazyFVs body_ty1 lazy_fv
in
......@@ -285,11 +283,11 @@ dmdAnal sigs dmd (Let (NonRec id rhs) body)
-- bother to re-analyse the RHS.
(body_ty2, Let (NonRec id2 rhs') body')
dmdAnal sigs dmd (Let (Rec pairs) body)
dmdAnal env dmd (Let (Rec pairs) body)
= let
bndrs = map fst pairs
(sigs', lazy_fv, pairs') = dmdFix NotTopLevel sigs pairs
(body_ty, body') = dmdAnal sigs' dmd body
(sigs', lazy_fv, pairs') = dmdFix NotTopLevel env pairs
(body_ty, body') = dmdAnal (updSigEnv env sigs') dmd body
body_ty1 = addLazyFVs body_ty lazy_fv
in
sigs' `seq` body_ty `seq`
......@@ -303,10 +301,10 @@ dmdAnal sigs dmd (Let (Rec pairs) body)
(body_ty2, Let (Rec pairs') body')
dmdAnalAlt :: SigEnv -> Demand -> Alt Var -> (DmdType, Alt Var)
dmdAnalAlt sigs dmd (con,bndrs,rhs)
dmdAnalAlt :: AnalEnv -> Demand -> Alt Var -> (DmdType, Alt Var)
dmdAnalAlt env dmd (con,bndrs,rhs)
= let
(rhs_ty, rhs') = dmdAnal sigs dmd rhs
(rhs_ty, rhs') = dmdAnal env dmd rhs
rhs_ty' = addDataConPatDmds con bndrs rhs_ty
(alt_ty, bndrs') = annotateBndrs rhs_ty' bndrs
final_alt_ty | io_hack_reqd = alt_ty `lubType` topDmdType
......@@ -388,14 +386,14 @@ argument, and pass an Int to $wfoo!
%************************************************************************
\begin{code}
dmdTransform :: SigEnv -- The strictness environment
dmdTransform :: AnalEnv -- The strictness environment
-> Id -- The function
-> Demand -- The demand on the function
-> DmdType -- The demand type of the function in this context
-- Returned DmdEnv includes the demand on
-- this function plus demand on its free variables
dmdTransform sigs var dmd
dmdTransform env var dmd
------ DATA CONSTRUCTOR
| isDataConWorkId var -- Data constructor
......@@ -439,7 +437,7 @@ dmdTransform sigs var dmd
topDmdType
------ LOCAL LET/REC BOUND THING
| Just (StrictSig dmd_ty, top_lvl) <- lookupSigEnv sigs var
| Just (StrictSig dmd_ty, top_lvl) <- lookupSigEnv env var
= let
fn_ty | dmdTypeDepth dmd_ty <= call_depth = dmd_ty
| otherwise = deferType dmd_ty
......@@ -467,22 +465,26 @@ dmdTransform sigs var dmd
\begin{code}
dmdFix :: TopLevelFlag
-> SigEnv -- Does not include bindings for this binding
-> AnalEnv -- Does not include bindings for this binding
-> [(Id,CoreExpr)]
-> (SigEnv, DmdEnv,
[(Id,CoreExpr)]) -- Binders annotated with stricness info
dmdFix top_lvl sigs orig_pairs
= loop 1 initial_sigs orig_pairs
dmdFix top_lvl env orig_pairs
= loop 1 initial_env orig_pairs
where
bndrs = map fst orig_pairs
initial_sigs = addInitialSigs top_lvl sigs bndrs
initial_env = addInitialSigs top_lvl env bndrs
loop :: Int
-> SigEnv -- Already contains the current sigs
-> AnalEnv -- Already contains the current sigs
-> [(Id,CoreExpr)]
-> (SigEnv, DmdEnv, [(Id,CoreExpr)])
loop n sigs pairs
loop n env pairs
= -- pprTrace "dmd loop" (ppr n <+> ppr bndrs $$ ppr env) $
loop' n env pairs
loop' n env pairs
| found_fixpoint
= (sigs', lazy_fv, pairs')
-- Note: return pairs', not pairs. pairs' is the result of
......@@ -492,11 +494,11 @@ dmdFix top_lvl sigs orig_pairs
| n >= 10
= pprTrace "dmdFix loop" (ppr n <+> (vcat
[ text "Sigs:" <+> ppr [ (id,lookupSigEnv sigs id, lookupSigEnv sigs' id)
[ text "Sigs:" <+> ppr [ (id,lookupVarEnv sigs id, lookupVarEnv sigs' id)
| (id,_) <- pairs],
text "env:" <+> ppr sigs,
text "env:" <+> ppr env,
text "binds:" <+> pprCoreBinding (Rec pairs)]))
(emptySigEnv, lazy_fv, orig_pairs) -- Safe output
(sigEnv env, lazy_fv, orig_pairs) -- Safe output
-- The lazy_fv part is really important! orig_pairs has no strictness
-- info, including nothing about free vars. But if we have
-- letrec f = ....y..... in ...f...
......@@ -504,42 +506,45 @@ dmdFix top_lvl sigs orig_pairs
-- otherwise y will get recorded as absent altogether
| otherwise
= loop (n+1) (setNonVirgin sigs') pairs'
= loop (n+1) (nonVirgin sigs') pairs'
where
sigs = sigEnv env
found_fixpoint = all (same_sig sigs sigs') bndrs
-- Use the new signature to do the next pair
((sigs',lazy_fv), pairs') = mapAccumL my_downRhs (sigs, emptyDmdEnv) pairs
-- mapAccumL: Use the new signature to do the next pair
-- The occurrence analyser has arranged them in a good order
-- so this can significantly reduce the number of iterations needed
((sigs',lazy_fv), pairs') = mapAccumL my_downRhs (sigs, emptyDmdEnv) pairs
my_downRhs (sigs,lazy_fv) (id,rhs) = ((sigs', lazy_fv'), pair')
where
(sigs', lazy_fv1, pair') = dmdAnalRhs top_lvl Recursive sigs (id,rhs)
lazy_fv' = plusVarEnv_C both lazy_fv lazy_fv1
my_downRhs (sigs,lazy_fv) (id,rhs)
= ((sigs', lazy_fv'), pair')
where
(sigs', lazy_fv1, pair') = dmdAnalRhs top_lvl Recursive (updSigEnv env sigs) (id,rhs)
lazy_fv' = plusVarEnv_C both lazy_fv lazy_fv1
same_sig sigs sigs' var = lookup sigs var == lookup sigs' var
lookup sigs var = case lookupSigEnv sigs var of
lookup sigs var = case lookupVarEnv sigs var of
Just (sig,_) -> sig
Nothing -> pprPanic "dmdFix" (ppr var)
dmdAnalRhs :: TopLevelFlag -> RecFlag
-> SigEnv -> (Id, CoreExpr)
-> AnalEnv -> (Id, CoreExpr)
-> (SigEnv, DmdEnv, (Id, CoreExpr))
-- Process the RHS of the binding, add the strictness signature
-- to the Id, and augment the environment with the signature as well.
dmdAnalRhs top_lvl rec_flag sigs (id, rhs)
dmdAnalRhs top_lvl rec_flag env (id, rhs)
= (sigs', lazy_fv, (id', rhs'))
where
arity = idArity id -- The idArity should be up to date
-- The simplifier was run just beforehand
(rhs_dmd_ty, rhs') = dmdAnal sigs (vanillaCall arity) rhs
(rhs_dmd_ty, rhs') = dmdAnal env (vanillaCall arity) rhs
(lazy_fv, sig_ty) = WARN( arity /= dmdTypeDepth rhs_dmd_ty && not (exprIsTrivial rhs), ppr id )
-- The RHS can be eta-reduced to just a variable,
-- in which case we should not complain.
mkSigTy top_lvl rec_flag id rhs rhs_dmd_ty
id' = id `setIdStrictness` sig_ty
sigs' = extendSigEnv top_lvl sigs id sig_ty
sigs' = extendSigEnv top_lvl (sigEnv env) id sig_ty
\end{code}
......@@ -841,13 +846,13 @@ annotateBndr dmd_ty@(DmdType fv ds res) var
annotateBndrs :: DmdType -> [Var] -> (DmdType, [Var])
annotateBndrs = mapAccumR annotateBndr
annotateLamIdBndr :: SigEnv
annotateLamIdBndr :: AnalEnv
-> DmdType -- Demand type of body
-> Id -- Lambda binder
-> (DmdType, -- Demand type of lambda
Id) -- and binder annotated with demand
annotateLamIdBndr sigs (DmdType fv ds res) id
annotateLamIdBndr env (DmdType fv ds res) id
-- For lambdas we add the demand to the argument demands
-- Only called for Ids
= ASSERT( isId id )
......@@ -858,7 +863,7 @@ annotateLamIdBndr sigs (DmdType fv ds res) id
Nothing -> main_ty
Just unf -> main_ty `bothType` unf_ty
where
(unf_ty, _) = dmdAnal sigs dmd unf
(unf_ty, _) = dmdAnal env dmd unf
main_ty = DmdType fv' (hacked_dmd:ds) res
......@@ -906,9 +911,9 @@ forget that fact, otherwise we might make 'x' absent when it isn't.
%************************************************************************
\begin{code}
data SigEnv
= SE { se_env :: VarEnv (StrictSig, TopLevelFlag)
, se_virgin :: Bool } -- True on first iteration only
data AnalEnv
= AE { ae_sigs :: SigEnv
, ae_virgin :: Bool } -- True on first iteration only
-- See Note [Initialising strictness]
-- We use the se_env to tell us whether to
-- record info about a variable in the DmdEnv
......@@ -917,36 +922,48 @@ data SigEnv
-- The DmdEnv gives the demand on the free vars of the function
-- when it is given enough args to satisfy the strictness signature
instance Outputable SigEnv where
ppr (SE { se_env = env, se_virgin = virgin })
= ptext (sLit "SE") <+> braces (vcat
[ ptext (sLit "se_virgin =") <+> ppr virgin
, ptext (sLit "se_env =") <+> ppr env ])
type SigEnv = VarEnv (StrictSig, TopLevelFlag)
instance Outputable AnalEnv where
ppr (AE { ae_sigs = env, ae_virgin = virgin })
= ptext (sLit "AE") <+> braces (vcat
[ ptext (sLit "ae_virgin =") <+> ppr virgin
, ptext (sLit "ae_sigs =") <+> ppr env ])
emptySigEnv :: SigEnv
emptySigEnv = SE { se_env = emptyVarEnv, se_virgin = True }
emptySigEnv = emptyVarEnv
sigEnv :: AnalEnv -> SigEnv
sigEnv = ae_sigs
updSigEnv :: AnalEnv -> SigEnv -> AnalEnv
updSigEnv env sigs = env { ae_sigs = sigs }
extendAnalEnv :: TopLevelFlag -> AnalEnv -> Id -> StrictSig -> AnalEnv
extendAnalEnv top_lvl env var sig
= env { ae_sigs = extendSigEnv top_lvl (ae_sigs env) var sig }
extendSigEnv :: TopLevelFlag -> SigEnv -> Id -> StrictSig -> SigEnv
extendSigEnv top_lvl sigs var sig
= sigs { se_env = extendVarEnv (se_env sigs) var (sig, top_lvl) }
extendSigEnv top_lvl sigs var sig = extendVarEnv sigs var (sig, top_lvl)
lookupSigEnv :: SigEnv -> Id -> Maybe (StrictSig, TopLevelFlag)
lookupSigEnv sigs id = lookupVarEnv (se_env sigs) id
lookupSigEnv :: AnalEnv -> Id -> Maybe (StrictSig, TopLevelFlag)
lookupSigEnv env id = lookupVarEnv (ae_sigs env) id
addInitialSigs :: TopLevelFlag -> SigEnv -> [Id] -> SigEnv
addInitialSigs :: TopLevelFlag -> AnalEnv -> [Id] -> AnalEnv
-- See Note [Initialising strictness]
addInitialSigs top_lvl sigs@(SE { se_env = env, se_virgin = virgin }) ids
= sigs { se_env = extendVarEnvList env [ (id, (init_sig id, top_lvl))
| id <- ids ] }
addInitialSigs top_lvl env@(AE { ae_sigs = sigs, ae_virgin = virgin }) ids
= env { ae_sigs = extendVarEnvList sigs [ (id, (init_sig id, top_lvl))
| id <- ids ] }
where
init_sig | virgin = \_ -> botSig
| otherwise = idStrictness
setNonVirgin :: SigEnv -> SigEnv
setNonVirgin sigs = sigs { se_virgin = False }
virgin, nonVirgin :: SigEnv -> AnalEnv
virgin sigs = AE { ae_sigs = sigs, ae_virgin = True }
nonVirgin sigs = AE { ae_sigs = sigs, ae_virgin = False }
extendSigsWithLam :: SigEnv -> Id -> SigEnv
-- Extend the SigEnv when we meet a lambda binder
extendSigsWithLam :: AnalEnv -> Id -> AnalEnv
-- Extend the AnalEnv when we meet a lambda binder
-- If the binder is marked demanded with a product demand, then give it a CPR
-- signature, because in the likely event that this is a lambda on a fn defn
-- [we only use this when the lambda is being consumed with a call demand],
......@@ -961,13 +978,13 @@ extendSigsWithLam :: SigEnv -> Id -> SigEnv
-- definitely has product type, else we may get over-optimistic
-- CPR results (e.g. from \x -> x!).
extendSigsWithLam sigs id
extendSigsWithLam env id
= case idDemandInfo_maybe id of
Nothing -> extendSigEnv NotTopLevel sigs id cprSig
Nothing -> extendAnalEnv NotTopLevel env id cprSig
-- Optimistic in the Nothing case;
-- See notes [CPR-AND-STRICTNESS]
Just (Eval (Prod _)) -> extendSigEnv NotTopLevel sigs id cprSig
_ -> sigs
Just (Eval (Prod _)) -> extendAnalEnv NotTopLevel env id cprSig
_ -> env
\end{code}
Note [Initialising strictness]
......@@ -986,8 +1003,8 @@ plan.)
But on the *first* iteration we want to *ignore* the current strictness
of the Id, and start from "bottom". Nowadays the Id can have a current
strictness, because interface files record strictness for nested bindings.
To know when we are in the first iteration, we look at the se_virgin
field of the SigEnv.
To know when we are in the first iteration, we look at the ae_virgin
field of the AnalEnv.
%************************************************************************
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment