Commit f347bfea authored by Joachim Breitner's avatar Joachim Breitner

Support mutual recursion

parent d51d7efd
......@@ -229,9 +229,9 @@ arity as for the whole expression.
calls are OnceAndOnly calls, then the body calls *either* the rhs *or* one
of the other mentioned variables. Similarly, the rhs calls *either* itself
again *or* one of the other mentioned variables. This precision is required!
We do not analyse mutually recursive functions. This can be done once we see it
in the wild.
If the recursive function is called by the body, or the rhs, tagged with Many
then we can also just `lubEnv`, because the result will no longer contain
any OnceAndOnly values.
Note [Case and App: Which side to take?]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -284,7 +284,7 @@ callArityTopLvl exported int1 (b:bs)
where
int2 = interestingBinds b
exported' = filter isExportedId int2 ++ exported
int' = int1 `extendVarSetList` int2
int' = int1 `addInterestingBinds` b
(ae1, bs') = callArityTopLvl exported' int' bs
(ae2, b') = callArityBind ae1 int1 b
......@@ -331,20 +331,20 @@ callArityAnal arity int e@(Var v)
-- Non-value lambdas are ignored
callArityAnal arity int (Lam v e) | not (isId v)
= second (Lam v) $ callArityAnal arity int e
= second (Lam v) $ callArityAnal arity (int `delVarSet` v) e
-- We have a lambda that we are not sure to call. Tail calls therein
-- are no longer OneAndOnly calls
callArityAnal 0 int (Lam v e)
= (ae', Lam v e')
where
(ae, e') = callArityAnal 0 int e
(ae, e') = callArityAnal 0 (int `delVarSet` v) e
ae' = forgetOnceCalls ae
-- We have a lambda that we are calling. decrease arity.
callArityAnal arity int (Lam v e)
= (ae, Lam v e')
where
(ae, e') = callArityAnal (arity - 1) int e
(ae, e') = callArityAnal (arity - 1) (int `delVarSet` v) e
-- For lets, use callArityBind
callArityAnal arity int (Let bind e)
......@@ -352,7 +352,7 @@ callArityAnal arity int (Let bind e)
-- (vcat [ppr v, ppr arity, ppr n, ppr final_ae ])
(final_ae, Let bind' e')
where
int_body = int `extendVarSetList` interestingBinds bind
int_body = int `addInterestingBinds` bind
(ae_body, e') = callArityAnal arity int_body e
(final_ae, bind') = callArityBind ae_body int bind
......@@ -396,6 +396,22 @@ interestingBinds bind =
where
go (v,e) = exprArity e < length (typeArity (idType v))
boringBinds :: CoreBind -> [Var]
boringBinds bind =
map fst $ filter go $ case bind of (NonRec v e) -> [(v,e)]
(Rec ves) -> ves
where
go (v,e) = exprArity e >= length (typeArity (idType v))
addInterestingBinds :: VarSet -> CoreBind -> VarSet
addInterestingBinds int bind
= int `delVarSetList` bindersOf bind -- Possible shadowing
`extendVarSetList` interestingBinds bind
addBoringCalls :: CallArityEnv -> CoreBind -> CallArityEnv
addBoringCalls ae bind
= ae `lubEnv` (mkVarEnv $ zip (boringBinds bind) (repeat topCallCount))
-- Used for both local and top-level binds
-- First argument is the demand from the body
callArityBind :: CallArityEnv -> VarSet -> CoreBind -> (CallArityEnv, CoreBind)
......@@ -412,47 +428,58 @@ callArityBind ae_body int (NonRec v rhs)
v' = v `setIdCallArity` safe_arity
-- Recursive let. See Note [Recursion and fixpointing]
callArityBind ae_body int b@(Rec [(v,rhs)])
= -- pprTrace "callArityBind:Rec"
-- (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr new_arity])
(final_ae, Rec [(v',rhs')])
where
int_body = int `extendVarSetList` interestingBinds b
callcount = lookupWithDefaultVarEnv ae_body topCallCount v
(ae_rhs, new_arity, rhs') = callArityFix callcount int_body v rhs
final_ae = (ae_rhs `lubEnv` ae_body) `delVarEnv` v
v' = v `setIdCallArity` new_arity
-- Mutual recursion. Do nothing serious here, for now
callArityBind ae_body int (Rec binds)
callArityBind ae_body int b@(Rec binds)
= (final_ae, Rec binds')
where
(aes, binds') = unzip $ map go binds
go (i,e) = let (ae, _, e') = callArityBound topCallCount int e
in (ae, (i,e'))
final_ae = foldl lubEnv ae_body aes `delVarEnvList` map fst binds
int_body = int `addInterestingBinds` b
-- We are ignoring calls to boring binds, so we need to pretend them here!
ae_body' = ae_body `addBoringCalls` b
(ae_rhs, binds') = callArityFix ae_body' int_body [(i,Nothing,e) | (i,e) <- binds]
final_ae = ae_rhs `delVarEnvList` interestingBinds b
-- Here we do the fix-pointing for possibly mutually recursive values. The
-- idea is that we start with the calls coming from the body, and analyize
-- every called value with that arity, adding lub these calls into the
-- environment. We also remember for each variable the CallCount we analised it
-- with. Then we check for every variable if in the new envrionment, it is
-- called with a different (i.e. lower) arity. If so, we reanalize that, and
-- lub the result back into the environment. If we had a change for any of the
-- variables, we repeat this step, otherwise we are done.
callArityFix ::
CallArityEnv -> VarSet ->
[(Id, Maybe CallCount, CoreExpr)] ->
(CallArityEnv, [(Id, CoreExpr)])
callArityFix ae int ann_binds
| any_change
= callArityFix ae' int ann_binds'
| otherwise
= (ae', map (\(i, a, e) -> (i `setArity` a, e)) ann_binds')
where
(changes, ae's, ann_binds') = unzip3 $ map rerun ann_binds
any_change = or changes
ae' = foldl lubEnv ae ae's
rerun (i, mbArity, rhs)
callArityFix :: CallCount -> VarSet -> Id -> CoreExpr -> (CallArityEnv, Arity, CoreExpr)
callArityFix arity int v e
| mb_new_arity == mbArity
-- No change. No need to re-analize, and no need to change the arity
-- environment
= (False, emptyVarEnv, (i,mbArity, rhs))
| arity `lteCallCount` min_arity
-- The incoming arity is already lower than the exprArity, so we can
-- ignore the arity coming from the RHS
= (ae `delVarEnv` v, 0, e')
| Just new_arity <- mb_new_arity
-- We previously analized this with a different arity (or not at all)
= let (ae_rhs, safe_arity, rhs') = callArityBound new_arity int rhs
in (True, ae_rhs, (i `setIdCallArity` safe_arity, mb_new_arity, rhs'))
| otherwise
= if new_arity `ltCallCount` arity
-- RHS puts a lower arity on itself, so try that
then callArityFix new_arity int v e
| otherwise
-- No call to this yet, so do nothing
= (False, emptyVarEnv, (i, mbArity, rhs))
where
mb_new_arity = lookupVarEnv ae i
setArity i Nothing = i -- Completely absent value
setArity i (Just (_, a)) = i `setIdCallArity` a
-- RHS calls itself with at least as many arguments as the body of the let: Great!
else (ae `delVarEnv` v, safe_arity, e')
where
(ae, safe_arity, e') = callArityBound arity int e
new_arity = lookupWithDefaultVarEnv ae topCallCount v
min_arity = (Many, exprArity e)
-- This is a variant of callArityAnal that takes a CallCount (i.e. an arity with a
-- cardinality) and adjust the resulting environment accordingly. It is to be used
......@@ -497,13 +524,6 @@ lubCount :: Count -> Count -> Count
lubCount OnceAndOnly OnceAndOnly = OnceAndOnly
lubCount _ _ = Many
lteCallCount :: CallCount -> CallCount -> Bool
lteCallCount (count1, arity1) (count2, arity2)
= count1 <= count2 && arity1 <= arity2
ltCallCount :: CallCount -> CallCount -> Bool
ltCallCount c1 c2 = c1 `lteCallCount` c2 && c1 /= c2
-- Used when combining results from alternative cases; take the minimum
lubEnv :: CallArityEnv -> CallArityEnv -> CallArityEnv
lubEnv = plusVarEnv_C lubCallCount
......
......@@ -126,10 +126,15 @@ exprs =
Let (Rec [ (n, mkACase (mkLams [y] $ mkLit 0) (Var d))
, (d, mkACase (mkLams [y] $ mkLit 0) (Var n))]) $
Var n `mkApps` [Var d `mkApps` [Var d `mkApps` [mkLit 0]]]
, ("mutual recursion (functions), but no thunks (both arity 2 would be good)",) $
, ("mutual recursion (functions), but no thunks",) $
Let (Rec [ (go, mkLams [x] (mkACase (mkLams [y] $ mkLit 0) (Var go2 `mkVarApps` [x])))
, (go2, mkLams [x] (mkACase (mkLams [y] $ mkLit 0) (Var go `mkVarApps` [x])))]) $
Var go `mkApps` [go2 `mkLApps` [0,1], mkLit 0]
, ("mutual recursion (functions), one boring (d 1 would be bad)",) $
mkLet d (f `mkLApps` [0]) $
Let (Rec [ (go, mkLams [x, y] (Var d `mkApps` [go2 `mkLApps` [1,2]]))
, (go2, mkLams [x] (mkACase (mkLams [y] $ mkLit 0) (Var go `mkVarApps` [x])))]) $
Var d `mkApps` [go2 `mkLApps` [0,1]]
]
main = do
......
......@@ -7,7 +7,7 @@ nested_go2:
d 1
n 1
d0:
go 0
go 1
d 0
go2 (in case crut):
go 2
......@@ -50,6 +50,10 @@ two functions (recursive):
mutual recursion (thunks), called mutiple times (both arity 1 would be bad!):
d 0
n 0
mutual recursion (functions), but no thunks (both arity 2 would be good):
mutual recursion (functions), but no thunks:
go 2
go2 2
mutual recursion (functions), one boring (d 1 would be bad):
go 0
go2 0
go2 2
d 0
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment