From 89f53f7556e5331cfaa0d258ec7b85744edaadfb Mon Sep 17 00:00:00 2001
From: Max Bolingbroke <batterseapower@hotmail.com>
Date: Mon, 1 Aug 2011 23:42:38 +0100
Subject: [PATCH] More comprehensive treatment of unlifted let bindings: in
 particular, stop residualising them as LetRecs

---
 compiler/supercompile/Supercompile.hs         | 24 ++++++++++++++-----
 .../supercompile/Supercompile/Core/Syntax.hs  |  4 ----
 .../Supercompile/Drive/Process.hs             |  2 +-
 .../supercompile/Supercompile/Drive/Split.hs  |  4 ++--
 .../Supercompile/Evaluator/Residualise.hs     |  2 +-
 .../Supercompile/Evaluator/Syntax.hs          | 22 ++++++++++++++++-
 .../supercompile/Supercompile/Utilities.hs    |  7 ++++++
 7 files changed, 50 insertions(+), 15 deletions(-)

diff --git a/compiler/supercompile/Supercompile.hs b/compiler/supercompile/Supercompile.hs
index 792583a7e739..70f1ff936e51 100644
--- a/compiler/supercompile/Supercompile.hs
+++ b/compiler/supercompile/Supercompile.hs
@@ -15,12 +15,12 @@ import DataCon    (dataConWorkId, dataConAllTyVars, dataConRepArgTys)
 import VarSet
 import Name       (localiseName)
 import Var        (Var, isTyVar, varName, setVarName)
-import Id         (Id, mkSysLocal, mkSysLocalM, realIdUnfolding, idInlinePragma, isPrimOpId_maybe, isDataConWorkId_maybe, setIdNotExported, isExportedId)
+import Id         (Id, mkSysLocal, idType, mkSysLocalM, realIdUnfolding, idInlinePragma, isPrimOpId_maybe, isDataConWorkId_maybe, setIdNotExported, isExportedId)
 import MkId       (mkPrimOpId)
 import MkCore     (mkBigCoreVarTup, mkTupleSelector, mkWildValBinder)
 import FastString (mkFastString, fsLit)
 import PrimOp     (primOpSig)
-import Type       (mkTyVarTy)
+import Type       (mkTyVarTy, isUnLiftedType)
 
 import Control.Monad
 import qualified Data.Map as M
@@ -95,7 +95,12 @@ bindFloats :: ParseM S.Term -> ParseM S.Term
 bindFloats = bindFloatsWith . fmap ((,) [])
 
 bindFloatsWith :: ParseM ([(Var, S.Term)], S.Term) -> ParseM S.Term
-bindFloatsWith act = ParseM $ \s -> case unParseM act s of (s, floats, (xes, e)) -> (s, [], S.letRecSmart (xes ++ floats) e)
+bindFloatsWith act = ParseM $ \s -> case unParseM act s of (s, floats, (xes, e)) -> (s, [], S.bindManyMixedLiftedness S.termFreeVars (xes ++ floats) e)
+
+bindUnliftedFloats :: ParseM S.Term -> ParseM S.Term
+bindUnliftedFloats act = ParseM $ \s -> case unParseM act s of (s, floats, e) -> if any (isUnLiftedType . idType . fst) floats
+                                                                                 then (s, [], S.bindManyMixedLiftedness S.termFreeVars floats e)
+                                                                                 else (s, floats, e)
 
 appE :: S.Term -> S.Term -> ParseM S.Term
 appE e1 e2
@@ -105,16 +110,16 @@ appE e1 e2
 
 
 coreExprToTerm :: CoreExpr -> S.Term
-coreExprToTerm = uncurry S.letRecSmart . runParseM . term
+coreExprToTerm = uncurry (S.bindManyMixedLiftedness S.termFreeVars) . runParseM . term
   where
     -- PrimOp and Data are dealt with later on by generating appropriate unfoldings
     term (Var x)                   = return $ S.var x
     term (Lit l)                   = return $ S.value (S.Literal l)
     term (App e_fun (Type ty_arg)) = fmap (flip S.tyApp ty_arg) (term e_fun)
-    term (App e_fun e_arg)         = join $ liftM2 appE (term e_fun) (term e_arg)
+    term (App e_fun e_arg)         = join $ liftM2 appE (term e_fun) (maybeUnLiftedTerm (exprType e_arg) e_arg)
     term (Lam x e) | isTyVar x     = fmap (S.value . S.TyLambda x) (bindFloats (term e))
                    | otherwise     = fmap (S.value . S.Lambda x) (bindFloats (term e))
-    term (Let (NonRec x e1) e2)    = liftM2 (S.let_ x) (term e1) (bindFloats (term e2))
+    term (Let (NonRec x e1) e2)    = liftM2 (S.let_ x) (maybeUnLiftedTerm (idType x) e1) (bindFloats (term e2))
     term (Let (Rec xes) e)         = bindFloatsWith (liftM2 (,) (mapM (secondM term) xes) (term e))
     term (Case e x ty alts)        = liftM2 (\e alts -> S.case_ e x ty alts) (term e) (mapM alt alts)
     term (Cast e co)               = fmap (flip S.cast co) (term e)
@@ -122,6 +127,13 @@ coreExprToTerm = uncurry S.letRecSmart . runParseM . term
     term (Type ty)                 = pprPanic "termToCoreExpr" (ppr ty)
     term (Coercion co)             = return $ S.value (S.Coercion co)
     
+    -- We can float unlifted bindings out of an unlifted argument/let
+    -- because they were certain to be evaluated anyway. Otherwise we have
+    -- to residualise all the floats if any of them were unlifted.
+    maybeUnLiftedTerm ty e
+      | isUnLiftedType ty = term e
+      | otherwise         = bindUnliftedFloats (term e)
+
     alt (DEFAULT,    [], e) = fmap ((,) S.DefaultAlt)            $ bindFloats (term e)
     alt (LitAlt l,   [], e) = fmap ((,) (S.LiteralAlt l))        $ bindFloats (term e)
     alt (DataAlt dc, xs, e) = fmap ((,) (S.DataAlt dc as qs zs)) $ bindFloats (term e)
diff --git a/compiler/supercompile/Supercompile/Core/Syntax.hs b/compiler/supercompile/Supercompile/Core/Syntax.hs
index 5ced926b18cc..d02b4b6f67ca 100644
--- a/compiler/supercompile/Supercompile/Core/Syntax.hs
+++ b/compiler/supercompile/Supercompile/Core/Syntax.hs
@@ -263,10 +263,6 @@ tyVarIdApps = foldl tyVarIdApp
   where tyVarIdApp e x | isTyVar x = tyApp e (mkTyVarTy x)
                        | otherwise = app   e x
 
-letRecSmart :: Symantics ann => [(Var, ann (TermF ann))] -> ann (TermF ann) -> ann (TermF ann)
-letRecSmart []  = id
-letRecSmart xes = letRec xes
-
 {-
 strictLet :: Symantics ann => Var -> ann (TermF ann) -> ann (TermF ann) -> ann (TermF ann)
 strictLet x e1 e2 = case_ e1 [(DefaultAlt (Just x), e2)]
diff --git a/compiler/supercompile/Supercompile/Drive/Process.hs b/compiler/supercompile/Supercompile/Drive/Process.hs
index 18c2b36a6892..d301382baeef 100644
--- a/compiler/supercompile/Supercompile/Drive/Process.hs
+++ b/compiler/supercompile/Supercompile/Drive/Process.hs
@@ -480,7 +480,7 @@ instance Monad ScpM where
     (!mx) >>= fxmy = ScpM $ \e s k -> unScpM mx e s (\x s -> unScpM (fxmy x) e s k)
 
 runScpM :: ScpM (Out FVedTerm) -> (SCStats, Out FVedTerm)
-runScpM me = unScpM me init_e init_s (\e' s -> (stats s, letRecSmart (fulfilmentsToBinds $ fst $ partitionFulfilments fulfilmentReferredTo unionVarSets (fvedTermFreeVars e') (fulfilments s)) e'))
+runScpM me = unScpM me init_e init_s (\e' s -> (stats s, bindManyMixedLiftedness fvedTermFreeVars (fulfilmentsToBinds $ fst $ partitionFulfilments fulfilmentReferredTo unionVarSets (fvedTermFreeVars e') (fulfilments s)) e'))
   where
     init_e = ScpEnv { promises = [], fulfilmentStack = [], depth = 0 }
     init_s = ScpState { names = h_names, fulfilments = [], stats = mempty }
diff --git a/compiler/supercompile/Supercompile/Drive/Split.hs b/compiler/supercompile/Supercompile/Drive/Split.hs
index a44e6cafa1e4..6c932447b78a 100644
--- a/compiler/supercompile/Supercompile/Drive/Split.hs
+++ b/compiler/supercompile/Supercompile/Drive/Split.hs
@@ -409,7 +409,7 @@ optimiseBracketed :: MonadStatics m
                   -> (Deeds, Bracketed State)
                   -> m (Deeds, Out FVedTerm)
 optimiseBracketed opt (deeds, b) = liftM (second (rebuild b)) $ optimiseMany optimise_one (deeds, extraBvs b `zip` fillers b)
-  where optimise_one (deeds, (extra_bvs, (s_deeds, s_heap, s_k, s_e))) = liftM (\(xes, (deeds, e)) -> (deeds, letRecSmart xes e)) $ bindCapturedFloats (mkVarSet extra_bvs) $ opt (deeds + s_deeds, s_heap, s_k, s_e)
+  where optimise_one (deeds, (extra_bvs, (s_deeds, s_heap, s_k, s_e))) = liftM (\(xes, (deeds, e)) -> (deeds, bindManyMixedLiftedness fvedTermFreeVars xes e)) $ bindCapturedFloats (mkVarSet extra_bvs) $ opt (deeds + s_deeds, s_heap, s_k, s_e)
         -- Because h-functions might potentially refer to the lambda/case-alt bound variables around this hole,
         -- we use bindCapturedFloats to residualise such bindings within exactly this context.
         -- See Note [When to bind captured floats]
@@ -474,7 +474,7 @@ optimiseSplit opt deeds bracketeds_heap bracketed_focus = do
     
         -- 3) Combine the residualised let bindings with the let body
         return (sumMap (releaseBracketedDeeds releaseStateDeed) bracketeds_deeded_heap + leftover_deeds,
-                letRecSmart xes e_focus)
+                bindManyMixedLiftedness fvedTermFreeVars xes e_focus)
   where
     -- TODO: clean up this incomprehensible loop
     -- TODO: investigate the possibility of just fusing in the optimiseLetBinds loop with this one
diff --git a/compiler/supercompile/Supercompile/Evaluator/Residualise.hs b/compiler/supercompile/Supercompile/Evaluator/Residualise.hs
index d3bb76f42a27..3106ee77030b 100644
--- a/compiler/supercompile/Supercompile/Evaluator/Residualise.hs
+++ b/compiler/supercompile/Supercompile/Evaluator/Residualise.hs
@@ -27,7 +27,7 @@ residualiseTerm :: InScopeSet -> In AnnedTerm -> Out FVedTerm
 residualiseTerm ids = detagAnnedTerm . renameIn (renameAnnedTerm ids)
 
 residualiseHeap :: Heap -> (InScopeSet -> ((Out [(Var, PrettyFunction)], Out [(Var, FVedTerm)]), Out FVedTerm)) -> (Out [(Var, PrettyFunction)], Out FVedTerm)
-residualiseHeap (Heap h ids) resid_body = (floats_static_h ++ floats_static_k, letRecSmart (floats_nonstatic_h ++ floats_nonstatic_k) e)
+residualiseHeap (Heap h ids) resid_body = (floats_static_h ++ floats_static_k, bindManyMixedLiftedness fvedTermFreeVars (floats_nonstatic_h ++ floats_nonstatic_k) e)
   where (floats_static_h, floats_nonstatic_h) = residualisePureHeap ids h
         ((floats_static_k, floats_nonstatic_k), e) = resid_body ids
 
diff --git a/compiler/supercompile/Supercompile/Evaluator/Syntax.hs b/compiler/supercompile/Supercompile/Evaluator/Syntax.hs
index 16d2be2a50d9..a991ca503a93 100644
--- a/compiler/supercompile/Supercompile/Evaluator/Syntax.hs
+++ b/compiler/supercompile/Supercompile/Evaluator/Syntax.hs
@@ -15,7 +15,7 @@ import Supercompile.Utilities
 
 import Id       (Id, idType)
 import PrimOp   (primOpType)
-import Type     (applyTy, applyTys, mkForAllTy, mkFunTy, splitFunTy, eqType)
+import Type     (applyTy, applyTys, mkForAllTy, mkFunTy, splitFunTy, eqType, isUnLiftedType)
 import Pair     (pSnd)
 import DataCon  (dataConWorkId)
 import Literal  (literalType)
@@ -329,3 +329,23 @@ releaseUnnormalisedStateDeed (deeds, Heap h _, k, (_, e)) = releaseStackDeeds (r
 
 releaseStateDeed :: State -> Deeds
 releaseStateDeed (deeds, Heap h _, k, a) = releaseStackDeeds (releasePureHeapDeeds (deeds + annedSize a) h) k
+
+
+-- Unlifted bindings are irritating. They mean that the PureHeap has an implicit order that we need to carefully
+-- preserve when we turn it back into a term: unlifted bindings must be bound by a "let".
+--
+-- An alternative to this would be to record the binding struture in the PureHeap itself, but that would get pretty
+-- fiddly (in particuar, update frames would need to hold a "cursor" saying where in the PureHeap to update upon
+-- completion). It's probably better to take the complexity hit here and now.
+bindManyMixedLiftedness :: Symantics ann => (ann (TermF ann) -> FreeVars) -> [(Var, ann (TermF ann))] -> ann (TermF ann) -> ann (TermF ann)
+bindManyMixedLiftedness get_fvs = go
+  where go []  = id
+        go xes = case takeFirst (\(x, _) -> isUnLiftedType (idType x)) xes of
+            Nothing                 -> letRec xes
+            Just ((x, e), rest_xes) -> go xes_above . let_ x e . go xes_below
+              where (xes_above, xes_below) = partition_one (get_fvs e) rest_xes
+
+        partition_one bvs_below xes | bvs_below' == bvs_below = (xes_above, xes_below)
+                                    | otherwise               = second (xes_below ++) $ partition_one bvs_below' xes_above
+          where (xes_below, xes_above) = partition (\(x, _) -> x `elemVarSet` bvs_below) xes
+                bvs_below' = bvs_below `unionVarSet` mkVarSet (map fst xes_below)
diff --git a/compiler/supercompile/Supercompile/Utilities.hs b/compiler/supercompile/Supercompile/Utilities.hs
index 189214603698..42329879a432 100644
--- a/compiler/supercompile/Supercompile/Utilities.hs
+++ b/compiler/supercompile/Supercompile/Utilities.hs
@@ -409,6 +409,13 @@ bagContexts xs = [(x, is ++ ts) | (is, x, ts) <- listContexts xs]
 dropLastWhile :: (a -> Bool) -> [a] -> [a]
 dropLastWhile p = reverse . dropWhile p . reverse
 
+takeFirst :: (a -> Bool) -> [a] -> Maybe (a, [a])
+takeFirst f = go []
+  where go _   []     = Nothing
+        go acc (x:xs) = if f x
+                        then Just (x, reverse acc ++ xs)
+                        else go (x:acc) xs
+
 takeWhileJust :: (a -> Maybe b) -> [a] -> ([b], [a])
 takeWhileJust f = go
   where
-- 
GitLab