Commit 8d5cf8bf authored by lukemaurer's avatar lukemaurer Committed by David Feuer

Join points

This major patch implements Join Points, as described in
https://ghc.haskell.org/trac/ghc/wiki/SequentCore.  You have
to read that page, and especially the paper it links to, to
understand what's going on; but it is very cool.

It's Luke Maurer's work, but done in close collaboration with Simon PJ.

This Phab is a squash-merge of wip/join-points branch of
http://github.com/lukemaurer/ghc. There are many, many interdependent
changes.

Reviewers: goldfire, mpickering, bgamari, simonmar, dfeuer, austin

Subscribers: simonpj, dfeuer, mpickering, Mikolaj, thomie

Differential Revision: https://phabricator.haskell.org/D2853
parent 4fa439e3
......@@ -606,8 +606,8 @@ rnIfaceConAlt (IfaceDataAlt data_occ) = IfaceDataAlt <$> rnIfaceGlobal data_occ
rnIfaceConAlt alt = pure alt
rnIfaceLetBndr :: Rename IfaceLetBndr
rnIfaceLetBndr (IfLetBndr fs ty info)
= IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info
rnIfaceLetBndr (IfLetBndr fs ty info jpi)
= IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info <*> pure jpi
rnIfaceLamBndr :: Rename IfaceLamBndr
rnIfaceLamBndr (bndr, oneshot) = (,) <$> rnIfaceBndr bndr <*> pure oneshot
......
......@@ -24,7 +24,7 @@ module BasicTypes(
ConTag, ConTagZ, fIRST_TAG,
Arity, RepArity,
Arity, RepArity, JoinArity,
Alignment,
......@@ -64,13 +64,15 @@ module BasicTypes(
noOneShotInfo, hasNoOneShotInfo, isOneShotInfo,
bestOneShot, worstOneShot,
OccInfo(..), seqOccInfo, zapFragileOcc, isOneOcc,
isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isNoOcc,
OccInfo(..), noOccInfo, seqOccInfo, zapFragileOcc, isOneOcc,
isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isManyOccs,
strongLoopBreaker, weakLoopBreaker,
InsideLam, insideLam, notInsideLam,
OneBranch, oneBranch, notOneBranch,
InterestingCxt,
TailCallInfo(..), tailCallInfo, zapOccTailCallInfo,
isAlwaysTailCalled,
EP(..),
......@@ -154,6 +156,12 @@ type Arity = Int
-- \(# x, y #) -> fib (x + y) has representation arity 2
type RepArity = Int
-- | The number of arguments that a join point takes. Unlike the arity of a
-- function, this is a purely syntactic property and is fixed when the join
-- point is created (or converted from a value). Both type and value arguments
-- are counted.
type JoinArity = Int
{-
************************************************************************
* *
......@@ -808,20 +816,23 @@ defn of OccInfo here, safely at the bottom
-- | identifier Occurrence Information
data OccInfo
= NoOccInfo -- ^ There are many occurrences, or unknown occurrences
= ManyOccs { occ_tail :: !TailCallInfo }
-- ^ There are many occurrences, or unknown occurrences
| IAmDead -- ^ Marks unused variables. Sometimes useful for
-- lambda and case-bound variables.
| OneOcc
!InsideLam
!OneBranch
!InterestingCxt -- ^ Occurs exactly once, not inside a rule
| OneOcc { occ_in_lam :: !InsideLam
, occ_one_br :: !OneBranch
, occ_int_cxt :: !InterestingCxt
, occ_tail :: !TailCallInfo }
-- ^ Occurs exactly once (per branch), not inside a rule
-- | This identifier breaks a loop of mutually recursive functions. The field
-- marks whether it is only a loop breaker due to a reference in a rule
| IAmALoopBreaker -- Note [LoopBreaker OccInfo]
!RulesOnly
| IAmALoopBreaker { occ_rules_only :: !RulesOnly
, occ_tail :: !TailCallInfo }
-- Note [LoopBreaker OccInfo]
deriving (Eq)
......@@ -839,9 +850,12 @@ Note [LoopBreaker OccInfo]
See OccurAnal Note [Weak loop breakers]
-}
isNoOcc :: OccInfo -> Bool
isNoOcc NoOccInfo = True
isNoOcc _ = False
noOccInfo :: OccInfo
noOccInfo = ManyOccs { occ_tail = NoTailCallInfo }
isManyOccs :: OccInfo -> Bool
isManyOccs ManyOccs{} = True
isManyOccs _ = False
seqOccInfo :: OccInfo -> ()
seqOccInfo occ = occ `seq` ()
......@@ -868,17 +882,41 @@ oneBranch, notOneBranch :: OneBranch
oneBranch = True
notOneBranch = False
-----------------
data TailCallInfo = AlwaysTailCalled JoinArity -- See Note [TailCallInfo]
| NoTailCallInfo
deriving (Eq)
tailCallInfo :: OccInfo -> TailCallInfo
tailCallInfo IAmDead = NoTailCallInfo
tailCallInfo other = occ_tail other
zapOccTailCallInfo :: OccInfo -> OccInfo
zapOccTailCallInfo IAmDead = IAmDead
zapOccTailCallInfo occ = occ { occ_tail = NoTailCallInfo }
isAlwaysTailCalled :: OccInfo -> Bool
isAlwaysTailCalled occ
= case tailCallInfo occ of AlwaysTailCalled{} -> True
NoTailCallInfo -> False
instance Outputable TailCallInfo where
ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ]
ppr _ = empty
-----------------
strongLoopBreaker, weakLoopBreaker :: OccInfo
strongLoopBreaker = IAmALoopBreaker False
weakLoopBreaker = IAmALoopBreaker True
strongLoopBreaker = IAmALoopBreaker False NoTailCallInfo
weakLoopBreaker = IAmALoopBreaker True NoTailCallInfo
isWeakLoopBreaker :: OccInfo -> Bool
isWeakLoopBreaker (IAmALoopBreaker _) = True
isWeakLoopBreaker (IAmALoopBreaker{}) = True
isWeakLoopBreaker _ = False
isStrongLoopBreaker :: OccInfo -> Bool
isStrongLoopBreaker (IAmALoopBreaker False) = True -- Loop-breaker that breaks a non-rule cycle
isStrongLoopBreaker _ = False
isStrongLoopBreaker (IAmALoopBreaker { occ_rules_only = False }) = True
-- Loop-breaker that breaks a non-rule cycle
isStrongLoopBreaker _ = False
isDeadOcc :: OccInfo -> Bool
isDeadOcc IAmDead = True
......@@ -889,16 +927,21 @@ isOneOcc (OneOcc {}) = True
isOneOcc _ = False
zapFragileOcc :: OccInfo -> OccInfo
zapFragileOcc (OneOcc {}) = NoOccInfo
zapFragileOcc occ = occ
-- Keep only the most robust data: deadness, loop-breaker-hood
zapFragileOcc (OneOcc {}) = noOccInfo
zapFragileOcc occ = zapOccTailCallInfo occ
instance Outputable OccInfo where
-- only used for debugging; never parsed. KSW 1999-07
ppr NoOccInfo = empty
ppr (IAmALoopBreaker ro) = text "LoopBreaker" <> if ro then char '!' else empty
ppr (ManyOccs tails) = pprShortTailCallInfo tails
ppr IAmDead = text "Dead"
ppr (OneOcc inside_lam one_branch int_cxt)
= text "Once" <> pp_lam <> pp_br <> pp_args
ppr (IAmALoopBreaker rule_only tails)
= text "LoopBreaker" <> pp_ro <> pprShortTailCallInfo tails
where
pp_ro | rule_only = char '!'
| otherwise = empty
ppr (OneOcc inside_lam one_branch int_cxt tail_info)
= text "Once" <> pp_lam <> pp_br <> pp_args <> pp_tail
where
pp_lam | inside_lam = char 'L'
| otherwise = empty
......@@ -906,8 +949,43 @@ instance Outputable OccInfo where
| otherwise = char '*'
pp_args | int_cxt = char '!'
| otherwise = empty
pp_tail = pprShortTailCallInfo tail_info
pprShortTailCallInfo :: TailCallInfo -> SDoc
pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar)
pprShortTailCallInfo NoTailCallInfo = empty
{-
Note [TailCallInfo]
~~~~~~~~~~~~~~~~~~~
The occurrence analyser determines what can be made into a join point, but it
doesn't change the binder into a JoinId because then it would be inconsistent
with the occurrences. Thus it's left to the simplifier (or to simpleOptExpr) to
change the IdDetails.
The AlwaysTailCalled marker actually means slightly more than simply that the
function is always tail-called. See Note [Invariants on join points].
This info is quite fragile and should not be relied upon unless the occurrence
analyser has *just* run. Use 'Id.isJoinId_maybe' for the permanent state of
the join-point-hood of a binder; a join id itself will not be marked
AlwaysTailCalled.
Note that there is a 'TailCallInfo' on a 'ManyOccs' value. One might expect that
being tail-called would mean that the variable could only appear once per branch
(thus getting a `OneOcc { occ_one_br = True }` occurrence info), but a join
point can also be invoked from other join points, not just from case branches:
let j1 x = ...
j2 y = ... j1 z {- tail call -} ...
in case w of
A -> j1 v
B -> j2 u
C -> j2 q
Here both 'j1' and 'j2' will get marked AlwaysTailCalled, but j1 will get
ManyOccs and j2 will get `OneOcc { occ_one_br = True }`.
************************************************************************
* *
Default method specification
......
......@@ -304,7 +304,9 @@ splitArgStrProdDmd n (Str _ s) = splitStrProdDmd n s
splitStrProdDmd :: Int -> StrDmd -> Maybe [ArgStr]
splitStrProdDmd n HyperStr = Just (replicate n strBot)
splitStrProdDmd n HeadStr = Just (replicate n strTop)
splitStrProdDmd n (SProd ds) = ASSERT( ds `lengthIs` n) Just ds
splitStrProdDmd n (SProd ds) = WARN( not (ds `lengthIs` n),
text "splitStrProdDmd" $$ ppr n $$ ppr ds )
Just ds
splitStrProdDmd _ (SCall {}) = Nothing
-- This can happen when the programmer uses unsafeCoerce,
-- and we don't then want to crash the compiler (Trac #9208)
......@@ -586,7 +588,9 @@ seqArgUse _ = ()
splitUseProdDmd :: Int -> UseDmd -> Maybe [ArgUse]
splitUseProdDmd n Used = Just (replicate n useTop)
splitUseProdDmd n UHead = Just (replicate n Abs)
splitUseProdDmd n (UProd ds) = ASSERT2( ds `lengthIs` n, text "splitUseProdDmd" $$ ppr n $$ ppr ds )
splitUseProdDmd n (UProd ds) = WARN( not (ds `lengthIs` n),
text "splitUseProdDmd" $$ ppr n
$$ ppr ds )
Just ds
splitUseProdDmd _ (UCall _ _) = Nothing
-- This can happen when the programmer uses unsafeCoerce,
......
......@@ -52,7 +52,7 @@ module Id (
globaliseId, localiseId,
setIdInfo, lazySetIdInfo, modifyIdInfo, maybeModifyIdInfo,
zapLamIdInfo, zapIdDemandInfo, zapIdUsageInfo, zapIdUsageEnvInfo,
zapIdUsedOnceInfo,
zapIdUsedOnceInfo, zapIdTailCallInfo,
zapFragileIdInfo, zapIdStrictness,
transferPolyIdInfo,
......@@ -73,6 +73,10 @@ module Id (
-- ** Evidence variables
DictId, isDictId, isEvVar,
-- ** Join variables
JoinId, isJoinId, isJoinId_maybe, idJoinArity,
asJoinId, asJoinId_maybe, zapJoinId,
-- ** Inline pragma stuff
idInlinePragma, setInlinePragma, modifyInlinePragma,
idInlineActivation, setInlineActivation, idRuleMatchInfo,
......@@ -118,11 +122,12 @@ import IdInfo
import BasicTypes
-- Imported and re-exported
import Var( Id, CoVar, DictId,
import Var( Id, CoVar, DictId, JoinId,
InId, InVar,
OutId, OutVar,
idInfo, idDetails, globaliseId, varType,
isId, isLocalId, isGlobalId, isExportedId )
idInfo, idDetails, setIdDetails, globaliseId, varType,
isId, isLocalId, isGlobalId, isExportedId,
isJoinId, isJoinId_maybe )
import qualified Var
import Type
......@@ -157,7 +162,10 @@ infixl 1 `setIdUnfolding`,
`idCafInfo`,
`setIdDemandInfo`,
`setIdStrictness`
`setIdStrictness`,
`asJoinId`,
`asJoinId_maybe`
{-
************************************************************************
......@@ -543,6 +551,40 @@ isEvVar var = isPredTy (varType var)
isDictId :: Id -> Bool
isDictId id = isDictTy (idType id)
{-
************************************************************************
* *
Join variables
* *
************************************************************************
-}
idJoinArity :: JoinId -> JoinArity
idJoinArity id = isJoinId_maybe id `orElse` pprPanic "idJoinArity" (ppr id)
asJoinId :: Id -> JoinArity -> JoinId
asJoinId id arity = WARN(not (isLocalId id),
text "global id being marked as join var:" <+> ppr id)
WARN(not (is_vanilla_or_join id),
ppr id <+> pprIdDetails (idDetails id))
id `setIdDetails` JoinId arity
where
is_vanilla_or_join id = case Var.idDetails id of
VanillaId -> True
JoinId {} -> True
_ -> False
zapJoinId :: Id -> Id
-- May be a regular id already
zapJoinId jid | isJoinId jid = zapIdTailCallInfo (jid `setIdDetails` VanillaId)
-- Core Lint may complain if still marked
-- as AlwaysTailCalled
| otherwise = jid
asJoinId_maybe :: Id -> Maybe JoinArity -> Id
asJoinId_maybe id (Just arity) = asJoinId id arity
asJoinId_maybe id Nothing = zapJoinId id
{-
************************************************************************
* *
......@@ -590,9 +632,11 @@ zapIdStrictness id = modifyIdInfo (`setStrictnessInfo` nopSig) id
isStrictId :: Id -> Bool
isStrictId id
= ASSERT2( isId id, text "isStrictId: not an id: " <+> ppr id )
not (isJoinId id) && (
(isStrictType (idType id)) ||
-- Take the best of both strictnesses - old and new
(isStrictDmd (idDemandInfo id))
)
---------------------------------
-- UNFOLDING
......@@ -660,7 +704,7 @@ setIdOccInfo :: Id -> OccInfo -> Id
setIdOccInfo id occ_info = modifyIdInfo (`setOccInfo` occ_info) id
zapIdOccInfo :: Id -> Id
zapIdOccInfo b = b `setIdOccInfo` NoOccInfo
zapIdOccInfo b = b `setIdOccInfo` noOccInfo
{-
---------------------------------
......@@ -804,6 +848,9 @@ zapIdUsageEnvInfo = zapInfo zapUsageEnvInfo
zapIdUsedOnceInfo :: Id -> Id
zapIdUsedOnceInfo = zapInfo zapUsedOnceInfo
zapIdTailCallInfo :: Id -> Id
zapIdTailCallInfo = zapInfo zapTailCallInfo
{-
Note [transferPolyIdInfo]
~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -869,13 +916,14 @@ transferPolyIdInfo old_id abstract_wrt new_id
old_inline_prag = inlinePragInfo old_info
old_occ_info = occInfo old_info
new_arity = old_arity + arity_increase
new_occ_info = zapOccTailCallInfo old_occ_info
old_strictness = strictnessInfo old_info
new_strictness = increaseStrictSigArity arity_increase old_strictness
transfer new_info = new_info `setArityInfo` new_arity
`setInlinePragInfo` old_inline_prag
`setOccInfo` old_occ_info
`setOccInfo` new_occ_info
`setStrictnessInfo` new_strictness
isNeverLevPolyId :: Id -> Bool
......
......@@ -14,6 +14,7 @@ Haskell. [WDP 94/11])
module IdInfo (
-- * The IdDetails type
IdDetails(..), pprIdDetails, coVarDetails, isCoVarDetails,
JoinArity, isJoinIdDetails_maybe,
RecSelParent(..),
-- * The IdInfo type
......@@ -28,6 +29,7 @@ module IdInfo (
-- ** Zapping various forms of Info
zapLamInfo, zapFragileInfo,
zapDemandInfo, zapUsageInfo, zapUsageEnvInfo, zapUsedOnceInfo,
zapTailCallInfo,
-- ** The ArityInfo type
ArityInfo,
......@@ -55,6 +57,9 @@ module IdInfo (
InsideLam, OneBranch,
insideLam, notInsideLam, oneBranch, notOneBranch,
TailCallInfo(..),
tailCallInfo, isAlwaysTailCalled,
-- ** The RuleInfo type
RuleInfo(..),
emptyRuleInfo,
......@@ -153,6 +158,8 @@ data IdDetails
| CoVarId -- ^ A coercion variable
-- This only covers /un-lifted/ coercions, of type
-- (t1 ~# t2) or (t1 ~R# t2), not their lifted variants
| JoinId JoinArity -- ^ An 'Id' for a join point taking n arguments
-- Note [Join points] in CoreSyn
-- | Recursive Selector Parent
data RecSelParent = RecSelData TyCon | RecSelPatSyn PatSyn deriving Eq
......@@ -176,6 +183,10 @@ isCoVarDetails :: IdDetails -> Bool
isCoVarDetails CoVarId = True
isCoVarDetails _ = False
isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
isJoinIdDetails_maybe (JoinId join_arity) = Just join_arity
isJoinIdDetails_maybe _ = Nothing
instance Outputable IdDetails where
ppr = pprIdDetails
......@@ -195,6 +206,7 @@ pprIdDetails other = brackets (pp other)
= brackets $ text "RecSel" <>
ppWhen is_naughty (text "(naughty)")
pp CoVarId = text "CoVarId"
pp (JoinId arity) = text "JoinId" <> parens (int arity)
{-
************************************************************************
......@@ -285,7 +297,7 @@ vanillaIdInfo
unfoldingInfo = noUnfolding,
oneShotInfo = NoOneShotInfo,
inlinePragInfo = defaultInlinePragma,
occInfo = NoOccInfo,
occInfo = noOccInfo,
demandInfo = topDmd,
strictnessInfo = nopSig,
callArityInfo = unknownArity,
......@@ -482,12 +494,16 @@ zapLamInfo info@(IdInfo {occInfo = occ, demandInfo = demand})
where
-- The "unsafe" occ info is the ones that say I'm not in a lambda
-- because that might not be true for an unsaturated lambda
is_safe_occ (OneOcc in_lam _ _) = in_lam
is_safe_occ _other = True
is_safe_occ occ | isAlwaysTailCalled occ = False
is_safe_occ (OneOcc { occ_in_lam = in_lam }) = in_lam
is_safe_occ _other = True
safe_occ = case occ of
OneOcc _ once int_cxt -> OneOcc insideLam once int_cxt
_other -> occ
OneOcc{} -> occ { occ_in_lam = True
, occ_tail = NoTailCallInfo }
IAmALoopBreaker{}
-> occ { occ_tail = NoTailCallInfo }
_other -> occ
is_safe_dmd dmd = not (isStrictDmd dmd)
......@@ -529,6 +545,14 @@ zapFragileUnfolding unf
| isFragileUnfolding unf = noUnfolding
| otherwise = unf
zapTailCallInfo :: IdInfo -> Maybe IdInfo
zapTailCallInfo info
= case occInfo info of
occ | isAlwaysTailCalled occ -> Just (info `setOccInfo` safe_occ)
| otherwise -> Nothing
where
safe_occ = occ { occ_tail = NoTailCallInfo }
{-
************************************************************************
* *
......
module IdInfo where
import BasicTypes
import Outputable
data IdInfo
data IdDetails
......@@ -6,5 +7,6 @@ data IdDetails
vanillaIdInfo :: IdInfo
coVarDetails :: IdDetails
isCoVarDetails :: IdDetails -> Bool
isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
pprIdDetails :: IdDetails -> SDoc
......@@ -34,7 +34,7 @@
module Var (
-- * The main data type and synonyms
Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId,
Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId, JoinId,
TyVar, TypeVar, KindVar, TKVar, TyCoVar,
-- * In and Out variants
......@@ -57,6 +57,7 @@ module Var (
-- ** Predicates
isId, isTyVar, isTcTyVar,
isLocalVar, isLocalId, isCoVar, isNonCoVarId, isTyCoVar,
isJoinId, isJoinId_maybe,
isGlobalId, isExportedId,
mustHaveLocalBinding,
......@@ -83,8 +84,11 @@ module Var (
import {-# SOURCE #-} TyCoRep( Type, Kind, pprKind )
import {-# SOURCE #-} TcType( TcTyVarDetails, pprTcTyVarDetails, vanillaSkolemTv )
import {-# SOURCE #-} IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails, vanillaIdInfo, pprIdDetails )
import {-# SOURCE #-} IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails,
isJoinIdDetails_maybe,
vanillaIdInfo, pprIdDetails )
import BasicTypes ( JoinArity )
import Name hiding (varName)
import Unique ( Uniquable, Unique, getKey, getUnique
, mkUniqueGrimily, nonDetCmpUnique )
......@@ -92,6 +96,7 @@ import Util
import Binary
import DynFlags
import Outputable
import Maybes
import Data.Data
......@@ -149,6 +154,7 @@ type IpId = EvId -- A term-level implicit parameter
-- | Equality Variable
type EqVar = EvId -- Boxed equality evidence
type JoinId = Id -- A join variable
-- | Type or Coercion Variable
type TyCoVar = Id -- Type, *or* coercion variable
......@@ -612,6 +618,14 @@ isNonCoVarId :: Var -> Bool
isNonCoVarId (Id { id_details = details }) = not (isCoVarDetails details)
isNonCoVarId _ = False
isJoinId :: Var -> Bool
isJoinId (Id { id_details = details }) = isJust (isJoinIdDetails_maybe details)
isJoinId _ = False
isJoinId_maybe :: Var -> Maybe JoinArity
isJoinId_maybe (Id { id_details = details }) = isJoinIdDetails_maybe details
isJoinId_maybe _ = Nothing
isLocalId :: Var -> Bool
isLocalId (Id { idScope = LocalId _ }) = True
isLocalId _ = False
......
......@@ -12,8 +12,8 @@ module VarEnv (
elemVarEnv,
extendVarEnv, extendVarEnv_C, extendVarEnv_Acc, extendVarEnv_Directly,
extendVarEnvList,
plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusVarEnvList,
alterVarEnv,
plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusMaybeVarEnv_C,
plusVarEnvList, alterVarEnv,
delVarEnvList, delVarEnv, delVarEnv_Directly,
minusVarEnv, intersectsVarEnv,
lookupVarEnv, lookupVarEnv_NF, lookupWithDefaultVarEnv,
......@@ -41,6 +41,7 @@ module VarEnv (
unitDVarEnv,
delDVarEnv,
delDVarEnvList,
minusDVarEnv,
partitionDVarEnv,
anyDVarEnv,
......@@ -450,6 +451,7 @@ minusVarEnv :: VarEnv a -> VarEnv b -> VarEnv a
intersectsVarEnv :: VarEnv a -> VarEnv a -> Bool
plusVarEnv_C :: (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
plusVarEnv_CD :: (a -> a -> a) -> VarEnv a -> a -> VarEnv a -> a -> VarEnv a
plusMaybeVarEnv_C :: (a -> a -> Maybe a) -> VarEnv a -> VarEnv a -> VarEnv a
mapVarEnv :: (a -> b) -> VarEnv a -> VarEnv b
modifyVarEnv :: (a -> a) -> VarEnv a -> Var -> VarEnv a
......@@ -471,6 +473,7 @@ extendVarEnv_Directly = addToUFM_Directly
extendVarEnvList = addListToUFM
plusVarEnv_C = plusUFM_C
plusVarEnv_CD = plusUFM_CD
plusMaybeVarEnv_C = plusMaybeUFM_C
delVarEnvList = delListFromUFM
delVarEnv = delFromUFM
minusVarEnv = minusUFM
......@@ -541,6 +544,9 @@ mkDVarEnv = listToUDFM
extendDVarEnv :: DVarEnv a -> Var -> a -> DVarEnv a
extendDVarEnv = addToUDFM
minusDVarEnv :: DVarEnv a -> DVarEnv a' -> DVarEnv a
minusDVarEnv = minusUDFM
lookupDVarEnv :: DVarEnv a -> Var -> Maybe a
lookupDVarEnv = lookupUDFM
......
......@@ -11,7 +11,8 @@
-- | Arity and eta expansion
module CoreArity (
manifestArity, exprArity, typeArity, exprBotStrictness_maybe,
exprEtaExpandArity, findRhsArity, CheapFun, etaExpand
exprEtaExpandArity, findRhsArity, CheapFun, etaExpand,
etaExpandToJoinPoint, etaExpandToJoinPointRule
) where
#include "HsVersions.h"
......@@ -952,11 +953,17 @@ etaInfoApp subst (Case e b ty alts) eis
etaInfoApp subst (Let b e) eis
= Let b' (etaInfoApp subst' e eis)
where
(subst', b') = subst_bind subst b
(subst', b') = etaInfoAppBind subst b eis
etaInfoApp subst (Tick t e) eis
= Tick (substTickish subst t) (etaInfoApp subst e eis)
etaInfoApp subst expr _
| (Var fun, _) <- collectArgs expr
, Var fun' <- lookupIdSubst (text "etaInfoApp" <+> ppr fun) subst fun
, isJoinId fun'
= subst_expr subst expr
etaInfoApp subst e eis
= go (subst_expr subst e) eis
where
......@@ -964,6 +971,94 @@ etaInfoApp subst e eis
go e (EtaVar v : eis) = go (App e (varToCoreExpr v)) eis
go e (EtaCo co : eis) = go (Cast e co) eis
--------------
-- | Apply the eta info to a local binding. Mostly delegates to
-- `etaInfoAppLocalBndr` and `etaInfoAppRhs`.
etaInfoAppBind :: Subst -> CoreBind -> [EtaInfo] -> (Subst, CoreBind)
etaInfoAppBind subst (NonRec bndr rhs) eis