Commit b14c0373 authored by Joachim Breitner's avatar Joachim Breitner

Some cleanup of the Exitification code

based on a thorough review by Simon in comments
through 37.

The changes are:

 * `isExitJoinId` is moved to `SimplUtils`, because
   it is only valid when occurrence information is up-to-date.
 * Abstracted variables are properly sorted using `sortQuantVars`
 * Exitification does not set occ info.

 And then minor quibles to notes and avoiding some unhelpful shadowing
 of local names.

 Differential Revision:
parent 8b823f27
......@@ -74,7 +74,7 @@ module Id (
DictId, isDictId, isEvVar,
-- ** Join variables
JoinId, isJoinId, isJoinId_maybe, idJoinArity, isExitJoinId,
JoinId, isJoinId, isJoinId_maybe, idJoinArity,
asJoinId, asJoinId_maybe, zapJoinId,
-- ** Inline pragma stuff
......@@ -498,10 +498,6 @@ isJoinId_maybe id
_ -> Nothing
| otherwise = Nothing
-- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs
isExitJoinId :: Var -> Bool
isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
idDataCon :: Id -> DataCon
-- ^ Get from either the worker or the wrapper 'Id' to the 'DataCon'. Currently used only in the desugarer.
......@@ -48,16 +48,19 @@ import VarEnv
import CoreFVs
import FastString
import Type
import MkCore ( sortQuantVars )
import Data.Bifunctor
import Control.Monad
-- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
-- The really interesting function is exitify
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram binds = map goTopLvl binds
goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
-- Top-level bindings are never join points
in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
......@@ -91,6 +94,10 @@ exitifyProgram binds = map goTopLvl binds
is_join_rec = any (isJoinId . fst) pairs
in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
-- | State Monad used inside `exitify`
type ExitifyM = State [(JoinId, CoreExpr)]
-- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
-- join-points outside the joinrec.
exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
......@@ -120,11 +127,13 @@ exitify in_scope pairs =
-- checks if there are no more recursive calls, if so, abstracts over
-- variables bound on the way and lifts it out as a join point.
-- It uses a state monad to keep track of floated binds
-- ExitifyM is a state monad to keep track of floated binds
go :: [Var] -- ^ variables to abstract over
-> CoreExprWithFVs -- ^ current expression in tail position
-> State [(Id, CoreExpr)] CoreExpr
-> ExitifyM CoreExpr
-- We first look at the expression (no matter what it shape is)
-- and determine if we can turn it into a exit join point
go captured ann_e
-- Do not touch an expression that is already a join jump where all arguments
-- are captured variables. See Note [Idempotency]
......@@ -145,13 +154,13 @@ exitify in_scope pairs =
-- We have something to float out!
| is_exit = do
-- Assemble the RHS of the exit join point
let rhs = mkLams args e
let rhs = mkLams abs_vars e
ty = exprType rhs
let avoid = in_scope `extendInScopeSetList` captured
-- Remember this binding under a suitable name
v <- addExit avoid ty (length args) rhs
v <- addExit avoid ty (length abs_vars) rhs
-- And jump to it from here
return $ mkVarApps (Var v) args
return $ mkVarApps (Var v) abs_vars
-- An exit expression has no recursive calls
is_exit = disjointVarSet fvs recursive_calls
......@@ -166,14 +175,17 @@ exitify in_scope pairs =
is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
-- The possible arguments of this exit join point
args = filter (`elemVarSet` fvs) captured
abs_vars = sortQuantVars $ filter (`elemVarSet` fvs) captured
-- We cannot abstract over join points
captures_join_points = any isJoinId args
captures_join_points = any isJoinId abs_vars
e = deAnnotate ann_e
fvs = dVarSetToVarSet (freeVarsOf ann_e)
-- We could not turn it into a exit joint point. So now recurse
-- into all expression where eligible exit join points might sit,
-- i.e. into all tail-call positions:
-- Case right hand sides are in tail-call position
go captured (_, AnnCase scrut bndr ty alts) = do
......@@ -211,6 +223,8 @@ exitify in_scope pairs =
return $ Let bind body'
where bind = deAnnBind ann_bind
-- Cannot be turned into an exit join point, but also has no
-- tail-call subexpression. Nothing to do here.
go _ ann_e = return (deAnnotate ann_e)
......@@ -227,14 +241,6 @@ mkExitJoinId in_scope ty join_arity = do
exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
`asJoinId` join_arity
`setIdOccInfo` exit_occ_info
-- See Note [Do not inline exit join points]
exit_occ_info =
OneOcc { occ_in_lam = True
, occ_one_br = True
, occ_int_cxt = False
, occ_tail = AlwaysTailCalled join_arity }
addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
addExit in_scope ty join_arity rhs = do
......@@ -245,8 +251,6 @@ addExit in_scope ty join_arity rhs = do
return v
type ExitifyM = State [(JoinId, CoreExpr)]
Note [Interesting expression]
......@@ -381,6 +385,8 @@ joinrecs are nested.
Further downside of A: If the exitify function returns annotated expressions,
it would have to ensure that the annotations are correct.
We therefore choose B, and calculate the free variables in `exitify`.
Note [Do not inline exit join points]
......@@ -399,7 +405,8 @@ To prevent this, we need to recognize exit join points, and then disable
Exit join points, recognizeable using `isExitJoinId` are join points with an
occurence in a recursive group, and can be recognized using `isExitJoinId`.
occurence in a recursive group, and can be recognized (after the occurence
analyzer ran!) using `isExitJoinId`.
This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
because the lambdas of a non-recursive join point are not considered for
`occ_in_lam`. For example, in the following code, `j1` is /not/ marked
......@@ -408,8 +415,6 @@ occ_in_lam, because `j2` is called only once.
join j1 x = x+1
join j2 y = join j1 (y+2)
We create exit join point ids with such an `OccInfo`, see `exit_occ_info`.
To prevent inlining, we check for isExitJoinId
* In `preInlineUnconditionally` directly.
* In `simplLetUnfolding` we simply give exit join points no unfolding, which
......@@ -30,7 +30,10 @@ module SimplUtils (
addValArgTo, addCastTo, addTyArgTo,
argInfoExpr, argInfoAppArgs, pushSimplifiedArgs,
-- Utilities
) where
#include "HsVersions.h"
......@@ -2199,6 +2202,13 @@ in PrelRules)
mkCase3 _dflags scrut bndr alts_ty alts
= return (Case scrut bndr alts_ty alts)
-- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs
-- This lives here (and not in Id) becuase occurrence info is only valid on
-- InIds, so it's crucial that isExitJoinId is only called on freshly
-- occ-analysed code. It's not a generic function you can call anywhere.
isExitJoinId :: Var -> Bool
isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
Note [Dead binders]
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment