Commit e5adcaf8 authored by simonpj@microsoft.com's avatar simonpj@microsoft.com
Browse files

Improve SpecConstr for local bindings: seed specialisation from the calls

This patch makes a significant improvement to SpecConstr, based on
Roman's experience with using it for stream fusion.  The main change is
this:

  * For local (not-top-level) declarations, seed the specialisation 
    loop from the calls in the body of the 'let'.

See Note [Local recursive groups] for discussion and example.  Top-level
declarations are treated just as before.

Other changes in this patch:

  * New flag -fspec-constr-count=N sets the maximum number of specialisations
    for any single function to N.  -fno-spec-constr-count removes the limit.

  * Refactoring in specLoop and friends; new algebraic data types 
    OneSpec and SpecInfo instead of the tuples that were there before

  * Be less keen to specialise on variables that are simply in scope.
    Example
      f p q = letrec g a y = ...g....  in g q p
    We probably do not want to specialise 'g' for calls with exactly
    the arguments 'q' and 'p', since we know nothing about them.
parent 6dc702e8
......@@ -305,6 +305,7 @@ data DynFlags = DynFlags {
ruleCheck :: Maybe String,
specConstrThreshold :: Maybe Int, -- Threshold for SpecConstr
specConstrCount :: Maybe Int, -- Max number of specialisations for any one function
liberateCaseThreshold :: Maybe Int, -- Threshold for LiberateCase
stolen_x86_regs :: Int,
......@@ -496,6 +497,7 @@ defaultDynFlags =
shouldDumpSimplPhase = const False,
ruleCheck = Nothing,
specConstrThreshold = Just 200,
specConstrCount = Just 3,
liberateCaseThreshold = Just 200,
stolen_x86_regs = 4,
cmdlineHcIncludes = [],
......@@ -1185,6 +1187,10 @@ dynamic_flags = [
upd (\dfs -> dfs{ specConstrThreshold = Just n })))
, ( "fno-spec-constr-threshold", NoArg (
upd (\dfs -> dfs{ specConstrThreshold = Nothing })))
, ( "fspec-constr-count", IntSuffix (\n ->
upd (\dfs -> dfs{ specConstrCount = Just n })))
, ( "fno-spec-constr-count", NoArg (
upd (\dfs -> dfs{ specConstrCount = Nothing })))
, ( "fliberate-case-threshold", IntSuffix (\n ->
upd (\dfs -> dfs{ liberateCaseThreshold = Just n })))
, ( "fno-liberate-case-threshold", NoArg (
......
......@@ -40,7 +40,7 @@ import ErrUtils ( dumpIfSet_dyn )
import DynFlags ( DynFlags(..), DynFlag(..) )
import StaticFlags ( opt_SpecInlineJoinPoints )
import BasicTypes ( Activation(..) )
import Maybes ( orElse, catMaybes, isJust )
import Maybes ( orElse, catMaybes, isJust, isNothing )
import Util
import List ( nubBy, partition )
import UniqSupply
......@@ -48,6 +48,7 @@ import Outputable
import FastString
import UniqFM
import MonadUtils
import Control.Monad ( zipWithM )
\end{code}
-----------------------------------------------------
......@@ -344,6 +345,29 @@ The recursive call ends up looking like
So we want to spot the construtor application inside the cast.
That's why we have the Cast case in argToPat
Note [Local recursive groups]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For a *local* recursive group, we can see all the calls to the
function, so we seed the specialisation loop from the calls in the
body, not from the calls in the RHS. Consider:
bar m n = foo n (n,n) (n,n) (n,n) (n,n)
where
foo n p q r s
| n == 0 = m
| n > 3000 = case p of { (p1,p2) -> foo (n-1) (p2,p1) q r s }
| n > 2000 = case q of { (q1,q2) -> foo (n-1) p (q2,q1) r s }
| n > 1000 = case r of { (r1,r2) -> foo (n-1) p q (r2,r1) s }
| otherwise = case s of { (s1,s2) -> foo (n-1) p q r (s2,s1) }
If we start with the RHSs of 'foo', we get lots and lots of specialisations,
most of which are not needed. But if we start with the (single) call
in the rhs of 'bar' we get exactly one fully-specialised copy, and all
the recursive calls go to this fully-specialised copy. Indeed, the original
function is later collected as dead code. This is very important in
specialising the loops arising from stream fusion, for example in NDP where
we were getting literally hundreds of (mostly unused) specialisations of
a local function.
-----------------------------------------------------
Stuff not yet handled
......@@ -444,7 +468,7 @@ specConstrProgram dflags us binds
return binds'
where
go _ [] = return []
go env (bind:binds) = do (env', _, bind') <- scBind env bind
go env (bind:binds) = do (env', bind') <- scTopBind env bind
binds' <- go env' binds
return (bind' : binds')
\end{code}
......@@ -457,7 +481,8 @@ specConstrProgram dflags us binds
%************************************************************************
\begin{code}
data ScEnv = SCE { sc_size :: Maybe Int, -- Size threshold
data ScEnv = SCE { sc_size :: Maybe Int, -- Size threshold
sc_count :: Maybe Int, -- Max # of specialisations for any one fn
sc_subst :: Subst, -- Current substitution
-- Maps InIds to OutExprs
......@@ -495,6 +520,7 @@ instance Outputable Value where
initScEnv :: DynFlags -> ScEnv
initScEnv dflags
= SCE { sc_size = specConstrThreshold dflags,
sc_count = specConstrCount dflags,
sc_subst = emptySubst,
sc_how_bound = emptyVarEnv,
sc_vals = emptyVarEnv }
......@@ -805,15 +831,28 @@ scExpr' env (Let (NonRec bndr rhs) body)
; return (body_usg { scu_calls = scu_calls body_usg `delVarEnv` bndr' }
`combineUsage` rhs_usg `combineUsage` spec_usg,
mkLets [NonRec b r | (b,r) <- addRules rhs_info specs] body')
mkLets [NonRec b r | (b,r) <- specInfoBinds rhs_info specs] body')
}
-}
-- A *local* recursive group: see Note [Local recursive groups]
scExpr' env (Let (Rec prs) body)
= do { (env', bind_usg, bind') <- scBind env (Rec prs)
; (body_usg, body') <- scExpr env' body
; return (bind_usg `combineUsage` body_usg, Let bind' body') }
= do { let (bndrs,rhss) = unzip prs
(rhs_env1,bndrs') = extendRecBndrs env bndrs
rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
; (body_usg, body') <- scExpr rhs_env2 body
-- NB: start specLoop from body_usg
; (spec_usg, specs) <- specLoop rhs_env2 (scu_calls body_usg) rhs_infos nullUsage
[SI [] 0 (Just usg) | usg <- rhs_usgs]
; let all_usg = spec_usg `combineUsage` body_usg
bind' = Rec (concat (zipWith specInfoBinds rhs_infos specs))
; return (all_usg { scu_calls = scu_calls all_usg `delVarEnvList` bndrs' },
Let bind' body') }
-----------------------------------
scApp :: ScEnv -> (InExpr, [InExpr]) -> UniqSM (ScUsage, CoreExpr)
......@@ -857,14 +896,14 @@ scApp env (other_fn, args)
; return (combineUsages arg_usgs `combineUsage` fn_usg, mkApps fn' args') }
----------------------
scBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, ScUsage, CoreBind)
scBind env (Rec prs)
scTopBind :: ScEnv -> CoreBind -> UniqSM (ScEnv, CoreBind)
scTopBind env (Rec prs)
| Just threshold <- sc_size env
, not (all (couldBeSmallEnoughToInline threshold) rhss)
-- No specialisation
= do { let (rhs_env,bndrs') = extendRecBndrs env bndrs
; (rhs_usgs, rhss') <- mapAndUnzipM (scExpr rhs_env) rhss
; return (rhs_env, combineUsages rhs_usgs, Rec (bndrs' `zip` rhss')) }
; (_, rhss') <- mapAndUnzipM (scExpr rhs_env) rhss
; return (rhs_env, Rec (bndrs' `zip` rhss')) }
| otherwise -- Do specialisation
= do { let (rhs_env1,bndrs') = extendRecBndrs env bndrs
rhs_env2 = extendHowBound rhs_env1 bndrs' RecFun
......@@ -872,38 +911,19 @@ scBind env (Rec prs)
; (rhs_usgs, rhs_infos) <- mapAndUnzipM (scRecRhs rhs_env2) (bndrs' `zip` rhss)
; let rhs_usg = combineUsages rhs_usgs
; (spec_usg, specs) <- spec_loop rhs_env2 (scu_calls rhs_usg)
(repeat [] `zip` rhs_infos)
; let all_usg = rhs_usg `combineUsage` spec_usg
; (_, specs) <- specLoop rhs_env2 (scu_calls rhs_usg) rhs_infos nullUsage
[SI [] 0 Nothing | _ <- bndrs]
; return (rhs_env1, -- For the body of the letrec, delete the RecFun business
all_usg { scu_calls = scu_calls rhs_usg `delVarEnvList` bndrs' },
Rec (concat (zipWith addRules rhs_infos specs))) }
Rec (concat (zipWith specInfoBinds rhs_infos specs))) }
where
(bndrs,rhss) = unzip prs
spec_loop :: ScEnv
-> CallEnv
-> [([CallPat], RhsInfo)] -- One per binder
-> UniqSM (ScUsage, [[SpecInfo]]) -- One list per binder
spec_loop env all_calls rhs_stuff
= do { (spec_usg_s, new_pats_s, specs) <- mapAndUnzip3M (specialise env all_calls) rhs_stuff
; let spec_usg = combineUsages spec_usg_s
; if all null new_pats_s then
return (spec_usg, specs) else do
{ (spec_usg1, specs1) <- spec_loop env (scu_calls spec_usg)
(zipWith add_pats new_pats_s rhs_stuff)
; return (spec_usg `combineUsage` spec_usg1, zipWith (++) specs specs1) } }
add_pats :: [CallPat] -> ([CallPat], RhsInfo) -> ([CallPat], RhsInfo)
add_pats new_pats (done_pats, rhs_info) = (done_pats ++ new_pats, rhs_info)
scBind env (NonRec bndr rhs)
= do { (usg, rhs') <- scExpr env rhs
scTopBind env (NonRec bndr rhs)
= do { (_, rhs') <- scExpr env rhs
; let (env1, bndr') = extendBndr env bndr
env2 = extendValEnv env1 bndr' (isValue (sc_vals env) rhs')
; return (env2, usg, NonRec bndr' rhs') }
; return (env2, NonRec bndr' rhs') }
----------------------
scRecRhs :: ScEnv -> (OutId, InExpr) -> UniqSM (ScUsage, RhsInfo)
......@@ -920,12 +940,12 @@ scRecRhs env (bndr,rhs)
-- Two pats are the same if they match both ways
----------------------
addRules :: RhsInfo -> [SpecInfo] -> [(Id,CoreExpr)]
addRules (fn, args, body, _) specs
= [(id,rhs) | (_,id,rhs) <- specs] ++
specInfoBinds :: RhsInfo -> SpecInfo -> [(Id,CoreExpr)]
specInfoBinds (fn, args, body, _) (SI specs _ _)
= [(id,rhs) | OS _ _ id rhs <- specs] ++
[(fn `addIdSpecialisations` rules, mkLams args body)]
where
rules = [r | (r,_,_) <- specs]
rules = [r | OS _ r _ _ <- specs]
----------------------
varUsage :: ScEnv -> OutVar -> ArgOcc -> ScUsage
......@@ -948,35 +968,80 @@ type RhsInfo = (OutId, [OutVar], OutExpr, [ArgOcc])
-- Original binding f = \xs.body
-- Plus info about usage of arguments
type SpecInfo = (CoreRule, OutId, OutExpr)
-- One specialisation: Rule plus definition
data SpecInfo = SI [OneSpec] -- The specialisations we have generated
Int -- Length of specs; used for numbering them
(Maybe ScUsage) -- Nothing => we have generated specialisations
-- from calls in the *original* RHS
-- Just cs => we haven't, and this is the usage
-- of the original RHS
-- One specialisation: Rule plus definition
data OneSpec = OS CallPat -- Call pattern that generated this specialisation
CoreRule -- Rule connecting original id with the specialisation
OutId OutExpr -- Spec id + its rhs
specLoop :: ScEnv
-> CallEnv
-> [RhsInfo]
-> ScUsage -> [SpecInfo] -- One per binder; acccumulating parameter
-> UniqSM (ScUsage, [SpecInfo]) -- ...ditto...
specLoop env all_calls rhs_infos usg_so_far specs_so_far
= do { specs_w_usg <- zipWithM (specialise env all_calls) rhs_infos specs_so_far
; let (new_usg_s, all_specs) = unzip specs_w_usg
new_usg = combineUsages new_usg_s
new_calls = scu_calls new_usg
all_usg = usg_so_far `combineUsage` new_usg
; if isEmptyVarEnv new_calls then
return (all_usg, all_specs)
else
specLoop env new_calls rhs_infos all_usg all_specs }
specialise
:: ScEnv
-> CallEnv -- Info on calls
-> ([CallPat], RhsInfo) -- Original RHS plus patterns dealt with
-> UniqSM (ScUsage, [CallPat], [SpecInfo]) -- Specialised calls
-> RhsInfo
-> SpecInfo -- Original RHS plus patterns dealt with
-> UniqSM (ScUsage, SpecInfo) -- New specialised versions and their usage
-- Note: the rhs here is the optimised version of the original rhs
-- So when we make a specialised copy of the RHS, we're starting
-- from an RHS whose nested functions have been optimised already.
specialise env bind_calls (done_pats, (fn, arg_bndrs, body, arg_occs))
specialise env bind_calls (fn, arg_bndrs, body, arg_occs)
spec_info@(SI specs spec_count mb_unspec)
| notNull arg_bndrs, -- Only specialise functions
Just all_calls <- lookupVarEnv bind_calls fn
= do { pats <- callsToPats env done_pats arg_occs all_calls
= do { (boring_call, pats) <- callsToPats env specs arg_occs all_calls
-- ; pprTrace "specialise" (vcat [ppr fn <+> ppr arg_occs,
-- text "calls" <+> ppr all_calls,
-- text "good pats" <+> ppr pats]) $
-- return ()
; (spec_usgs, specs) <- mapAndUnzipM (spec_one env fn arg_bndrs body)
(pats `zip` [length done_pats..])
; return (combineUsages spec_usgs, pats, specs) }
-- Bale out if too many specialisations
-- Rather a hacky way to do so, but it'll do for now
; let spec_count' = length pats + spec_count
; case sc_count env of
Just max | spec_count' > max
-> pprTrace "SpecConstr: too many specialisations for one function (see -fspec-constr-count):"
(vcat [ptext SLIT("Function:") <+> ppr fn,
ptext SLIT("Specialisations:") <+> ppr (pats ++ [p | OS p _ _ _ <- specs])])
return (nullUsage, spec_info)
_normal_case -> do
{ (spec_usgs, new_specs) <- mapAndUnzipM (spec_one env fn arg_bndrs body)
(pats `zip` [spec_count..])
; let spec_usg = combineUsages spec_usgs
(new_usg, mb_unspec')
= case mb_unspec of
Just rhs_usg | boring_call -> (spec_usg `combineUsage` rhs_usg, Nothing)
_ -> (spec_usg, mb_unspec)
; return (new_usg, SI (new_specs ++ specs) spec_count' mb_unspec') } }
| otherwise
= return (nullUsage, [], []) -- The boring case
= return (nullUsage, spec_info) -- The boring case
---------------------
......@@ -984,8 +1049,8 @@ spec_one :: ScEnv
-> OutId -- Function
-> [Var] -- Lambda-binders of RHS; should match patterns
-> CoreExpr -- Body of the original function
-> (([Var], [CoreArg]), Int)
-> UniqSM (ScUsage, SpecInfo) -- Rule and binding
-> (CallPat, Int)
-> UniqSM (ScUsage, OneSpec) -- Rule and binding
-- spec_one creates a specialised copy of the function, together
-- with a rule for using it. I'm very proud of how short this
......@@ -1009,7 +1074,7 @@ spec_one :: ScEnv
f (b,c) ((:) (a,(b,c)) (x,v) hw) = f_spec b c v hw
-}
spec_one env fn arg_bndrs body ((qvars, pats), rule_number)
spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
= do { -- Specialise the body
let spec_env = extendScSubstList (extendScInScope env qvars)
(arg_bndrs `zip` pats)
......@@ -1034,7 +1099,7 @@ spec_one env fn arg_bndrs body ((qvars, pats), rule_number)
body_ty = exprType spec_body
rule_rhs = mkVarApps (Var spec_id) spec_call_args
rule = mkLocalRule rule_name specConstrActivation fn_name qvars pats rule_rhs
; return (spec_usg, (rule, spec_id, spec_rhs)) }
; return (spec_usg, OS call_pat rule spec_id spec_rhs) }
-- In which phase should the specialise-constructor rules be active?
-- Originally I made them always-active, but Manuel found that
......@@ -1062,17 +1127,20 @@ they are constructor applications.
type CallPat = ([Var], [CoreExpr]) -- Quantified variables and arguments
callsToPats :: ScEnv -> [CallPat] -> [ArgOcc] -> [Call] -> UniqSM [CallPat]
callsToPats :: ScEnv -> [OneSpec] -> [ArgOcc] -> [Call] -> UniqSM (Bool, [CallPat])
-- Result has no duplicate patterns,
-- nor ones mentioned in done_pats
callsToPats env done_pats bndr_occs calls
-- Bool indicates that there was at least one boring pattern
callsToPats env done_specs bndr_occs calls
= do { mb_pats <- mapM (callToPats env bndr_occs) calls
; let good_pats :: [([Var], [CoreArg])]
good_pats = catMaybes mb_pats
done_pats = [p | OS p _ _ _ <- done_specs]
is_done p = any (samePat p) done_pats
; return (filterOut is_done (nubBy samePat good_pats)) }
; return (any isNothing mb_pats,
filterOut is_done (nubBy samePat good_pats)) }
callToPats :: ScEnv -> [ArgOcc] -> Call -> UniqSM (Maybe CallPat)
-- The [Var] is the variables to quantify over in the rule
......@@ -1085,7 +1153,7 @@ callToPats env bndr_occs (con_env, args)
| otherwise
= do { let in_scope = substInScope (sc_subst env)
; prs <- argsToPats in_scope con_env (args `zip` bndr_occs)
; let (good_pats, pats) = unzip prs
; let (interesting_s, pats) = unzip prs
pat_fvs = varSetElems (exprsFreeVars pats)
qvars = filterOut (`elemInScopeSet` in_scope) pat_fvs
-- Quantify over variables that are not in sccpe
......@@ -1098,7 +1166,7 @@ callToPats env bndr_occs (con_env, args)
-- variable may mention a type variable
; -- pprTrace "callToPats" (ppr args $$ ppr prs $$ ppr bndr_occs) $
if or good_pats
if or interesting_s
then return (Just (qvars', pats))
else return Nothing }
......@@ -1201,12 +1269,18 @@ argToPat in_scope val_env (Var v) arg_occ
-- variables that are in soope, which in turn can
-- expose the weakness in let-matching
-- See Note [Matching lets] in Rules
-- Check for a variable bound inside the function.
-- Don't make a wild-card, because we may usefully share
-- e.g. f a = let x = ... in f (x,x)
-- NB: this case follows the lambda and con-app cases!!
argToPat _in_scope _val_env (Var v) _arg_occ
= return (False, Var v)
-- argToPat _in_scope _val_env (Var v) _arg_occ
-- = return (False, Var v)
-- SLPJ : disabling this to avoid proliferation of versions
-- also works badly when thinking about seeding the loop
-- from the body of the let
-- f x y = letrec g z = ... in g (x,y)
-- We don't want to specialise for that *particular* x,y
-- The default case: make a wild-card
argToPat _in_scope _val_env arg _arg_occ
......
......@@ -1324,6 +1324,15 @@
<entry><option>-fno-spec-constr-threshold</option></entry>
</row>
<row>
<entry><option>-fspec-constr-count</option>=<replaceable>n</replaceable></entry>
<entry>Set to <replaceable>n</replaceable> (default: 3) the maximum number of
specialisations that will be created for any one function
by the SpecConstr transformation</entry>
<entry>static</entry>
<entry><option>-fno-spec-constr-count</option></entry>
</row>
<row>
<entry><option>-fliberate-case</option></entry>
<entry>Turn on the liberate-case transformation. Implied by <option>-O2</option>.</entry>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment