diff --git a/compiler/supercompile/Supercompile.hs b/compiler/supercompile/Supercompile.hs index 5e52e96a7915d6f746e9015fc9f611666304ab3e..573c308f98cab86a6d7986a4a826d6e7b19c37f9 100644 --- a/compiler/supercompile/Supercompile.hs +++ b/compiler/supercompile/Supercompile.hs @@ -184,7 +184,7 @@ termUnfoldings e = go (S.termFreeVars e) emptyVarSet [] [] DFunUnfolding _ dc es -> Right $ runParseM us2 $ coreExprToTerm $ mkLams as $ mkLams xs $ Var (dataConWorkId dc) `mkTyApps` cls_tys `mkApps` [(e `mkTyApps` map mkTyVarTy as) `mkVarApps` xs | e <- es] where (as, theta, _cls, cls_tys) = tcSplitDFunTy (idType x) xs = zipWith (mkSysLocal (fsLit "x")) bv_uniques theta - CoreUnfolding { uf_tmpl = e } -> Right $ (if super then markSuperinlinable else id) $ runParseM us2 $ coreExprToTerm e + CoreUnfolding { uf_tmpl = e } -> Right $ superinlinableLexically super $ runParseM us2 $ coreExprToTerm e -- NB: it's OK if the unfolding is a non-value, as the evaluator won't inline LetBound non-values primOpUnfolding pop = S.tyLambdas as $ S.lambdas xs $ S.primOp pop (map mkTyVarTy as) (map S.var xs) @@ -217,11 +217,14 @@ termUnfoldings e = go (S.termFreeVars e) emptyVarSet [] [] -- NB: this is used to deal with SUPERINLINABLE bindings which have locally bound loops which -- are *not* marked SUPERINLINABLE -- +-- NB: for this to work properly, coreExprToTerm must not float +-- stuff that was lexically within a binding out of that binding! +-- -- TODO: this is actually useless if we just say that all InternallyBound things are SUPERINLINABLE -markSuperinlinable :: S.Term -> S.Term ---markSuperinlinable = id +superinlinableLexically :: Superinlinable -> S.Term -> S.Term +--superinlinableLexically = id {--} -markSuperinlinable = term +superinlinableLexically ctxt = term where term e = flip fmap e $ \e -> case e of S.Var x -> S.Var x @@ -233,16 +236,19 @@ markSuperinlinable = term S.CoApp e co -> S.CoApp (term e) co S.App e x -> S.App (term e) x S.PrimOp pop tys es -> S.PrimOp pop tys (map term es) - S.Case e x ty alts -> S.Case (term e) x ty (map (second term) alts) - S.Let x e1 e2 -> S.Let (bndr x) (term e1) (term e2) - S.LetRec xes e -> S.LetRec (map (bndr *** term) xes) (term e) + S.Case e x ty alts -> uncurry (flip S.Case) (pair (x, e)) ty (map (second term) alts) + S.Let x e1 e2 -> uncurry S.Let (pair (x, e1)) (term e2) + S.LetRec xes e -> S.LetRec (map pair xes) (term e) S.Cast e co -> S.Cast (term e) co - bndr x = x `setInlinePragma` (idInlinePragma x) { inl_inline = case inl_inline (idInlinePragma x) of - Inline -> Inline - Inlinable _ -> Inlinable True - NoInline -> NoInline - EmptyInlineSpec -> Inlinable True } + pair (x, e) | not ctxt = case S.shouldExposeUnfolding x of Right True -> (x, superinlinableLexically True e) + _ -> (x, term e) + | otherwise = (x', term e) + where x' = x `setInlinePragma` (idInlinePragma x) { inl_inline = case inl_inline (idInlinePragma x) of + Inline -> Inline + Inlinable _ -> Inlinable True + NoInline -> NoInline + EmptyInlineSpec -> Inlinable True } {--} @@ -252,7 +258,9 @@ supercompile e = -- liftM (termToCoreExpr . snd) $ return $ termToCoreExpr $ S.supercompile (M.fromList unfs) e' where unfs = termUnfoldings e' - e' = runParseM anfUniqSupply' (coreExprToTerm e) + -- NB: ensure we mark any child bindings of bindings marked SUPERINLINABLE in *this module* as SUPERINLINABLE, + -- just like we would if we imported a SUPERINLINABLE binding + e' = superinlinableLexically False $ runParseM anfUniqSupply' $ coreExprToTerm e supercompileProgram :: [CoreBind] -> IO [CoreBind] supercompileProgram binds = supercompileProgramSelective selector binds