Commit b5b7d820 authored by Simon Peyton Jones's avatar Simon Peyton Jones
Browse files

Improve demand analysis for join points

I realised (Trac #13543) that we can improve demand analysis for
join point quite straightforwardly.

The idea is explained in
    Note [Demand analysis for join points]
in DmdAnal
parent 751996e9
......@@ -64,20 +64,20 @@ dmdAnalProgram dflags fam_envs binds
dmdAnalTopBind :: AnalEnv
-> CoreBind
-> (AnalEnv, CoreBind)
dmdAnalTopBind sigs (NonRec id rhs)
= (extendAnalEnv TopLevel sigs id2 (idStrictness id2), NonRec id2 rhs2)
dmdAnalTopBind env (NonRec id rhs)
= (extendAnalEnv TopLevel env id2 (idStrictness id2), NonRec id2 rhs2)
where
( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing sigs id rhs
( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin sigs) id rhs1
( _, _, rhs1) = dmdAnalRhsLetDown TopLevel Nothing env cleanEvalDmd id rhs
( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin env) cleanEvalDmd id rhs1
-- Do two passes to improve CPR information
-- See Note [CPR for thunks]
-- See Note [Optimistic CPR in the "virgin" case]
-- See Note [Initial CPR for strict binders]
dmdAnalTopBind sigs (Rec pairs)
= (sigs', Rec pairs')
dmdAnalTopBind env (Rec pairs)
= (env', Rec pairs')
where
(sigs', _, pairs') = dmdFix TopLevel sigs pairs
(env', _, pairs') = dmdFix TopLevel env cleanEvalDmd pairs
-- We get two iterations automatically
-- c.f. the NonRec case above
......@@ -308,7 +308,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
dmdAnal' env dmd (Let (NonRec id rhs) body)
= (body_ty2, Let (NonRec id2 rhs') body')
where
(lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env id rhs
(lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env dmd id rhs
env1 = extendAnalEnv NotTopLevel env id1 (idStrictness id1)
(body_ty, body') = dmdAnal env1 dmd body
(body_ty1, id2) = annotateBndr env body_ty id1
......@@ -329,7 +329,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
dmdAnal' env dmd (Let (Rec pairs) body)
= let
(env', lazy_fv, pairs') = dmdFix NotTopLevel env pairs
(env', lazy_fv, pairs') = dmdFix NotTopLevel env dmd pairs
(body_ty, body') = dmdAnal env' dmd body
body_ty1 = deleteFVs body_ty (map fst pairs)
body_ty2 = addLazyFVs body_ty1 lazy_fv -- see Note [Lazy and unleasheable free variables]
......@@ -509,17 +509,17 @@ dmdTransform env var dmd
-- Recursive bindings
dmdFix :: TopLevelFlag
-> AnalEnv -- Does not include bindings for this binding
-> CleanDemand
-> [(Id,CoreExpr)]
-> (AnalEnv, DmdEnv, [(Id,CoreExpr)]) -- Binders annotated with stricness info
dmdFix top_lvl env orig_pairs
dmdFix top_lvl env let_dmd orig_pairs
= loop 1 initial_pairs
where
bndrs = map fst orig_pairs
-- See Note [Initialising strictness]
initial_pairs | ae_virgin env = [(setIdStrictness id botSig, rhs) | (id, rhs) <- orig_pairs ]
| otherwise = orig_pairs
-- If fixed-point iteration does not yield a result we use this instead
......@@ -562,7 +562,7 @@ dmdFix top_lvl env orig_pairs
my_downRhs (env, lazy_fv) (id,rhs)
= ((env', lazy_fv'), (id', rhs'))
where
(lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env id rhs
(lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env let_dmd id rhs
lazy_fv' = plusVarEnv_C bothDmd lazy_fv lazy_fv1
env' = extendAnalEnv top_lvl env id (idStrictness id')
......@@ -621,18 +621,27 @@ dmdAnalTrivialRhs env id rhs fn
-- This is the LetDown rule in the paper “Higher-Order Cardinality Analysis”.
dmdAnalRhsLetDown :: TopLevelFlag
-> Maybe [Id] -- Just bs <=> recursive, Nothing <=> non-recursive
-> AnalEnv -> Id -> CoreExpr
-> AnalEnv -> CleanDemand
-> Id -> CoreExpr
-> (DmdEnv, Id, CoreExpr)
-- Process the RHS of the binding, add the strictness signature
-- to the Id, and augment the environment with the signature as well.
dmdAnalRhsLetDown top_lvl rec_flag env id rhs
dmdAnalRhsLetDown top_lvl rec_flag env let_dmd id rhs
| Just fn <- unpackTrivial rhs -- See Note [Demand analysis for trivial right-hand sides]
= dmdAnalTrivialRhs env id rhs fn
| otherwise
= (lazy_fv, id', mkLams bndrs' body')
where
(bndrs, body) = collectBinders rhs
(bndrs, body, body_dmd)
= case isJoinId_maybe id of
Just join_arity -- See Note [Demand analysis for join points]
| (bndrs, body) <- collectNBinders join_arity rhs
-> (bndrs, body, let_dmd)
Nothing | (bndrs, body) <- collectBinders rhs
-> (bndrs, body, mkBodyDmd env body)
env_body = foldl extendSigsWithLam env bndrs
(body_ty, body') = dmdAnal env_body body_dmd body
body_ty' = removeDmdTyArgs body_ty -- zap possible deep CPR info
......@@ -642,10 +651,6 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
id' = set_idStrictness env id sig_ty
-- See Note [NOINLINE and strictness]
-- See Note [Product demands for function body]
body_dmd = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
Nothing -> cleanEvalDmd
Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
-- See Note [Aggregated demand for cardinality]
rhs_fv1 = case rec_flag of
......@@ -667,6 +672,13 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
|| not (isStrictDmd (idDemandInfo id) || ae_virgin env)
-- See Note [Optimistic CPR in the "virgin" case]
mkBodyDmd :: AnalEnv -> CoreExpr -> CleanDemand
-- See Note [Product demands for function body]
mkBodyDmd env body
= case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
Nothing -> cleanEvalDmd
Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
unpackTrivial :: CoreExpr -> Maybe Id
-- Returns (Just v) if the arg is really equal to v, modulo
-- casts, type applications etc
......@@ -691,7 +703,37 @@ useLetUp _ (Lam _ _) = False
useLetUp _ _ = True
{-
{- Note [Demand analysis for join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
g :: (Int,Int) -> Int
g (p,q) = p+q
f :: T -> Int -> Int
f x p = g (join j y = (p,y)
in case x of
A -> j 3
B -> j 4
C -> (p,7))
If j was a vanilla function definition, we'd analyse its body with
evalDmd, and think that it was lazy in p. But for join points we can
do better! We know that j's body will (if called at all) be evaluated
with the demand that consumes the entire join-binding, in this case
the argument demand from g. Whizzo! g evaluates both components of
its arugment pair, so p will certainly be evaluated if j is called.
For f to be strict in p, we need /all/ paths to evaluate p; in this
case the C branch does so too, so we are fine. So, as usual, we need
to transport demands on free variables to the call site(s). Compare
Note [Lazy and unleasheable free variables].
The implementation is easy. Wwhen analysing a join point, we can
analyse its body with the demand from the entire join-binding (written
let_dmd here).
Another win for join points! Trac #13543.
Note [Demand analysis for trivial right-hand sides]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider
......
{-# LANGUAGE RankNTypes, GADTs #-}
module Foo where
g :: (Int, Int) -> Int
{-# NOINLINE g #-}
g (p,q) = p+q
f :: Int -> Int -> Int -> Int
f x p q
= g (let j y = (p,q)
{-# NOINLINE j #-}
in
case x of
2 -> j 3
_ -> j 4)
......@@ -259,3 +259,4 @@ test('T13468',
normal,
run_command,
['$MAKE -s --no-print-directory T13468'])
test('T13543', normal, compile, ['-ddump-str-signatures'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment