...
 
Commits (2)
......@@ -49,6 +49,7 @@ import CoreFVs
import FastString
import Type
import Util( mapSnd )
import qualified DList as DL
import Data.Bifunctor
import Control.Monad
......@@ -120,7 +121,7 @@ exitifyRec in_scope pairs
forM ann_pairs $ \(x,rhs) -> do
-- go past the lambdas of the join point
let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
body' <- go args body
body' <- go (DL.fromList args) body
let rhs' = mkLams args body'
return (x, rhs')
......@@ -131,7 +132,7 @@ exitifyRec in_scope pairs
-- variables bound on the way and lifts it out as a join point.
--
-- ExitifyM is a state monad to keep track of floated binds
go :: [Var] -- ^ Variables that are in-scope here, but
go :: DL.DList Var -- ^ Variables that are in-scope here, but
-- not in scope at the joinrec; that is,
-- we must potentially abstract over them.
-- Invariant: they are kept in dependency order
......@@ -144,7 +145,7 @@ exitifyRec in_scope pairs
| -- An exit expression has no recursive calls
let fvs = dVarSetToVarSet (freeVarsOf ann_e)
, disjointVarSet fvs recursive_calls
= go_exit captured (deAnnotate ann_e) fvs
= go_exit (DL.toList captured) (deAnnotate ann_e) fvs
-- We could not turn it into a exit joint point. So now recurse
-- into all expression where eligible exit join points might sit,
......@@ -153,7 +154,7 @@ exitifyRec in_scope pairs
-- Case right hand sides are in tail-call position
go captured (_, AnnCase scrut bndr ty alts) = do
alts' <- forM alts $ \(dc, pats, rhs) -> do
rhs' <- go (captured ++ [bndr] ++ pats) rhs
rhs' <- go (DL.snoc captured bndr DL.++: pats) rhs
return (dc, pats, rhs')
return $ Case (deAnnotate scrut) bndr ty alts'
......@@ -162,9 +163,9 @@ exitifyRec in_scope pairs
| AnnNonRec j rhs <- ann_bind
, Just join_arity <- isJoinId_maybe j
= do let (params, join_body) = collectNAnnBndrs join_arity rhs
join_body' <- go (captured ++ params) join_body
join_body' <- go (captured DL.++: params) join_body
let rhs' = mkLams params join_body'
body' <- go (captured ++ [j]) body
body' <- go (DL.snoc captured j) body
return $ Let (NonRec j rhs') body'
-- rec join point, RHSs and body are in tail-call position
......@@ -174,15 +175,15 @@ exitifyRec in_scope pairs
pairs' <- forM pairs $ \(j,rhs) -> do
let join_arity = idJoinArity j
(params, join_body) = collectNAnnBndrs join_arity rhs
join_body' <- go (captured ++ js ++ params) join_body
join_body' <- go (captured DL.++: js DL.++: params) join_body
let rhs' = mkLams params join_body'
return (j, rhs')
body' <- go (captured ++ js) body
body' <- go (captured DL.++: js) body
return $ Let (Rec pairs') body'
-- normal Let, only the body is in tail-call position
| otherwise
= do body' <- go (captured ++ bindersOf bind ) body
= do body' <- go (captured DL.++: bindersOf bind) body
return $ Let bind body'
where bind = deAnnBind ann_bind
......
......@@ -68,6 +68,7 @@ import MonadUtils
import Outputable
import PrelRules
import FastString ( fsLit )
import qualified DList as DL
import Control.Monad ( when )
import Data.List ( sortBy )
......@@ -1402,20 +1403,21 @@ mkLam _env [] body _cont
= return body
mkLam env bndrs body cont
= do { dflags <- getDynFlags
; mkLam' dflags bndrs body }
; mkLam' dflags (DL.fromList bndrs) body }
where
mkLam' :: DynFlags -> [OutBndr] -> OutExpr -> SimplM OutExpr
mkLam' :: DynFlags -> DL.DList OutBndr -> OutExpr -> SimplM OutExpr
mkLam' dflags bndrs (Cast body co)
| not (any bad bndrs)
| let bndrs' = DL.toList bndrs
, not (any bad bndrs')
-- Note [Casts and lambdas]
= do { lam <- mkLam' dflags bndrs body
; return (mkCast lam (mkPiCos Representational bndrs co)) }
; return (mkCast lam (mkPiCos Representational bndrs' co)) }
where
co_vars = tyCoVarsOfCo co
bad bndr = isCoVar bndr && bndr `elemVarSet` co_vars
mkLam' dflags bndrs body@(Lam {})
= mkLam' dflags (bndrs ++ bndrs1) body1
= mkLam' dflags (bndrs DL.++: bndrs1) body1
where
(bndrs1, body1) = collectBinders body
......@@ -1425,8 +1427,9 @@ mkLam env bndrs body cont
mkLam' dflags bndrs body
| gopt Opt_DoEtaReduction dflags
, Just etad_lam <- tryEtaReduce bndrs body
= do { tick (EtaReduction (head bndrs))
, bndrs'@(bndr':_) <- DL.toList bndrs
, Just etad_lam <- tryEtaReduce bndrs' body
= do { tick (EtaReduction bndr')
; return etad_lam }
| not (contIsRhs cont) -- See Note [Eta-expanding lambdas]
......@@ -1434,14 +1437,15 @@ mkLam env bndrs body cont
, any isRuntimeVar bndrs
, let body_arity = exprEtaExpandArity dflags body
, body_arity > 0
= do { tick (EtaExpansion (head bndrs))
; let res = mkLams bndrs (etaExpand body_arity body)
; traceSmpl "eta expand" (vcat [text "before" <+> ppr (mkLams bndrs body)
, bndrs'@(bndr':_) <- DL.toList bndrs
= do { tick (EtaExpansion bndr')
; let res = mkLams bndrs' (etaExpand body_arity body)
; traceSmpl "eta expand" (vcat [text "before" <+> ppr (mkLams bndrs' body)
, text "after" <+> ppr res])
; return res }
| otherwise
= return (mkLams bndrs body)
= return $ mkLams (DL.toList bndrs) body
{-
Note [Eta expanding lambdas]
......
......@@ -37,6 +37,7 @@ module DList
, snoc
, append
, (++)
, (++:)
, concat
, replicate
, list
......@@ -154,6 +155,12 @@ append xs ys = DL (unDL xs . unDL ys)
{-# INLINE (++) #-}
infixr 5 ++
-- | /O(1)/. Append a list to an existing DList.
(++:) :: DList a -> [a] -> DList a
xs ++: ys = DL (unDL xs . (List.++ ys))
{-# INLINE (++:) #-}
infixl 5 ++:
-- | /O(spine)/. Concatenate dlists
concat :: [DList a] -> DList a
concat = List.foldr append empty
......