Commit a7628dcd authored by Simon Peyton Jones's avatar Simon Peyton Jones

Deal with join points with RULES

Trac #13900 showed that when we have a join point that
has a RULE, we must push the continuation into the RHS
of the RULE.

See Note [Rules and unfolding for join points]

It's hard to tickle this bug, so I have not added a regression test.
parent 9cc6a182
......@@ -767,7 +767,7 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
-- for imported Ids. Eg RULE map my_f = blah
-- If we have a substitution my_f :-> other_f, we'd better
-- apply it to the rule to, or it'll never match
; rules1 <- simplRules env1 Nothing rules
; rules1 <- simplRules env1 Nothing rules Nothing
; return (getTopFloatBinds floats, rules1) } ;
......
......@@ -24,7 +24,7 @@ import Id
import MkId ( seqId )
import MkCore ( mkImpossibleExpr, castBottomExpr )
import IdInfo
import Name ( Name, mkSystemVarName, isExternalName, getOccFS )
import Name ( mkSystemVarName, isExternalName, getOccFS )
import Coercion hiding ( substCo, substCoVar )
import OptCoercion ( optCoercion )
import FamInstEnv ( topNormaliseType_maybe )
......@@ -143,11 +143,11 @@ simplTopBinds env0 binds0
; (floats, env2) <- simpl_binds env1 binds
; return (float `addFloats` floats, env2) }
simpl_bind env (Rec pairs) = simplRecBind env TopLevel Nothing pairs
simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b)
; simplRecOrTopPair env' TopLevel
NonRecursive Nothing
b b' r }
simpl_bind env (Rec pairs)
= simplRecBind env TopLevel Nothing pairs
simpl_bind env (NonRec b r)
= do { (env', b') <- addBndrRules env b (lookupRecBndr env b) Nothing
; simplRecOrTopPair env' TopLevel NonRecursive Nothing b b' r }
{-
************************************************************************
......@@ -160,7 +160,7 @@ simplRecBind is used for
* recursive bindings only
-}
simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont
simplRecBind :: SimplEnv -> TopLevelFlag -> MaybeJoinCont
-> [(InId, InExpr)]
-> SimplM (SimplFloats, SimplEnv)
simplRecBind env0 top_lvl mb_cont pairs0
......@@ -171,7 +171,7 @@ simplRecBind env0 top_lvl mb_cont pairs0
add_rules :: SimplEnv -> (InBndr,InExpr) -> SimplM (SimplEnv, (InBndr, OutBndr, InExpr))
-- Add the (substituted) rules to the binder
add_rules env (bndr, rhs)
= do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr)
= do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr) mb_cont
; return (env', (bndr, bndr', rhs)) }
go env [] = return (emptyFloats env, env)
......@@ -191,7 +191,7 @@ It assumes the binder has already been simplified, but not its IdInfo.
-}
simplRecOrTopPair :: SimplEnv
-> TopLevelFlag -> RecFlag -> Maybe SimplCont
-> TopLevelFlag -> RecFlag -> MaybeJoinCont
-> InId -> OutBndr -> InExpr -- Binder and rhs
-> SimplM (SimplFloats, SimplEnv)
......@@ -616,7 +616,7 @@ Nor does it do the atomic-argument thing
completeBind :: SimplEnv
-> TopLevelFlag -- Flag stuck into unfolding
-> Maybe SimplCont -- Required only for join point
-> MaybeJoinCont -- Required only for join point
-> InId -- Old binder
-> OutId -> OutExpr -- New binder and RHS
-> SimplM (SimplFloats, SimplEnv)
......@@ -645,7 +645,7 @@ completeBind env top_lvl mb_cont old_bndr new_bndr new_rhs
-- Simplify the unfolding
; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr
final_rhs old_unf
final_rhs (idType new_bndr) old_unf
; let final_bndr = addLetBndrInfo new_bndr new_arity is_bot new_unfolding
......@@ -1319,7 +1319,8 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
simplLamBndr env bndr
| isId bndr && isFragileUnfolding old_unf -- Special case
= do { (env1, bndr1) <- simplBinder env bndr
; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr old_unf
; unf' <- simplStableUnfolding env1 NotTopLevel Nothing bndr
old_unf (idType bndr1)
; let bndr2 = bndr1 `setIdUnfolding` unf'
; return (modifyInScope env1 bndr2, bndr2) }
......@@ -1378,7 +1379,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
| otherwise
= ASSERT( not (isTyVar bndr) )
do { (env1, bndr1) <- simplNonRecBndr env bndr
; (env2, bndr2) <- addBndrRules env1 bndr bndr1
; (env2, bndr2) <- addBndrRules env1 bndr bndr1 Nothing
; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
; (floats2, expr') <- simplLam env3 bndrs body cont
; return (floats1 `addFloats` floats2, expr') }
......@@ -1450,6 +1451,33 @@ Here it'd be far better to drop the unfolding and use the actual RHS.
* *
********************************************************************* -}
{- Note [Rules and unfolding for join points]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Suppose we have
simplExpr (join j x = rhs ) cont
( {- RULE j (p:ps) = blah -} )
( {- StableUnfolding j = blah -} )
(in blah )
Then we will push 'cont' into the rhs of 'j'. But we should *also* push
'cont' into the RHS of
* Any RULEs for j, e.g. generated by SpecConstr
* Any stable unfolding for j, e.g. the result of an INLINE pragma
Simplifying rules and stable-unfoldings happens a bit after
simplifying the right-hand side, so we remember whether or not it
is a join point, and what 'cont' is, in a value of type MaybeJoinCont
Trac #13900 wsa caused by forgetting to push 'cont' into the RHS
of a SpecConstr-generated RULE for a join point.
-}
type MaybeJoinCont = Maybe SimplCont
-- Nothing => Not a join point
-- Just k => This is a join binding with continuation k
-- See Note [Rules and unfolding for join points]
simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
-> InExpr -> SimplCont
-> SimplM (SimplFloats, OutExpr)
......@@ -1465,7 +1493,7 @@ simplNonRecJoinPoint env bndr rhs body cont
-- and wrap wrap_cont around the whole thing
; let res_ty = contResultType cont
; (env1, bndr1) <- simplNonRecJoinBndr env res_ty bndr
; (env2, bndr2) <- addBndrRules env1 bndr bndr1
; (env2, bndr2) <- addBndrRules env1 bndr bndr1 (Just cont)
; (floats1, env3) <- simplJoinBind env2 cont bndr bndr2 rhs env
; (floats2, body') <- simplExprF env3 body cont
; return (floats1 `addFloats` floats2, body') }
......@@ -3235,13 +3263,13 @@ because we don't know its usage in each RHS separately
-}
simplLetUnfolding :: SimplEnv-> TopLevelFlag
-> Maybe SimplCont
-> MaybeJoinCont
-> InId
-> OutExpr
-> OutExpr -> OutType
-> Unfolding -> SimplM Unfolding
simplLetUnfolding env top_lvl cont_mb id new_rhs unf
simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf
| isStableUnfolding unf
= simplStableUnfolding env top_lvl cont_mb id unf
= simplStableUnfolding env top_lvl cont_mb id unf rhs_ty
| isExitJoinId id
= return noUnfolding -- see Note [Do not inline exit join points]
| otherwise
......@@ -3265,26 +3293,26 @@ mkLetUnfolding dflags top_lvl src id new_rhs
-------------------
simplStableUnfolding :: SimplEnv -> TopLevelFlag
-> Maybe SimplCont -- Just k => a join point with continuation k
-> MaybeJoinCont -- Just k => a join point with continuation k
-> InId
-> Unfolding -> SimplM Unfolding
-> Unfolding -> OutType -> SimplM Unfolding
-- Note [Setting the new unfolding]
simplStableUnfolding env top_lvl mb_cont id unf
simplStableUnfolding env top_lvl mb_cont id unf rhs_ty
= case unf of
NoUnfolding -> return unf
BootUnfolding -> return unf
OtherCon {} -> return unf
DFunUnfolding { df_bndrs = bndrs, df_con = con, df_args = args }
-> do { (env', bndrs') <- simplBinders rule_env bndrs
-> do { (env', bndrs') <- simplBinders unf_env bndrs
; args' <- mapM (simplExpr env') args
; return (mkDFunUnfolding bndrs' con args') }
CoreUnfolding { uf_tmpl = expr, uf_src = src, uf_guidance = guide }
| isStableSource src
-> do { expr' <- case mb_cont of
Just cont -> simplJoinRhs rule_env id expr cont
Nothing -> simplExpr rule_env expr
-> do { expr' <- case mb_cont of -- See Note [Rules and unfolding for join points]
Just cont -> simplJoinRhs unf_env id expr cont
Nothing -> simplExprC unf_env expr (mkBoringStop rhs_ty)
; case guide of
UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok } -- Happens for INLINE things
-> let guide' = UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok
......@@ -3308,7 +3336,7 @@ simplStableUnfolding env top_lvl mb_cont id unf
dflags = seDynFlags env
is_top_lvl = isTopLevel top_lvl
act = idInlineActivation id
rule_env = updMode (updModeForStableUnfoldings act) env
unf_env = updMode (updModeForStableUnfoldings act) env
-- See Note [Simplifying inside stable unfoldings] in SimplUtils
{-
......@@ -3350,20 +3378,24 @@ to apply in that function's own right-hand side.
See Note [Forming Rec groups] in OccurAnal
-}
addBndrRules :: SimplEnv -> InBndr -> OutBndr -> SimplM (SimplEnv, OutBndr)
addBndrRules :: SimplEnv -> InBndr -> OutBndr
-> MaybeJoinCont -- Just k for a join point binder
-- Nothing otherwise
-> SimplM (SimplEnv, OutBndr)
-- Rules are added back into the bin
addBndrRules env in_id out_id
addBndrRules env in_id out_id mb_cont
| null old_rules
= return (env, out_id)
| otherwise
= do { new_rules <- simplRules env (Just (idName out_id)) old_rules
= do { new_rules <- simplRules env (Just out_id) old_rules mb_cont
; let final_id = out_id `setIdSpecialisation` mkRuleInfo new_rules
; return (modifyInScope env final_id, final_id) }
where
old_rules = ruleInfoRules (idSpecialisation in_id)
simplRules :: SimplEnv -> Maybe Name -> [CoreRule] -> SimplM [CoreRule]
simplRules env mb_new_nm rules
simplRules :: SimplEnv -> Maybe OutId -> [CoreRule]
-> MaybeJoinCont -> SimplM [CoreRule]
simplRules env mb_new_id rules mb_cont
= mapM simpl_rule rules
where
simpl_rule rule@(BuiltinRule {})
......@@ -3373,11 +3405,29 @@ simplRules env mb_new_nm rules
, ru_fn = fn_name, ru_rhs = rhs })
= do { (env', bndrs') <- simplBinders env bndrs
; let rhs_ty = substTy env' (exprType rhs)
rule_cont = mkBoringStop rhs_ty
rule_env = updMode updModeForRules env'
rhs_cont = case mb_cont of -- See Note [Rules and unfolding for join points]
Nothing -> mkBoringStop rhs_ty
Just cont -> ASSERT2( join_ok, bad_join_msg )
cont
rule_env = updMode updModeForRules env'
fn_name' = case mb_new_id of
Just id -> idName id
Nothing -> fn_name
-- join_ok is an assertion check that the join-arity of the
-- binder matches that of the rule, so that pushing the
-- continuation into the RHS makes sense
join_ok = case mb_new_id of
Just id | Just join_arity <- isJoinId_maybe id
-> length args == join_arity
_ -> False
bad_join_msg = vcat [ ppr mb_new_id, ppr rule
, ppr (fmap isJoinId_maybe mb_new_id) ]
; args' <- mapM (simplExpr rule_env) args
; rhs' <- simplExprC rule_env rhs rule_cont
; rhs' <- simplExprC rule_env rhs rhs_cont
; return (rule { ru_bndrs = bndrs'
, ru_fn = mb_new_nm `orElse` fn_name
, ru_fn = fn_name'
, ru_args = args'
, ru_rhs = rhs' }) }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment