From 89f53f7556e5331cfaa0d258ec7b85744edaadfb Mon Sep 17 00:00:00 2001 From: Max Bolingbroke <batterseapower@hotmail.com> Date: Mon, 1 Aug 2011 23:42:38 +0100 Subject: [PATCH] More comprehensive treatment of unlifted let bindings: in particular, stop residualising them as LetRecs --- compiler/supercompile/Supercompile.hs | 24 ++++++++++++++----- .../supercompile/Supercompile/Core/Syntax.hs | 4 ---- .../Supercompile/Drive/Process.hs | 2 +- .../supercompile/Supercompile/Drive/Split.hs | 4 ++-- .../Supercompile/Evaluator/Residualise.hs | 2 +- .../Supercompile/Evaluator/Syntax.hs | 22 ++++++++++++++++- .../supercompile/Supercompile/Utilities.hs | 7 ++++++ 7 files changed, 50 insertions(+), 15 deletions(-) diff --git a/compiler/supercompile/Supercompile.hs b/compiler/supercompile/Supercompile.hs index 792583a7e739..70f1ff936e51 100644 --- a/compiler/supercompile/Supercompile.hs +++ b/compiler/supercompile/Supercompile.hs @@ -15,12 +15,12 @@ import DataCon (dataConWorkId, dataConAllTyVars, dataConRepArgTys) import VarSet import Name (localiseName) import Var (Var, isTyVar, varName, setVarName) -import Id (Id, mkSysLocal, mkSysLocalM, realIdUnfolding, idInlinePragma, isPrimOpId_maybe, isDataConWorkId_maybe, setIdNotExported, isExportedId) +import Id (Id, mkSysLocal, idType, mkSysLocalM, realIdUnfolding, idInlinePragma, isPrimOpId_maybe, isDataConWorkId_maybe, setIdNotExported, isExportedId) import MkId (mkPrimOpId) import MkCore (mkBigCoreVarTup, mkTupleSelector, mkWildValBinder) import FastString (mkFastString, fsLit) import PrimOp (primOpSig) -import Type (mkTyVarTy) +import Type (mkTyVarTy, isUnLiftedType) import Control.Monad import qualified Data.Map as M @@ -95,7 +95,12 @@ bindFloats :: ParseM S.Term -> ParseM S.Term bindFloats = bindFloatsWith . fmap ((,) []) bindFloatsWith :: ParseM ([(Var, S.Term)], S.Term) -> ParseM S.Term -bindFloatsWith act = ParseM $ \s -> case unParseM act s of (s, floats, (xes, e)) -> (s, [], S.letRecSmart (xes ++ floats) e) +bindFloatsWith act = ParseM $ \s -> case unParseM act s of (s, floats, (xes, e)) -> (s, [], S.bindManyMixedLiftedness S.termFreeVars (xes ++ floats) e) + +bindUnliftedFloats :: ParseM S.Term -> ParseM S.Term +bindUnliftedFloats act = ParseM $ \s -> case unParseM act s of (s, floats, e) -> if any (isUnLiftedType . idType . fst) floats + then (s, [], S.bindManyMixedLiftedness S.termFreeVars floats e) + else (s, floats, e) appE :: S.Term -> S.Term -> ParseM S.Term appE e1 e2 @@ -105,16 +110,16 @@ appE e1 e2 coreExprToTerm :: CoreExpr -> S.Term -coreExprToTerm = uncurry S.letRecSmart . runParseM . term +coreExprToTerm = uncurry (S.bindManyMixedLiftedness S.termFreeVars) . runParseM . term where -- PrimOp and Data are dealt with later on by generating appropriate unfoldings term (Var x) = return $ S.var x term (Lit l) = return $ S.value (S.Literal l) term (App e_fun (Type ty_arg)) = fmap (flip S.tyApp ty_arg) (term e_fun) - term (App e_fun e_arg) = join $ liftM2 appE (term e_fun) (term e_arg) + term (App e_fun e_arg) = join $ liftM2 appE (term e_fun) (maybeUnLiftedTerm (exprType e_arg) e_arg) term (Lam x e) | isTyVar x = fmap (S.value . S.TyLambda x) (bindFloats (term e)) | otherwise = fmap (S.value . S.Lambda x) (bindFloats (term e)) - term (Let (NonRec x e1) e2) = liftM2 (S.let_ x) (term e1) (bindFloats (term e2)) + term (Let (NonRec x e1) e2) = liftM2 (S.let_ x) (maybeUnLiftedTerm (idType x) e1) (bindFloats (term e2)) term (Let (Rec xes) e) = bindFloatsWith (liftM2 (,) (mapM (secondM term) xes) (term e)) term (Case e x ty alts) = liftM2 (\e alts -> S.case_ e x ty alts) (term e) (mapM alt alts) term (Cast e co) = fmap (flip S.cast co) (term e) @@ -122,6 +127,13 @@ coreExprToTerm = uncurry S.letRecSmart . runParseM . term term (Type ty) = pprPanic "termToCoreExpr" (ppr ty) term (Coercion co) = return $ S.value (S.Coercion co) + -- We can float unlifted bindings out of an unlifted argument/let + -- because they were certain to be evaluated anyway. Otherwise we have + -- to residualise all the floats if any of them were unlifted. + maybeUnLiftedTerm ty e + | isUnLiftedType ty = term e + | otherwise = bindUnliftedFloats (term e) + alt (DEFAULT, [], e) = fmap ((,) S.DefaultAlt) $ bindFloats (term e) alt (LitAlt l, [], e) = fmap ((,) (S.LiteralAlt l)) $ bindFloats (term e) alt (DataAlt dc, xs, e) = fmap ((,) (S.DataAlt dc as qs zs)) $ bindFloats (term e) diff --git a/compiler/supercompile/Supercompile/Core/Syntax.hs b/compiler/supercompile/Supercompile/Core/Syntax.hs index 5ced926b18cc..d02b4b6f67ca 100644 --- a/compiler/supercompile/Supercompile/Core/Syntax.hs +++ b/compiler/supercompile/Supercompile/Core/Syntax.hs @@ -263,10 +263,6 @@ tyVarIdApps = foldl tyVarIdApp where tyVarIdApp e x | isTyVar x = tyApp e (mkTyVarTy x) | otherwise = app e x -letRecSmart :: Symantics ann => [(Var, ann (TermF ann))] -> ann (TermF ann) -> ann (TermF ann) -letRecSmart [] = id -letRecSmart xes = letRec xes - {- strictLet :: Symantics ann => Var -> ann (TermF ann) -> ann (TermF ann) -> ann (TermF ann) strictLet x e1 e2 = case_ e1 [(DefaultAlt (Just x), e2)] diff --git a/compiler/supercompile/Supercompile/Drive/Process.hs b/compiler/supercompile/Supercompile/Drive/Process.hs index 18c2b36a6892..d301382baeef 100644 --- a/compiler/supercompile/Supercompile/Drive/Process.hs +++ b/compiler/supercompile/Supercompile/Drive/Process.hs @@ -480,7 +480,7 @@ instance Monad ScpM where (!mx) >>= fxmy = ScpM $ \e s k -> unScpM mx e s (\x s -> unScpM (fxmy x) e s k) runScpM :: ScpM (Out FVedTerm) -> (SCStats, Out FVedTerm) -runScpM me = unScpM me init_e init_s (\e' s -> (stats s, letRecSmart (fulfilmentsToBinds $ fst $ partitionFulfilments fulfilmentReferredTo unionVarSets (fvedTermFreeVars e') (fulfilments s)) e')) +runScpM me = unScpM me init_e init_s (\e' s -> (stats s, bindManyMixedLiftedness fvedTermFreeVars (fulfilmentsToBinds $ fst $ partitionFulfilments fulfilmentReferredTo unionVarSets (fvedTermFreeVars e') (fulfilments s)) e')) where init_e = ScpEnv { promises = [], fulfilmentStack = [], depth = 0 } init_s = ScpState { names = h_names, fulfilments = [], stats = mempty } diff --git a/compiler/supercompile/Supercompile/Drive/Split.hs b/compiler/supercompile/Supercompile/Drive/Split.hs index a44e6cafa1e4..6c932447b78a 100644 --- a/compiler/supercompile/Supercompile/Drive/Split.hs +++ b/compiler/supercompile/Supercompile/Drive/Split.hs @@ -409,7 +409,7 @@ optimiseBracketed :: MonadStatics m -> (Deeds, Bracketed State) -> m (Deeds, Out FVedTerm) optimiseBracketed opt (deeds, b) = liftM (second (rebuild b)) $ optimiseMany optimise_one (deeds, extraBvs b `zip` fillers b) - where optimise_one (deeds, (extra_bvs, (s_deeds, s_heap, s_k, s_e))) = liftM (\(xes, (deeds, e)) -> (deeds, letRecSmart xes e)) $ bindCapturedFloats (mkVarSet extra_bvs) $ opt (deeds + s_deeds, s_heap, s_k, s_e) + where optimise_one (deeds, (extra_bvs, (s_deeds, s_heap, s_k, s_e))) = liftM (\(xes, (deeds, e)) -> (deeds, bindManyMixedLiftedness fvedTermFreeVars xes e)) $ bindCapturedFloats (mkVarSet extra_bvs) $ opt (deeds + s_deeds, s_heap, s_k, s_e) -- Because h-functions might potentially refer to the lambda/case-alt bound variables around this hole, -- we use bindCapturedFloats to residualise such bindings within exactly this context. -- See Note [When to bind captured floats] @@ -474,7 +474,7 @@ optimiseSplit opt deeds bracketeds_heap bracketed_focus = do -- 3) Combine the residualised let bindings with the let body return (sumMap (releaseBracketedDeeds releaseStateDeed) bracketeds_deeded_heap + leftover_deeds, - letRecSmart xes e_focus) + bindManyMixedLiftedness fvedTermFreeVars xes e_focus) where -- TODO: clean up this incomprehensible loop -- TODO: investigate the possibility of just fusing in the optimiseLetBinds loop with this one diff --git a/compiler/supercompile/Supercompile/Evaluator/Residualise.hs b/compiler/supercompile/Supercompile/Evaluator/Residualise.hs index d3bb76f42a27..3106ee77030b 100644 --- a/compiler/supercompile/Supercompile/Evaluator/Residualise.hs +++ b/compiler/supercompile/Supercompile/Evaluator/Residualise.hs @@ -27,7 +27,7 @@ residualiseTerm :: InScopeSet -> In AnnedTerm -> Out FVedTerm residualiseTerm ids = detagAnnedTerm . renameIn (renameAnnedTerm ids) residualiseHeap :: Heap -> (InScopeSet -> ((Out [(Var, PrettyFunction)], Out [(Var, FVedTerm)]), Out FVedTerm)) -> (Out [(Var, PrettyFunction)], Out FVedTerm) -residualiseHeap (Heap h ids) resid_body = (floats_static_h ++ floats_static_k, letRecSmart (floats_nonstatic_h ++ floats_nonstatic_k) e) +residualiseHeap (Heap h ids) resid_body = (floats_static_h ++ floats_static_k, bindManyMixedLiftedness fvedTermFreeVars (floats_nonstatic_h ++ floats_nonstatic_k) e) where (floats_static_h, floats_nonstatic_h) = residualisePureHeap ids h ((floats_static_k, floats_nonstatic_k), e) = resid_body ids diff --git a/compiler/supercompile/Supercompile/Evaluator/Syntax.hs b/compiler/supercompile/Supercompile/Evaluator/Syntax.hs index 16d2be2a50d9..a991ca503a93 100644 --- a/compiler/supercompile/Supercompile/Evaluator/Syntax.hs +++ b/compiler/supercompile/Supercompile/Evaluator/Syntax.hs @@ -15,7 +15,7 @@ import Supercompile.Utilities import Id (Id, idType) import PrimOp (primOpType) -import Type (applyTy, applyTys, mkForAllTy, mkFunTy, splitFunTy, eqType) +import Type (applyTy, applyTys, mkForAllTy, mkFunTy, splitFunTy, eqType, isUnLiftedType) import Pair (pSnd) import DataCon (dataConWorkId) import Literal (literalType) @@ -329,3 +329,23 @@ releaseUnnormalisedStateDeed (deeds, Heap h _, k, (_, e)) = releaseStackDeeds (r releaseStateDeed :: State -> Deeds releaseStateDeed (deeds, Heap h _, k, a) = releaseStackDeeds (releasePureHeapDeeds (deeds + annedSize a) h) k + + +-- Unlifted bindings are irritating. They mean that the PureHeap has an implicit order that we need to carefully +-- preserve when we turn it back into a term: unlifted bindings must be bound by a "let". +-- +-- An alternative to this would be to record the binding struture in the PureHeap itself, but that would get pretty +-- fiddly (in particuar, update frames would need to hold a "cursor" saying where in the PureHeap to update upon +-- completion). It's probably better to take the complexity hit here and now. +bindManyMixedLiftedness :: Symantics ann => (ann (TermF ann) -> FreeVars) -> [(Var, ann (TermF ann))] -> ann (TermF ann) -> ann (TermF ann) +bindManyMixedLiftedness get_fvs = go + where go [] = id + go xes = case takeFirst (\(x, _) -> isUnLiftedType (idType x)) xes of + Nothing -> letRec xes + Just ((x, e), rest_xes) -> go xes_above . let_ x e . go xes_below + where (xes_above, xes_below) = partition_one (get_fvs e) rest_xes + + partition_one bvs_below xes | bvs_below' == bvs_below = (xes_above, xes_below) + | otherwise = second (xes_below ++) $ partition_one bvs_below' xes_above + where (xes_below, xes_above) = partition (\(x, _) -> x `elemVarSet` bvs_below) xes + bvs_below' = bvs_below `unionVarSet` mkVarSet (map fst xes_below) diff --git a/compiler/supercompile/Supercompile/Utilities.hs b/compiler/supercompile/Supercompile/Utilities.hs index 189214603698..42329879a432 100644 --- a/compiler/supercompile/Supercompile/Utilities.hs +++ b/compiler/supercompile/Supercompile/Utilities.hs @@ -409,6 +409,13 @@ bagContexts xs = [(x, is ++ ts) | (is, x, ts) <- listContexts xs] dropLastWhile :: (a -> Bool) -> [a] -> [a] dropLastWhile p = reverse . dropWhile p . reverse +takeFirst :: (a -> Bool) -> [a] -> Maybe (a, [a]) +takeFirst f = go [] + where go _ [] = Nothing + go acc (x:xs) = if f x + then Just (x, reverse acc ++ xs) + else go (x:acc) xs + takeWhileJust :: (a -> Maybe b) -> [a] -> ([b], [a]) takeWhileJust f = go where -- GitLab