Commit 777b7707 authored by Simon Peyton Jones's avatar Simon Peyton Jones Committed by David Feuer
Browse files

Mark non-recursive join lambdas as one-shot

When we have

  join j x y = rhs in ...

we know that the lambdas for 'x' and 'y' are one-shot.
Let's mark them as such!

This doesn't fix a specific bug, but it feels right to me.

Reviewers: austin, bgamari

Reviewed By: bgamari

Subscribers: lukemaurer, thomie

Differential Revision: https://phabricator.haskell.org/D3196
parent 6eb52cfc
......@@ -732,7 +732,6 @@ add this analysis if necessary.
------------------------------------------------------------
Note [Adjusting for lambdas]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There's a bit of a dance we need to do after analysing a lambda expression or
a right-hand side. In particular, we need to
......@@ -802,28 +801,33 @@ occAnalNonRecBind env lvl imp_rule_edges binder rhs body_usage
| otherwise -- It's mentioned in the body
= (body_usage' +++ rhs_usage', [NonRec tagged_binder rhs'])
where
(bndrs, body) = collectBinders rhs
(body_usage', tagged_binder) = tagNonRecBinder lvl body_usage binder
mb_join_arity = willBeJoinId_maybe tagged_binder
(bndrs, body) = collectBinders rhs
(rhs_usage1, bndrs', body') = occAnalNonRecRhs env tagged_binder bndrs body
rhs' = mkLams bndrs' body'
rhs' = mkLams (markJoinOneShots mb_join_arity bndrs') body'
-- For a /non-recursive/ join point we can mark all
-- its join-lambda as one-shot; and it's a good idea to do so
-- Unfoldings
-- See Note [Unfoldings and join points]
rhs_usage2 = case occAnalUnfolding env NonRecursive binder of
Just unf_usage -> rhs_usage1 +++ unf_usage
Nothing -> rhs_usage1
-- See Note [Unfoldings and join points]
mb_join_arity = willBeJoinId_maybe tagged_binder
-- Rules
-- See Note [Rules are extra RHSs] and Note [Rule dependency info]
rules_w_uds = occAnalRules env mb_join_arity NonRecursive tagged_binder
rhs_usage3 = rhs_usage2 +++ combineUsageDetailsList
(map (\(_, l, r) -> l +++ r) rules_w_uds)
-- See Note [Rules are extra RHSs] and Note [Rule dependency info]
rhs_usage4 = maybe rhs_usage3 (addManyOccsSet rhs_usage3) $
lookupVarEnv imp_rule_edges binder
-- See Note [Preventing loops due to imported functions rules]
rhs_usage' = adjustRhsUsage (willBeJoinId_maybe tagged_binder) NonRecursive
bndrs' rhs_usage4
-- Final adjustment
rhs_usage' = adjustRhsUsage mb_join_arity NonRecursive bndrs' rhs_usage4
-----------------
occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
......@@ -1550,7 +1554,6 @@ occAnalNonRecRhs env bndr bndrs body
-- See Note [Sources of one-shot information]
rhs_env = env1 { occ_one_shots = argOneShots dmd }
certainly_inline -- See Note [Cascading inlines]
= case idOccInfo bndr of
OneOcc { occ_in_lam = in_lam, occ_one_br = one_br }
......@@ -1731,7 +1734,8 @@ occAnal env app@(App _ _)
-- (a) occurrences inside type lambdas only not marked as InsideLam
-- (b) type variables not in environment
occAnal env (Lam x body) | isTyVar x
occAnal env (Lam x body)
| isTyVar x
= case occAnal env body of { (body_usage, body') ->
(markAllNonTailCalled body_usage, Lam x body')
}
......@@ -1749,14 +1753,14 @@ occAnal env expr@(Lam _ _)
= case occAnalLamOrRhs env binders body of { (usage, tagged_binders, body') ->
let
expr' = mkLams tagged_binders body'
final_usage | all isOneShotBndr tagged_binders
= markAllNonTailCalled usage
| otherwise
= markAllInsideLam $ markAllNonTailCalled usage
usage1 = markAllNonTailCalled usage
one_shot_gp = all isOneShotBndr tagged_binders
final_usage | one_shot_gp = usage1
| otherwise = markAllInsideLam usage1
in
(final_usage, expr') }
where
(binders, body) = collectBinders expr
(binders, body) = collectBinders expr
occAnal env (Case scrut bndr ty alts)
= case occ_anal_scrut scrut alts of { (scrut_usage, scrut') ->
......@@ -2130,21 +2134,31 @@ oneShotGroup env@(OccEnv { occ_one_shots = ctxt }) bndrs
= ( env { occ_one_shots = [], occ_encl = OccVanilla }
, reverse rev_bndrs ++ bndrs )
go ctxt (bndr:bndrs) rev_bndrs
| isId bndr
= case ctxt of
[] -> go [] bndrs (bndr : rev_bndrs)
(one_shot : ctxt) -> go ctxt bndrs (bndr': rev_bndrs)
where
bndr' = updOneShotInfo bndr one_shot
go ctxt@(one_shot : ctxt') (bndr : bndrs) rev_bndrs
| isId bndr = go ctxt' bndrs (bndr': rev_bndrs)
| otherwise = go ctxt bndrs (bndr : rev_bndrs)
where
bndr' = updOneShotInfo bndr one_shot
-- Use updOneShotInfo, not setOneShotInfo, as pre-existing
-- one-shot info might be better than what we can infer, e.g.
-- due to explicit use of the magic 'oneShot' function.
-- See Note [The oneShot function]
| otherwise
= go ctxt bndrs (bndr:rev_bndrs)
markJoinOneShots :: Maybe JoinArity -> [Var] -> [Var]
-- Mark the lambdas of a non-recursive join point as one-shot.
-- This is good to prevent gratuitous float-out etc
markJoinOneShots mb_join_arity bndrs
= case mb_join_arity of
Nothing -> bndrs
Just n -> go n bndrs
where
go 0 bndrs = bndrs
go _ [] = WARN( True, ppr mb_join_arity <+> ppr bndrs ) []
go n (b:bs) = b' : go (n-1) bs
where
b' | isId b = setOneShotLambda b
| otherwise = b
addAppCtxt :: OccEnv -> [Arg CoreBndr] -> OccEnv
addAppCtxt env@(OccEnv { occ_one_shots = ctxt }) args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment