Commit 731f53de authored by simonpj's avatar simonpj
Browse files

[project @ 1999-09-17 09:15:22 by simonpj]

This bunch of commits represents work in progress on inlining and
worker/wrapper stuff.

Currently, I think it makes the compiler slightly worse than 4.04, for
reasons I don't yet understand.  But it means that Simon and I can
both peer at what is going on.

* Substantially improve handling of coerces in worker/wrapper

* exprIsDupable for an application (f e1 .. en) wasn't calling exprIsDupable
  on the arguments!!  So applications with few, but large, args were being dupliated.

* sizeExpr on an application wasn't doing a nukeScrutDiscount on the arg of
  an application!!  So bogus discounts could accumulate from arguments!

* Improve handling of INLINE pragmas in calcUnfoldingGuidance.  It was really
  wrong before
parent 6e5c95e9
......@@ -48,14 +48,14 @@ import PprCore ( pprCoreExpr )
import OccurAnal ( occurAnalyseGlobalExpr )
import BinderInfo ( )
import CoreUtils ( coreExprType, exprIsTrivial, exprIsValue, exprIsCheap )
import Id ( Id, idType, idUnique, isId,
import Id ( Id, idType, idUnique, isId, getIdWorkerInfo,
getIdSpecialisation, getInlinePragma, getIdUnfolding
)
import VarSet
import Name ( isLocallyDefined )
import Const ( Con(..), isLitLitLit, isWHNFCon )
import PrimOp ( PrimOp(..), primOpIsDupable )
import IdInfo ( ArityInfo(..), InlinePragInfo(..), OccInfo(..) )
import IdInfo ( ArityInfo(..), InlinePragInfo(..), OccInfo(..), workerExists )
import TyCon ( tyConFamilySize )
import Type ( splitAlgTyConApp_maybe, splitFunTy_maybe, isUnLiftedType )
import Const ( isNoRepLit )
......@@ -170,10 +170,8 @@ instance Outputable UnfoldingGuidance where
ppr UnfoldAlways = ptext SLIT("ALWAYS")
ppr UnfoldNever = ptext SLIT("NEVER")
ppr (UnfoldIfGoodArgs v cs size discount)
= hsep [ptext SLIT("IF_ARGS"), int v,
if null cs -- always print *something*
then char 'X'
else hcat (map (text . show) cs),
= hsep [ ptext SLIT("IF_ARGS"), int v,
brackets (hsep (map int cs)),
int size,
int discount ]
\end{code}
......@@ -199,21 +197,33 @@ calcUnfoldingGuidance bOMB_OUT_SIZE expr
= UnfoldAlways
| otherwise
= case collectBinders expr of { (binders, body) ->
let
val_binders = filter isId binders
in
= case collect_val_bndrs expr of { (inline, val_binders, body) ->
case (sizeExpr bOMB_OUT_SIZE val_binders body) of
TooBig -> UnfoldNever
SizeIs size cased_args scrut_discount
-> UnfoldIfGoodArgs
(length val_binders)
n_val_binders
(map discount_for val_binders)
(I# size)
final_size
(I# scrut_discount)
where
boxed_size = I# size
n_val_binders = length val_binders
final_size | inline = boxed_size `min` (n_val_binders + 2)
| otherwise = boxed_size
-- The idea is that if there is an INLINE pragma (inline is True)
-- and there's a big body, we give a size of n_val_binders+2. This
-- This is enough to defeat the no-size-increase test in callSiteInline;
-- we don't want to inline an INLINE thing into a totally boring context
--
-- Sometimes, though, an INLINE thing is smaller than n_val_binders+2.
-- A particular case in point is a constructor, which has size 1.
-- We want to inline this regardless, hence the `min`
discount_for b
| num_cases == 0 = 0
| is_fun_ty = num_cases * opt_UF_FunAppDiscount
......@@ -228,6 +238,19 @@ calcUnfoldingGuidance bOMB_OUT_SIZE expr
Nothing -> (False, panic "discount")
Just (tc,_,_) -> (True, tc)
}
where
collect_val_bndrs e = go False [] e
-- We need to be a bit careful about how we collect the
-- value binders. In ptic, if we see
-- __inline_me (\x y -> e)
-- We want to say "2 value binders". Why? So that
-- we take account of information given for the arguments
go inline rev_vbs (Note InlineMe e) = go True rev_vbs e
go inline rev_vbs (Lam b e) | isId b = go inline (b:rev_vbs) e
| otherwise = go inline rev_vbs e
go inline rev_vbs e = (inline, reverse rev_vbs, e)
\end{code}
\begin{code}
......@@ -243,12 +266,6 @@ sizeExpr (I# bOMB_OUT_SIZE) args expr
size_up (Type t) = sizeZero -- Types cost nothing
size_up (Var v) = sizeOne
size_up (Note InlineMe _) = sizeTwo -- The idea is that this is one more
-- than the size of the "call" (i.e. 1)
-- We want to reply "no" to noSizeIncrease
-- for a bare reference (i.e. applied to no args)
-- to an INLINE thing
size_up (Note _ body) = size_up body -- Notes cost nothing
size_up (App fun (Type t)) = size_up fun
......@@ -289,7 +306,7 @@ sizeExpr (I# bOMB_OUT_SIZE) args expr
------------
size_up_app (App fun arg) args = size_up_app fun (arg:args)
size_up_app fun args = foldr (addSize . size_up) (fun_discount fun) args
size_up_app fun args = foldr (addSize . nukeScrutDiscount . size_up) (fun_discount fun) args
-- A function application with at least one value argument
-- so if the function is an argument give it an arg-discount
......@@ -597,7 +614,9 @@ computeDiscount n_vals_wanted arg_discounts res_discount arg_infos result_used
-- we also discount 1 for each argument passed, because these will
-- reduce with the lambdas in the function (we count 1 for a lambda
-- in size_up).
= length (take n_vals_wanted arg_infos) +
= 1 + -- Discount of 1 because the result replaces the call
-- so we count 1 for the function itself
length (take n_vals_wanted arg_infos) +
-- Discount of 1 for each arg supplied, because the
-- result replaces the call
round (opt_UF_KeenessFactor *
......@@ -636,10 +655,21 @@ blackListed :: IdSet -- Used in transformation rules
-- inlined because of the inline phase we are in. This is the sole
-- place that the inline phase number is looked at.
-- ToDo: improve horrible coding style (too much duplication)
-- Phase 0: used for 'no imported inlinings please'
-- This prevents wrappers getting inlined which in turn is bad for full laziness
-- NEW: try using 'not a wrapper' rather than 'not imported' in this phase.
-- This allows a little more inlining, which seems to be important, sometimes.
-- For example PrelArr.newIntArr gets better.
blackListed rule_vars (Just 0)
= \v -> not (isLocallyDefined v)
= \v -> let v_uniq = idUnique v
in
-- not (isLocallyDefined v)
workerExists (getIdWorkerInfo v)
|| v `elemVarSet` rule_vars
|| not (isEmptyCoreRules (getIdSpecialisation v))
|| v_uniq == runSTRepIdKey
-- Phase 1: don't inline any rule-y things or things with specialisations
blackListed rule_vars (Just 1)
......
......@@ -10,7 +10,7 @@ module CoreUtils (
exprIsBottom, exprIsDupable, exprIsTrivial, exprIsCheap,
exprIsValue,
exprOkForSpeculation, exprIsBig, hashExpr,
exprArity, exprGenerousArity,
exprArity, exprEtaExpandArity,
cheapEqExpr, eqExpr, applyTypeToArgs
) where
......@@ -149,7 +149,7 @@ exprIsDupable (Con con args) = conIsDupable con &&
exprIsDupable (Note _ e) = exprIsDupable e
exprIsDupable expr = case collectArgs expr of
(Var f, args) -> valArgCount args <= dupAppSize
(Var f, args) -> all exprIsDupable args && valArgCount args <= dupAppSize
other -> False
dupAppSize :: Int
......@@ -230,7 +230,8 @@ It returns True iff
the expression guarantees to terminate,
soon,
without raising an exceptoin
without raising an exception,
without causing a side effect (e.g. writing a mutable variable)
E.G.
let x = case y# +# 1# of { r# -> I# r# }
......@@ -303,13 +304,24 @@ exprIsValue e@(App _ _) = case collectArgs e of
exprArity :: CoreExpr -> Int -- How many value lambdas are at the top
exprArity (Lam b e) | isTyVar b = exprArity e
| otherwise = 1 + exprArity e
exprArity (Note note e) | ok_note note = exprArity e
exprArity other = 0
where
ok_note (Coerce _ _) = True
-- We *do* look through coerces when getting arities.
-- Reason: arities are to do with *representation* and
-- work duplication.
ok_note InlineMe = True
ok_note InlineCall = True
ok_note other = False
-- SCC and TermUsg might be over-conservative?
exprArity other = 0
\end{code}
\begin{code}
exprGenerousArity :: CoreExpr -> Int -- The number of args the thing can be applied to
exprEtaExpandArity :: CoreExpr -> Int -- The number of args the thing can be applied to
-- without doing much work
-- This is used when eta expanding
-- e ==> \xy -> e x y
......@@ -320,17 +332,36 @@ exprGenerousArity :: CoreExpr -> Int -- The number of args the thing can be app
-- We are prepared to evaluate x each time round the loop in order to get that
-- Hence "generous" arity
exprGenerousArity (Var v) = arityLowerBound (getIdArity v)
exprGenerousArity (Note note e)
| ok_note note = exprGenerousArity e
exprGenerousArity (Lam x e)
| isId x = 1 + exprGenerousArity e
| otherwise = exprGenerousArity e
exprGenerousArity (Let bind body)
| all exprIsCheap (rhssOfBind bind) = exprGenerousArity body
exprGenerousArity (Case scrut _ alts)
| exprIsCheap scrut = min_zero [exprGenerousArity rhs | (_,_,rhs) <- alts]
exprGenerousArity other = 0 -- Could do better for applications
exprEtaExpandArity (Var v) = arityLowerBound (getIdArity v)
exprEtaExpandArity (Lam x e)
| isId x = 1 + exprEtaExpandArity e
| otherwise = exprEtaExpandArity e
exprEtaExpandArity (Let bind body)
| all exprIsCheap (rhssOfBind bind) = exprEtaExpandArity body
exprEtaExpandArity (Case scrut _ alts)
| exprIsCheap scrut = min_zero [exprEtaExpandArity rhs | (_,_,rhs) <- alts]
exprEtaExpandArity (Note note e)
| ok_note note = exprEtaExpandArity e
where
ok_note InlineCall = True
ok_note other = False
-- Notice that we do not look through __inline_me__
-- This one is a bit more surprising, but consider
-- f = _inline_me (\x -> e)
-- We DO NOT want to eta expand this to
-- f = \x -> (_inline_me (\x -> e)) x
-- because the _inline_me gets dropped now it is applied,
-- giving just
-- f = \x -> e
-- A Bad Idea
--
-- Notice also that we don't look through Coerce
-- This is simply because the etaExpand code in SimplUtils
-- isn't capable of making the alternating lambdas and coerces
-- that would be necessary to exploit it
exprEtaExpandArity other = 0 -- Could do better for applications
min_zero :: [Int] -> Int -- Find the minimum, but zero is the smallest
min_zero (x:xs) = go x xs
......@@ -340,24 +371,6 @@ min_zero (x:xs) = go x xs
go min (x:xs) | x < min = go x xs
| otherwise = go min xs
ok_note (SCC _) = False -- (Over?) conservative
ok_note (TermUsg _) = False -- Doesn't matter much
ok_note (Coerce _ _) = True
-- We *do* look through coerces when getting arities.
-- Reason: arities are to do with *representation* and
-- work duplication.
ok_note InlineCall = True
ok_note InlineMe = False
-- This one is a bit more surprising, but consider
-- f = _inline_me (\x -> e)
-- We DO NOT want to eta expand this to
-- f = \x -> (_inline_me (\x -> e)) x
-- because the _inline_me gets dropped now it is applied,
-- giving just
-- f = \x -> e
-- A Bad Idea
\end{code}
......
......@@ -9,14 +9,14 @@ module ErrUtils (
addShortErrLocLine, addShortWarnLocLine,
addErrLocHdrLine,
dontAddErrLoc,
pprBagOfErrors, pprBagOfWarnings,
printErrorsAndWarnings, pprBagOfErrors, pprBagOfWarnings,
ghcExit,
doIfSet, dumpIfSet
) where
#include "HsVersions.h"
import Bag ( Bag, bagToList )
import Bag ( Bag, bagToList, isEmptyBag )
import SrcLoc ( SrcLoc, noSrcLoc )
import Util ( sortLt )
import Outputable
......@@ -57,6 +57,16 @@ dontAddErrLoc title rest_of_err_msg
| otherwise =
( noSrcLoc, hang (text title <> colon) 4 rest_of_err_msg )
printErrorsAndWarnings :: Bag ErrMsg -> Bag WarnMsg -> IO ()
-- Don't print any warnings if there are errors
printErrorsAndWarnings errs warns
| no_errs && no_warns = return ()
| no_errs = printErrs (pprBagOfWarnings warns)
| otherwise = printErrs (pprBagOfErrors errs)
where
no_warns = isEmptyBag warns
no_errs = isEmptyBag errs
pprBagOfErrors :: Bag ErrMsg -> SDoc
pprBagOfErrors bag_of_errors
= vcat [text "" $$ p | (_,p) <- sorted_errs ]
......
......@@ -27,8 +27,8 @@ import Id ( Id, idType, idInfo, omitIfaceSigForId, isUserExportedId,
import Var ( isId )
import VarSet
import DataCon ( StrictnessMark(..), dataConSig, dataConFieldLabels, dataConStrictMarks )
import IdInfo ( IdInfo, StrictnessInfo, ArityInfo, InlinePragInfo(..), inlinePragInfo,
arityInfo, ppArityInfo,
import IdInfo ( IdInfo, StrictnessInfo(..), ArityInfo, InlinePragInfo(..), inlinePragInfo,
arityInfo, ppArityInfo, arityLowerBound,
strictnessInfo, ppStrictnessInfo, isBottomingStrictness,
cafInfo, ppCafInfo, specInfo,
cprInfo, ppCprInfo,
......@@ -290,7 +290,8 @@ ifaceId get_idinfo needed_ids is_rec id rhs
= Nothing -- Well, that was easy!
ifaceId get_idinfo needed_ids is_rec id rhs
= Just (hsep [sig_pretty, prag_pretty, char ';'], new_needed_ids)
= ASSERT2( arity_matches_strictness, ppr id )
Just (hsep [sig_pretty, prag_pretty, char ';'], new_needed_ids)
where
core_idinfo = idInfo id
stg_idinfo = get_idinfo id
......@@ -310,7 +311,8 @@ ifaceId get_idinfo needed_ids is_rec id rhs
ptext SLIT("##-}")]
------------ Arity --------------
arity_pretty = ppArityInfo (arityInfo stg_idinfo)
arity_info = arityInfo stg_idinfo
arity_pretty = ppArityInfo arity_info
------------ Caf Info --------------
caf_pretty = ppCafInfo (cafInfo stg_idinfo)
......@@ -369,6 +371,15 @@ ifaceId get_idinfo needed_ids is_rec id rhs
find_fvs expr = exprSomeFreeVars interestingId expr
------------ Sanity checking --------------
-- The arity of a wrapper function should match its strictness,
-- or else an importing module will get very confused indeed.
arity_matches_strictness
= not has_worker ||
case strict_info of
StrictnessInfo ds _ -> length ds == arityLowerBound arity_info
other -> True
interestingId id = isId id && isLocallyDefined id &&
not (omitIfaceSigForId id)
\end{code}
......
......@@ -48,7 +48,7 @@ module RdrHsSyn (
RdrNameGenPragmas,
RdrNameInstancePragmas,
extractHsTyRdrNames,
extractHsTyRdrTyVars,
extractHsTyRdrTyVars, extractHsTysRdrTyVars,
extractPatsTyVars,
extractRuleBndrsTyVars,
......@@ -138,6 +138,9 @@ extractHsTyRdrNames ty = nub (extract_ty ty [])
extractHsTyRdrTyVars :: RdrNameHsType -> [RdrName]
extractHsTyRdrTyVars ty = filter isRdrTyVar (extractHsTyRdrNames ty)
extractHsTysRdrTyVars :: [RdrNameHsType] -> [RdrName]
extractHsTysRdrTyVars tys = filter isRdrTyVar (nub (extract_tys tys []))
extractRuleBndrsTyVars :: [RuleBndr RdrName] -> [RdrName]
extractRuleBndrsTyVars bndrs = filter isRdrTyVar (nub (foldr go [] bndrs))
where
......@@ -151,6 +154,8 @@ extract_ctxt ctxt acc = foldr extract_ass acc ctxt
where
extract_ass (cls, tys) acc = foldr extract_ty (cls : acc) tys
extract_tys tys acc = foldr extract_ty acc tys
extract_ty (MonoTyApp ty1 ty2) acc = extract_ty ty1 (extract_ty ty2 acc)
extract_ty (MonoListTy ty) acc = extract_ty ty acc
extract_ty (MonoTupleTy tys _) acc = foldr extract_ty acc tys
......
......@@ -44,9 +44,7 @@ import PrelMods ( mAIN_Name, pREL_MAIN_Name )
import TysWiredIn ( unitTyCon, intTyCon, doubleTyCon, boolTyCon )
import PrelInfo ( ioTyCon_NAME, numClass_RDR, thinAirIdNames, derivingOccurrences )
import Type ( namesOfType, funTyCon )
import ErrUtils ( pprBagOfErrors, pprBagOfWarnings,
doIfSet, dumpIfSet, ghcExit
)
import ErrUtils ( printErrorsAndWarnings, dumpIfSet, ghcExit )
import BasicTypes ( NewOrData(..) )
import Bag ( isEmptyBag, bagToList )
import FiniteMap ( fmToList, delListFromFM, addToFM, sizeFM, eltsFM )
......@@ -77,14 +75,7 @@ renameModule us this_mod@(HsModule mod_name vers exports imports local_decls loc
\ (maybe_rn_stuff, rn_errs_bag, rn_warns_bag) ->
-- Check for warnings
doIfSet (not (isEmptyBag rn_warns_bag))
(printErrs (pprBagOfWarnings rn_warns_bag)) >>
-- Check for errors; exit if so
doIfSet (not (isEmptyBag rn_errs_bag))
(printErrs (pprBagOfErrors rn_errs_bag) >>
ghcExit 1
) >>
printErrorsAndWarnings rn_errs_bag rn_warns_bag >>
-- Dump output, if any
(case maybe_rn_stuff of
......@@ -95,7 +86,10 @@ renameModule us this_mod@(HsModule mod_name vers exports imports local_decls loc
) >>
-- Return results
return maybe_rn_stuff
if not (isEmptyBag rn_errs_bag) then
ghcExit 1 >> return Nothing
else
return maybe_rn_stuff
\end{code}
......
......@@ -14,7 +14,7 @@ import HsPragmas
import HsTypes ( getTyVarName, pprClassAssertion, cmpHsTypes )
import RdrName ( RdrName, isRdrDataCon, rdrNameOcc, isRdrTyVar )
import RdrHsSyn ( RdrNameContext, RdrNameHsType, RdrNameConDecl,
extractRuleBndrsTyVars, extractHsTyRdrTyVars
extractRuleBndrsTyVars, extractHsTyRdrTyVars, extractHsTysRdrTyVars
)
import RnHsSyn
import HsCore
......@@ -551,7 +551,7 @@ rnHsPolyType doc (HsForAllTy Nothing ctxt ty)
mentioned_in_tau = extractHsTyRdrTyVars ty
forall_tyvars = filter (not . (`elemFM` name_env)) mentioned_in_tau
in
checkConstraints False doc forall_tyvars ctxt ty `thenRn` \ ctxt' ->
checkConstraints doc forall_tyvars mentioned_in_tau ctxt ty `thenRn` \ ctxt' ->
rnForAll doc (map UserTyVar forall_tyvars) ctxt' ty
rnHsPolyType doc (HsForAllTy (Just forall_tyvars) ctxt tau)
......@@ -575,9 +575,9 @@ rnHsPolyType doc (HsForAllTy (Just forall_tyvars) ctxt tau)
forall_tyvar_names = map getTyVarName forall_tyvars
in
mapRn_ (forAllErr doc tau) bad_guys `thenRn_`
mapRn_ (forAllWarn doc tau) warn_guys `thenRn_`
checkConstraints True doc forall_tyvar_names ctxt tau `thenRn` \ ctxt' ->
mapRn_ (forAllErr doc tau) bad_guys `thenRn_`
mapRn_ (forAllWarn doc tau) warn_guys `thenRn_`
checkConstraints doc forall_tyvar_names mentioned_in_tau ctxt tau `thenRn` \ ctxt' ->
rnForAll doc forall_tyvars ctxt' tau
rnHsPolyType doc other_ty = rnHsType doc other_ty
......@@ -587,19 +587,26 @@ rnHsPolyType doc other_ty = rnHsType doc other_ty
-- Since the forall'd type variables are a subset of the free tyvars
-- of the tau-type part, this guarantees that every constraint mentions
-- at least one of the free tyvars in ty
checkConstraints explicit_forall doc forall_tyvars ctxt ty
checkConstraints doc forall_tyvars tau_vars ctxt ty
= mapRn check ctxt `thenRn` \ maybe_ctxt' ->
returnRn (catMaybes maybe_ctxt')
-- Remove problem ones, to avoid duplicate error message.
where
check ct@(_,tys)
| forall_mentioned = returnRn (Just ct)
| otherwise = addErrRn (ctxtErr explicit_forall doc forall_tyvars ct ty)
`thenRn_` returnRn Nothing
| ambiguous = failWithRn Nothing (ambigErr doc ct ty)
| not_univ = failWithRn Nothing (univErr doc ct ty)
| otherwise = returnRn (Just ct)
where
forall_mentioned = foldr ((||) . any (`elem` forall_tyvars) . extractHsTyRdrTyVars)
False
tys
ct_vars = extractHsTysRdrTyVars tys
ambiguous = -- All the universally-quantified tyvars in the constraint must appear in the tau ty
-- (will change when we get functional dependencies)
not (all (\ct_var -> not (ct_var `elem` forall_tyvars) || ct_var `elem` tau_vars) ct_vars)
not_univ = -- At least one of the tyvars in each constraint must
-- be universally quantified. This restriction isn't in Hugs
not (any (`elem` forall_tyvars) ct_vars)
rnForAll doc forall_tyvars ctxt ty
= bindTyVarsFVRn doc forall_tyvars $ \ new_tyvars ->
......@@ -918,17 +925,22 @@ forAllErr doc ty tyvar
$$
(ptext SLIT("In") <+> doc))
ctxtErr explicit_forall doc tyvars constraint ty
= sep [ptext SLIT("None of the type variable(s) in the constraint")
<+> quotes (pprClassAssertion constraint),
if explicit_forall then
nest 4 (ptext SLIT("is universally quantified (i.e. bound by the forall)"))
else
nest 4 (ptext SLIT("appears in the type") <+> quotes (ppr ty))
univErr doc constraint ty
= sep [ptext SLIT("All of the type variable(s) in the constraint")
<+> quotes (pprClassAssertion constraint)
<+> ptext SLIT("are already in scope"),
nest 4 (ptext SLIT("At least one must be universally quantified here"))
]
$$
(ptext SLIT("In") <+> doc)
ambigErr doc constraint ty
= sep [ptext SLIT("Ambiguous constraint") <+> quotes (pprClassAssertion constraint),
nest 4 (ptext SLIT("in the type:") <+> ppr ty),
nest 4 (ptext SLIT("Each forall-d type variable mentioned by the constraint must appear after the =>."))]
$$
(ptext SLIT("In") <+> doc)
unexpectedForAllTy ty
= ptext SLIT("Unexpected forall type:") <+> ppr ty
......
......@@ -25,7 +25,7 @@ import CoreSyn
import CoreFVs ( idRuleVars )
import CoreUtils ( exprIsTrivial )
import Const ( Con(..), Literal(..) )
import Id ( isSpecPragmaId, isOneShotLambda,
import Id ( isSpecPragmaId, isOneShotLambda, setOneShotLambda,
getInlinePragma, setInlinePragma,
isExportedId, modifyIdInfo, idInfo,
getIdSpecialisation,
......@@ -626,6 +626,10 @@ occAnal env expr@(Lam _ _)
= case occAnal (env_body `addNewCands` binders) body of { (body_usage, body') ->
let
(final_usage, tagged_binders) = tagBinders body_usage binders
-- URGH! Sept 99: we don't seem to be able to use binders' here, because
-- we get linear-typed things in the resulting program that we can't handle yet.
-- (e.g. PrelShow) TODO
really_final_usage = if linear then
final_usage
else
......@@ -635,7 +639,7 @@ occAnal env expr@(Lam _ _)
mkLams tagged_binders body') }
where
(binders, body) = collectBinders expr
(linear, env_body) = oneShotGroup env (filter isId binders)
(linear, env_body, binders') = oneShotGroup env binders
occAnal env (Case scrut bndr alts)
= case mapAndUnzip (occAnalAlt alt_env) alts of { (alts_usage_s, alts') ->
......@@ -764,15 +768,31 @@ addNewCand (OccEnv ifun cands ctxt) id
setCtxt :: OccEnv -> CtxtTy -> OccEnv
setCtxt (OccEnv ifun cands _) ctxt = OccEnv ifun cands ctxt
oneShotGroup :: OccEnv -> [Id] -> (Bool, OccEnv) -- True <=> this is a one-shot linear lambda group
-- The [Id] are the binders
oneShotGroup :: OccEnv -> [CoreBndr] -> (Bool, OccEnv, [CoreBndr])
-- True <=> this is a one-shot linear lambda group
-- The [CoreBndr] are the binders.
-- The result binders have one-shot-ness set that they might not have had originally.
-- This happens in (build (\cn -> e)). Here the occurrence analyser
-- linearity context knows that c,n are one-shot, and it records that fact in
-- the binder. This is useful to guide subsequent float-in/float-out tranformations
oneShotGroup (OccEnv ifun cands ctxt) bndrs
= (go bndrs ctxt, OccEnv ifun cands (drop (length bndrs) ctxt))
= case go ctxt bndrs [] of
(new_ctxt, new_bndrs) -> (all is_one_shot new_bndrs, OccEnv ifun cands new_ctxt, new_bndrs)
where
-- Only return True if *all* the lambdas are linear
go (bndr:bndrs) (lin:ctxt) = (lin || isOneShotLambda bndr) && go bndrs ctxt
go [] ctxt = True
go bndrs [] = all isOneShotLambda bndrs
is_one_shot b = isId b && isOneShotLambda b
go ctxt [] rev_bndrs = (ctxt, reverse rev_bndrs)
go (lin_ctxt:ctxt) (bndr:bndrs) rev_bndrs
| isId bndr = go ctxt bndrs (bndr':rev_bndrs)
where
bndr' | lin_ctxt = setOneShotLambda bndr
| otherwise = bndr
go ctxt (bndr:bndrs) rev_bndrs = go ctxt bndrs (bndr:rev_bndrs)
zapCtxt env@(OccEnv ifun cands []) = env
zapCtxt (OccEnv ifun cands _ ) = OccEnv ifun cands []
......
......@@ -266,9 +266,9 @@ lvlExpr ctxt_lvl env (_, AnnNote note expr)
-- Why not? Because partial applications are fairly rare, and splitting
-- lambdas makes them more expensive.
lvlExpr ctxt_lvl env (_, AnnLam bndr rhs)
lvlExpr ctxt_lvl env expr@(_, AnnLam bndr rhs)
= lvlMFE incd_lvl new_env body `thenLvl` \ body' ->
returnLvl (mkLams lvld_bndrs body')
returnLvl (mk_lams lvld_bndrs expr body')
where
bndr_is_id = isId bndr
bndr_is_tyvar = isTyVar bndr
......@@ -283,11 +283,21 @@ lvlExpr ctxt_lvl env (_, AnnLam bndr rhs)
lvld_bndrs = [(b,incd_lvl) | b <- bndrs]
new_env = extendLvlEnv env lvld_bndrs
-- Ignore notes, because we don't want to split
-- a lambda like this (\x -> coerce t (\s -> ...))
-- This happens quite a bit in state-transformer programs
go (_, AnnLam bndr rhs) | bndr_is_id && isId bndr
|| bndr_is_tyvar && isTyVar bndr
= case go rhs of { (bndrs, body) -> (bndr:bndrs, body) }
go (_, AnnNote _ rhs) = go rhs
go body = ([], body)
-- Have to reconstruct the right Notes, since we ignored
-- them when gathering the lambdas
mk_lams (lb : lbs) (_, AnnLam _ body) body' = Lam lb (mk_lams lbs body body')
mk_lams lbs (_, AnnNote note body) body' = Note note (mk_lams lbs body body')
mk_lams [] body body' = body'
lvlExpr ctxt_lvl env (_, AnnLet bind body)
= lvlBind NotTopLevel ctxt_lvl env bind `thenLvl` \ (binds', new_env) ->
lvlExpr ctxt_lvl new_env body `thenLvl` \ body' ->
......
......@@ -259,6 +259,11 @@ simplifyPgm (imported_rule_ids, rule_lhs_fvs)
let { (binds', counts') = initSmpl sw_chkr us1 imported_rule_ids
black_list_fn
(simplTopBinds tagged_binds);
-- The imported_rule_ids are used by initSmpl to initialise
-- the in-scope set. That way, the simplifier will change any
-- occurrences of the imported id to the one in the imported_rule_ids
-- set, which are decorated with their rules.
all_counts = counts `plusSimplCount` counts'
} ;
......@@ -447,7 +452,14 @@ postSimplExpr (Let bind body)
returnPM (Let bind' body')
postSimplExpr (Note note body)
= postSimplExprEta body `thenPM` \ body' ->
= postSimplExpr body `thenPM` \ body' ->
-- Do *not* call postSimplExprEta here
-- We don't want to turn f = \x -> coerce t (\y -> f x y)
-- into f = \x -> coerce t (f x)
-- because then f has a lower arity.
-- This is not only bad in general, it causes the arity to
-- not match the [Demand] on an Id,
-- which confuses the importer of this module.
returnPM (Note note body')
postSimplExpr (Case scrut case_bndr alts)
......
......@@ -39,6 +39,7 @@ module SimplMonad (
getEnclosingCC, setEnclosingCC,