Commit 60994016 authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.

Added a pragma {-# NOVECTORISE f #-} that suppresses vectorisation of toplevel variable 'f'.

parent a8defd8a
...@@ -332,8 +332,9 @@ Also since rule_fn is a Name, not a Var, we have to use the grungy delUFM. ...@@ -332,8 +332,9 @@ Also since rule_fn is a Name, not a Var, we have to use the grungy delUFM.
vectsFreeVars :: [CoreVect] -> VarSet vectsFreeVars :: [CoreVect] -> VarSet
vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet
where where
vectFreeVars (Vect _ Nothing) = noFVs vectFreeVars (Vect _ Nothing) = noFVs
vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (NoVect _) = noFVs
\end{code} \end{code}
......
...@@ -714,8 +714,9 @@ substVects subst = map (substVect subst) ...@@ -714,8 +714,9 @@ substVects subst = map (substVect subst)
------------------ ------------------
substVect :: Subst -> CoreVect -> CoreVect substVect :: Subst -> CoreVect -> CoreVect
substVect _subst (Vect v Nothing) = Vect v Nothing substVect _subst (Vect v Nothing) = Vect v Nothing
substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs)) substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs))
substVect _subst (NoVect v) = NoVect v
------------------ ------------------
substVarSet :: Subst -> VarSet -> VarSet substVarSet :: Subst -> VarSet -> VarSet
......
...@@ -417,14 +417,16 @@ Representation of desugared vectorisation declarations that are fed to the vecto ...@@ -417,14 +417,16 @@ Representation of desugared vectorisation declarations that are fed to the vecto
'ModGuts'). 'ModGuts').
\begin{code} \begin{code}
data CoreVect = Vect Id (Maybe CoreExpr) data CoreVect = Vect Id (Maybe CoreExpr)
| NoVect Id
\end{code} \end{code}
%************************************************************************ %************************************************************************
%* * %* *
Unfoldings Unfoldings
%* * %* *
%************************************************************************ %************************************************************************
The @Unfolding@ type is declared here to avoid numerous loops The @Unfolding@ type is declared here to avoid numerous loops
......
...@@ -446,7 +446,7 @@ instance Outputable e => Outputable (DFunArg e) where ...@@ -446,7 +446,7 @@ instance Outputable e => Outputable (DFunArg e) where
\end{code} \end{code}
----------------------------------------------------- -----------------------------------------------------
-- Rules -- Rules
----------------------------------------------------- -----------------------------------------------------
\begin{code} \begin{code}
...@@ -461,11 +461,23 @@ pprRule (BuiltinRule { ru_fn = fn, ru_name = name}) ...@@ -461,11 +461,23 @@ pprRule (BuiltinRule { ru_fn = fn, ru_name = name})
= ptext (sLit "Built in rule for") <+> ppr fn <> colon <+> doubleQuotes (ftext name) = ptext (sLit "Built in rule for") <+> ppr fn <> colon <+> doubleQuotes (ftext name)
pprRule (Rule { ru_name = name, ru_act = act, ru_fn = fn, pprRule (Rule { ru_name = name, ru_act = act, ru_fn = fn,
ru_bndrs = tpl_vars, ru_args = tpl_args, ru_bndrs = tpl_vars, ru_args = tpl_args,
ru_rhs = rhs }) ru_rhs = rhs })
= hang (doubleQuotes (ftext name) <+> ppr act) = hang (doubleQuotes (ftext name) <+> ppr act)
4 (sep [ptext (sLit "forall") <+> braces (sep (map pprTypedBinder tpl_vars)), 4 (sep [ptext (sLit "forall") <+> braces (sep (map pprTypedBinder tpl_vars)),
nest 2 (ppr fn <+> sep (map pprArg tpl_args)), nest 2 (ppr fn <+> sep (map pprArg tpl_args)),
nest 2 (ptext (sLit "=") <+> pprCoreExpr rhs) nest 2 (ptext (sLit "=") <+> pprCoreExpr rhs)
]) ])
\end{code}
-----------------------------------------------------
-- Vectorisation declarations
-----------------------------------------------------
\begin{code}
instance Outputable CoreVect where
ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var
ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
\end{code} \end{code}
...@@ -394,16 +394,11 @@ the rule is precisly to optimise them: ...@@ -394,16 +394,11 @@ the rule is precisly to optimise them:
\begin{code} \begin{code}
dsVect :: LVectDecl Id -> DsM CoreVect dsVect :: LVectDecl Id -> DsM CoreVect
dsVect (L loc (HsVect v rhs)) dsVect (L loc (HsVect (L _ v) rhs))
= putSrcSpanDs loc $ = putSrcSpanDs loc $
do { rhs' <- fmapMaybeM dsLExpr rhs do { rhs' <- fmapMaybeM dsLExpr rhs
; return $ Vect (unLoc v) rhs' ; return $ Vect v rhs'
} }
-- dsVect (L loc (HsVect v Nothing)) dsVect (L loc (HsNoVect (L _ v)))
-- = return $ Vect v Nothing = return $ NoVect v
-- dsVect (L loc (HsVect v (Just rhs)))
-- = putSrcSpanDs loc $
-- do { rhs' <- dsLExpr rhs
-- ; return $ Vect v (Just rhs')
-- }
\end{code} \end{code}
...@@ -28,6 +28,7 @@ module HsDecls ( ...@@ -28,6 +28,7 @@ module HsDecls (
collectRuleBndrSigTys, collectRuleBndrSigTys,
-- ** @VECTORISE@ declarations -- ** @VECTORISE@ declarations
VectDecl(..), LVectDecl, VectDecl(..), LVectDecl,
lvectDeclName,
-- ** @default@ declarations -- ** @default@ declarations
DefaultDecl(..), LDefaultDecl, DefaultDecl(..), LDefaultDecl,
-- ** Top-level template haskell splice -- ** Top-level template haskell splice
...@@ -1005,10 +1006,11 @@ instance OutputableBndr name => Outputable (RuleBndr name) where ...@@ -1005,10 +1006,11 @@ instance OutputableBndr name => Outputable (RuleBndr name) where
%* * %* *
%************************************************************************ %************************************************************************
A vectorisation pragma A vectorisation pragma, one of
{-# VECTORISE f = closure1 g (scalar_map g) #-} OR {-# VECTORISE f = closure1 g (scalar_map g) #-}
{-# VECTORISE SCALAR f #-} {-# VECTORISE SCALAR f #-}
{-# NOVECTORISE f #-}
Note [Typechecked vectorisation pragmas] Note [Typechecked vectorisation pragmas]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -1029,14 +1031,23 @@ data VectDecl name ...@@ -1029,14 +1031,23 @@ data VectDecl name
= HsVect = HsVect
(Located name) (Located name)
(Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration (Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration
| HsNoVect
(Located name)
deriving (Data, Typeable) deriving (Data, Typeable)
lvectDeclName :: LVectDecl name -> name
lvectDeclName (L _ (HsVect (L _ name) _)) = name
lvectDeclName (L _ (HsNoVect (L _ name))) = name
instance OutputableBndr name => Outputable (VectDecl name) where instance OutputableBndr name => Outputable (VectDecl name) where
ppr (HsVect v rhs) ppr (HsVect v Nothing)
= sep [text "{-# VECTORISE SCALAR" <+> ppr v <+> text "#-}" ]
ppr (HsVect v (Just rhs))
= sep [text "{-# VECTORISE" <+> ppr v, = sep [text "{-# VECTORISE" <+> ppr v,
nest 4 (case rhs of nest 4 $
Nothing -> text "SCALAR #-}" pprExpr (unLoc rhs) <+> text "#-}" ]
Just rhs -> pprExpr (unLoc rhs) <+> text "#-}") ] ppr (HsNoVect v)
= sep [text "{-# NOVECTORISE" <+> ppr v <+> text "#-}" ]
\end{code} \end{code}
%************************************************************************ %************************************************************************
......
...@@ -483,6 +483,7 @@ data Token ...@@ -483,6 +483,7 @@ data Token
| ITlanguage_prag | ITlanguage_prag
| ITvect_prag | ITvect_prag
| ITvect_scalar_prag | ITvect_scalar_prag
| ITnovect_prag
| ITdotdot -- reserved symbols | ITdotdot -- reserved symbols
| ITcolon | ITcolon
...@@ -2281,7 +2282,8 @@ oneWordPrags = Map.fromList([("rules", rulePrag), ...@@ -2281,7 +2282,8 @@ oneWordPrags = Map.fromList([("rules", rulePrag),
("core", token ITcore_prag), ("core", token ITcore_prag),
("unpack", token ITunpack_prag), ("unpack", token ITunpack_prag),
("ann", token ITann_prag), ("ann", token ITann_prag),
("vectorize", token ITvect_prag)]) ("vectorize", token ITvect_prag),
("novectorize", token ITnovect_prag)])
twoWordPrags = Map.fromList([("inline conlike", token (ITinline_prag Inline ConLike)), twoWordPrags = Map.fromList([("inline conlike", token (ITinline_prag Inline ConLike)),
("notinline conlike", token (ITinline_prag NoInline ConLike)), ("notinline conlike", token (ITinline_prag NoInline ConLike)),
...@@ -2307,6 +2309,7 @@ clean_pragma prag = canon_ws (map toLower (unprefix prag)) ...@@ -2307,6 +2309,7 @@ clean_pragma prag = canon_ws (map toLower (unprefix prag))
"noinline" -> "notinline" "noinline" -> "notinline"
"specialise" -> "specialize" "specialise" -> "specialize"
"vectorise" -> "vectorize" "vectorise" -> "vectorize"
"novectorise" -> "novectorize"
"constructorlike" -> "conlike" "constructorlike" -> "conlike"
_ -> prag' _ -> prag'
canon_ws s = unwords (map canonical (words s)) canon_ws s = unwords (map canonical (words s))
......
...@@ -252,21 +252,22 @@ incorrect. ...@@ -252,21 +252,22 @@ incorrect.
'by' { L _ ITby } -- for list transform extension 'by' { L _ ITby } -- for list transform extension
'using' { L _ ITusing } -- for list transform extension 'using' { L _ ITusing } -- for list transform extension
'{-# INLINE' { L _ (ITinline_prag _ _) } '{-# INLINE' { L _ (ITinline_prag _ _) }
'{-# SPECIALISE' { L _ ITspec_prag } '{-# SPECIALISE' { L _ ITspec_prag }
'{-# SPECIALISE_INLINE' { L _ (ITspec_inline_prag _) } '{-# SPECIALISE_INLINE' { L _ (ITspec_inline_prag _) }
'{-# SOURCE' { L _ ITsource_prag } '{-# SOURCE' { L _ ITsource_prag }
'{-# RULES' { L _ ITrules_prag } '{-# RULES' { L _ ITrules_prag }
'{-# CORE' { L _ ITcore_prag } -- hdaume: annotated core '{-# CORE' { L _ ITcore_prag } -- hdaume: annotated core
'{-# SCC' { L _ ITscc_prag } '{-# SCC' { L _ ITscc_prag }
'{-# GENERATED' { L _ ITgenerated_prag } '{-# GENERATED' { L _ ITgenerated_prag }
'{-# DEPRECATED' { L _ ITdeprecated_prag } '{-# DEPRECATED' { L _ ITdeprecated_prag }
'{-# WARNING' { L _ ITwarning_prag } '{-# WARNING' { L _ ITwarning_prag }
'{-# UNPACK' { L _ ITunpack_prag } '{-# UNPACK' { L _ ITunpack_prag }
'{-# ANN' { L _ ITann_prag } '{-# ANN' { L _ ITann_prag }
'{-# VECTORISE' { L _ ITvect_prag } '{-# VECTORISE' { L _ ITvect_prag }
'{-# VECTORISE_SCALAR' { L _ ITvect_scalar_prag } '{-# VECTORISE_SCALAR' { L _ ITvect_scalar_prag }
'#-}' { L _ ITclose_prag } '{-# NOVECTORISE' { L _ ITnovect_prag }
'#-}' { L _ ITclose_prag }
'..' { L _ ITdotdot } -- reserved symbols '..' { L _ ITdotdot } -- reserved symbols
':' { L _ ITcolon } ':' { L _ ITcolon }
...@@ -546,33 +547,34 @@ ops :: { Located [Located RdrName] } ...@@ -546,33 +547,34 @@ ops :: { Located [Located RdrName] }
-- Top-Level Declarations -- Top-Level Declarations
topdecls :: { OrdList (LHsDecl RdrName) } topdecls :: { OrdList (LHsDecl RdrName) }
: topdecls ';' topdecl { $1 `appOL` $3 } : topdecls ';' topdecl { $1 `appOL` $3 }
| topdecls ';' { $1 } | topdecls ';' { $1 }
| topdecl { $1 } | topdecl { $1 }
topdecl :: { OrdList (LHsDecl RdrName) } topdecl :: { OrdList (LHsDecl RdrName) }
: cl_decl { unitOL (L1 (TyClD (unLoc $1))) } : cl_decl { unitOL (L1 (TyClD (unLoc $1))) }
| ty_decl { unitOL (L1 (TyClD (unLoc $1))) } | ty_decl { unitOL (L1 (TyClD (unLoc $1))) }
| 'instance' inst_type where_inst | 'instance' inst_type where_inst
{ let (binds, sigs, ats, _) = cvBindsAndSigs (unLoc $3) { let (binds, sigs, ats, _) = cvBindsAndSigs (unLoc $3)
in in
unitOL (L (comb3 $1 $2 $3) (InstD (InstDecl $2 binds sigs ats)))} unitOL (L (comb3 $1 $2 $3) (InstD (InstDecl $2 binds sigs ats)))}
| stand_alone_deriving { unitOL (LL (DerivD (unLoc $1))) } | stand_alone_deriving { unitOL (LL (DerivD (unLoc $1))) }
| 'default' '(' comma_types0 ')' { unitOL (LL $ DefD (DefaultDecl $3)) } | 'default' '(' comma_types0 ')' { unitOL (LL $ DefD (DefaultDecl $3)) }
| 'foreign' fdecl { unitOL (LL (unLoc $2)) } | 'foreign' fdecl { unitOL (LL (unLoc $2)) }
| '{-# DEPRECATED' deprecations '#-}' { $2 } | '{-# DEPRECATED' deprecations '#-}' { $2 }
| '{-# WARNING' warnings '#-}' { $2 } | '{-# WARNING' warnings '#-}' { $2 }
| '{-# RULES' rules '#-}' { $2 } | '{-# RULES' rules '#-}' { $2 }
| '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) } | '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) } | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
| annotation { unitOL $1 } | '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| decl { unLoc $1 } | annotation { unitOL $1 }
| decl { unLoc $1 }
-- Template Haskell Extension
-- The $(..) form is one possible form of infixexp -- Template Haskell Extension
-- but we treat an arbitrary expression just as if -- The $(..) form is one possible form of infixexp
-- it had a $(..) wrapped around it -- but we treat an arbitrary expression just as if
| infixexp { unitOL (LL $ mkTopSpliceDecl $1) } -- it had a $(..) wrapped around it
| infixexp { unitOL (LL $ mkTopSpliceDecl $1) }
-- Type classes -- Type classes
-- --
......
...@@ -666,6 +666,10 @@ rnHsVectDecl (HsVect var (Just rhs)) ...@@ -666,6 +666,10 @@ rnHsVectDecl (HsVect var (Just rhs))
; (rhs', fv_rhs) <- rnLExpr rhs ; (rhs', fv_rhs) <- rnLExpr rhs
; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var') ; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var')
} }
rnHsVectDecl (HsNoVect var)
= do { var' <- wrapLocM lookupTopBndrRn var
; return (HsNoVect var', unitFV (unLoc var'))
}
\end{code} \end{code}
%********************************************************* %*********************************************************
......
...@@ -29,7 +29,7 @@ import FloatIn ( floatInwards ) ...@@ -29,7 +29,7 @@ import FloatIn ( floatInwards )
import FloatOut ( floatOutwards ) import FloatOut ( floatOutwards )
import FamInstEnv import FamInstEnv
import Id import Id
import BasicTypes ( CompilerPhase, isDefaultInlinePragma ) import BasicTypes
import VarSet import VarSet
import VarEnv import VarEnv
import LiberateCase ( liberateCase ) import LiberateCase ( liberateCase )
...@@ -356,11 +356,18 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode) ...@@ -356,11 +356,18 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
-- space usage, especially with -O. JRS, 000620. -- space usage, especially with -O. JRS, 000620.
| let sz = coreBindsSize binds in sz == sz | let sz = coreBindsSize binds in sz == sz
= do { = do {
-- Occurrence analysis -- Occurrence analysis
let { tagged_binds = {-# SCC "OccAnal" #-} let { -- During the 'InitialPhase' (i.e., before vectorisation), we need to make sure
occurAnalysePgm active_rule rules [] binds } ; -- that the right-hand sides of vectorisation declarations are taken into
Err.dumpIfSet_dyn dflags Opt_D_dump_occur_anal "Occurrence analysis" -- account during occurence analysis.
(pprCoreBindings tagged_binds); maybeVects = case sm_phase mode of
InitialPhase -> mg_vect_decls guts
_ -> []
; tagged_binds = {-# SCC "OccAnal" #-}
occurAnalysePgm active_rule rules maybeVects binds
} ;
Err.dumpIfSet_dyn dflags Opt_D_dump_occur_anal "Occurrence analysis"
(pprCoreBindings tagged_binds);
-- Get any new rules, and extend the rule base -- Get any new rules, and extend the rule base
-- See Note [Overall plumbing for rules] in Rules.lhs -- See Note [Overall plumbing for rules] in Rules.lhs
......
...@@ -591,7 +591,7 @@ impSpecErr name ...@@ -591,7 +591,7 @@ impSpecErr name
tcVectDecls :: [LVectDecl Name] -> TcM ([LVectDecl TcId]) tcVectDecls :: [LVectDecl Name] -> TcM ([LVectDecl TcId])
tcVectDecls decls tcVectDecls decls
= do { decls' <- mapM (wrapLocM tcVect) decls = do { decls' <- mapM (wrapLocM tcVect) decls
; let ids = [unLoc id | L _ (HsVect id _) <- decls'] ; let ids = map lvectDeclName decls'
dups = findDupsEq (==) ids dups = findDupsEq (==) ids
; mapM_ reportVectDups dups ; mapM_ reportVectDups dups
; traceTcConstraints "End of tcVectDecls" ; traceTcConstraints "End of tcVectDecls"
...@@ -642,6 +642,11 @@ tcVect (HsVect name@(L loc _) (Just rhs)) ...@@ -642,6 +642,11 @@ tcVect (HsVect name@(L loc _) (Just rhs))
-- to the vectoriser - see "Note [Typechecked vectorisation pragmas]" in HsDecls -- to the vectoriser - see "Note [Typechecked vectorisation pragmas]" in HsDecls
; return $ HsVect (L loc id') (Just rhsWrapped) ; return $ HsVect (L loc id') (Just rhsWrapped)
} }
tcVect (HsNoVect name)
= addErrCtxt (vectCtxt name) $
do { id <- wrapLocM tcLookupId name
; return $ HsNoVect id
}
vectCtxt :: Located Name -> SDoc vectCtxt :: Located Name -> SDoc
vectCtxt name = ptext (sLit "When checking the vectorisation declaration for") <+> ppr name vectCtxt name = ptext (sLit "When checking the vectorisation declaration for") <+> ppr name
......
...@@ -1027,6 +1027,10 @@ zonkVect env (HsVect v (Just e)) ...@@ -1027,6 +1027,10 @@ zonkVect env (HsVect v (Just e))
; e' <- zonkLExpr env e ; e' <- zonkLExpr env e
; return $ HsVect v' (Just e') ; return $ HsVect v' (Just e')
} }
zonkVect env (HsNoVect v)
= do { v' <- wrapLocM (zonkIdBndr env) v
; return $ HsNoVect v'
}
\end{code} \end{code}
%************************************************************************ %************************************************************************
......
{-# OPTIONS -fno-warn-missing-signatures -fno-warn-unused-do-bind #-}
module Vectorise ( vectorise ) module Vectorise ( vectorise )
where where
...@@ -82,98 +81,124 @@ vectModule guts@(ModGuts { mg_types = types ...@@ -82,98 +81,124 @@ vectModule guts@(ModGuts { mg_types = types
} }
} }
-- | Try to vectorise a top-level binding. -- |Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
-- If it doesn't vectorise then return it unharmed.
-- --
-- For example, for the binding -- For example, for the binding
-- --
-- @ -- @
-- foo :: Int -> Int -- foo :: Int -> Int
-- foo = \x -> x + x -- foo = \x -> x + x
-- @ -- @
--
-- we get
-- @
-- foo :: Int -> Int
-- foo = \x -> vfoo $: x
--
-- v_foo :: Closure void vfoo lfoo
-- v_foo = closure vfoo lfoo void
--
-- vfoo :: Void -> Int -> Int
-- vfoo = ...
-- --
-- lfoo :: PData Void -> PData Int -> PData Int -- we get
-- lfoo = ... -- @
-- @ -- foo :: Int -> Int
-- foo = \x -> vfoo $: x
-- --
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original -- v_foo :: Closure void vfoo lfoo
-- function foo, but takes an explicit environment. -- v_foo = closure vfoo lfoo void
-- --
-- @lfoo@ is the "lifted" version that works on arrays. -- vfoo :: Void -> Int -> Int
-- vfoo = ...
--
-- lfoo :: PData Void -> PData Int -> PData Int
-- lfoo = ...
-- @
-- --
-- @v_foo@ combines both of these into a `Closure` that also contains the -- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
-- environment. -- function foo, but takes an explicit environment.
-- --
-- The original binding @foo@ is rewritten to call the vectorised version -- @lfoo@ is the "lifted" version that works on arrays.
-- present in the closure. --
-- @v_foo@ combines both of these into a `Closure` that also contains the
-- environment.
--
-- The original binding @foo@ is rewritten to call the vectorised version
-- present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
-- the pragma. If only some bindings are annotated, a fatal error is being raised.
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
-- we may emit a warning and refrain from vectorising the entire group.
-- --
vectTopBind :: CoreBind -> VM CoreBind vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr) vectTopBind b@(NonRec var expr)
= do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it to = unlessNoVectDecl $
-- the vectorisation map. do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
; (inline, isScalar, expr') <- vectTopRhs [] var expr -- to the vectorisation map.
; var' <- vectTopBinder var inline expr' ; (inline, isScalar, expr') <- vectTopRhs [] var expr
; when isScalar $ ; var' <- vectTopBinder var inline expr'
addGlobalScalar var ; when isScalar $
addGlobalScalar var
-- We replace the original top-level binding by a value projected from the vectorised
-- closure and add any newly created hoisted top-level bindings. -- We replace the original top-level binding by a value projected from the vectorised
; cexpr <- tryConvert var var' expr -- closure and add any newly created hoisted top-level bindings.
; hs <- takeHoisted ; cexpr <- tryConvert var var' expr
; return . Rec $ (var, cexpr) : (var', expr') : hs ; hs <- takeHoisted
} ; return . Rec $ (var, cexpr) : (var', expr') : hs
`orElseV` }
return b `orElseV`
return b
where
unlessNoVectDecl vectorise
= do { hasNoVectDecl <- noVectDecl var
; when hasNoVectDecl $
traceVt "NOVECTORISE" $ ppr var
; if hasNoVectDecl then return b else vectorise
}
vectTopBind b@(Rec bs) vectTopBind b@(Rec bs)
= let (vars, exprs) = unzip bs = unlessSomeNoVectDecl $
in do { (vars', _, exprs', hs) <- fixV $
do { (vars', _, exprs', hs) <- fixV $ \ ~(_, inlines, rhss, _) ->
\ ~(_, inlines, rhss, _) -> do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
do { -- Vectorise the right-hand sides, create an appropriate top-level bindings and -- and add them to the vectorisation map.
-- add them to the vectorisation map. ; vars' <- sequence [vectTopBinder var inline rhs
; vars' <- sequence [vectTopBinder var inline rhs | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
| (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)] ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs ; hs <- takeHoisted
; hs <- takeHoisted ; if and areScalars
; if and areScalars then -- (1) Entire recursive group is scalar
then -- (1) Entire recursive group is scalar -- => add all variables to the global set of scalars
-- => add all variables to the global set of scalars do { mapM_ addGlobalScalar vars
do { mapM addGlobalScalar vars ; return (vars', inlines, exprs', hs)
; return (vars', inlines, exprs', hs) }
} else -- (2) At least one binding is not scalar
else -- (2) At least one binding is not scalar -- => vectorise again with empty set of local scalars
-- => vectorise again with empty set of local scalars do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs ; hs <- takeHoisted
; hs <- takeHoisted ; return (vars', inlines, exprs', hs)
; return (vars', inlines, exprs', hs)