Commit b77da25e authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.

Rewrote vectorisation avoidance (based on the HS paper)

* Vectorisation avoidance is now the default
* Types and values from unvectorised modules are permitted in scalar code
* Simplified the VECTORISE pragmas (see http://hackage.haskell.org/trac/ghc/wiki/DataParallel/VectPragma for the spec)
* Vectorisation information is now included in the annotated Core AST
parent 2a7217e3
......@@ -328,8 +328,7 @@ breaker, which is perfectly inlinable.
vectsFreeVars :: [CoreVect] -> VarSet
vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet
where
vectFreeVars (Vect _ Nothing) = noFVs
vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (Vect _ rhs) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (NoVect _) = noFVs
vectFreeVars (VectType _ _ _) = noFVs
vectFreeVars (VectClass _) = noFVs
......
......@@ -749,8 +749,7 @@ substVects subst = map (substVect subst)
------------------
substVect :: Subst -> CoreVect -> CoreVect
substVect _subst (Vect v Nothing) = Vect v Nothing
substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs))
substVect subst (Vect v rhs) = Vect v (simpleOptExprWith subst rhs)
substVect _subst vd@(NoVect _) = vd
substVect _subst vd@(VectType _ _ _) = vd
substVect _subst vd@(VectClass _) = vd
......
......@@ -592,11 +592,11 @@ Representation of desugared vectorisation declarations that are fed to the vecto
'ModGuts').
\begin{code}
data CoreVect = Vect Id (Maybe CoreExpr)
data CoreVect = Vect Id CoreExpr
| NoVect Id
| VectType Bool TyCon (Maybe TyCon)
| VectClass TyCon -- class tycon
| VectInst Id -- instance dfun (always SCALAR)
| VectInst Id -- instance dfun (always SCALAR) !!!FIXME: should be superfluous now
\end{code}
......
......@@ -494,8 +494,7 @@ instance Outputable id => Outputable (Tickish id) where
\begin{code}
instance Outputable CoreVect where
ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var
ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
ppr (Vect var e) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
ppr (VectType False var Nothing) = ptext (sLit "VECTORISE type") <+> ppr var
......
......@@ -432,7 +432,7 @@ the rule is precisly to optimise them:
dsVect :: LVectDecl Id -> DsM CoreVect
dsVect (L loc (HsVect (L _ v) rhs))
= putSrcSpanDs loc $
do { rhs' <- fmapMaybeM dsLExpr rhs
do { rhs' <- dsLExpr rhs
; return $ Vect v rhs'
}
dsVect (L _loc (HsNoVect (L _ v)))
......
......@@ -1111,7 +1111,7 @@ type LVectDecl name = Located (VectDecl name)
data VectDecl name
= HsVect
(Located name)
(Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration
(LHsExpr name)
| HsNoVect
(Located name)
| HsVectTypeIn -- pre type-checking
......@@ -1126,9 +1126,9 @@ data VectDecl name
(Located name)
| HsVectClassOut -- post type-checking
Class
| HsVectInstIn -- pre type-checking (always SCALAR)
| HsVectInstIn -- pre type-checking (always SCALAR) !!!FIXME: should be superfluous now
(LHsType name)
| HsVectInstOut -- post type-checking (always SCALAR)
| HsVectInstOut -- post type-checking (always SCALAR) !!!FIXME: should be superfluous now
ClsInst
deriving (Data, Typeable)
......@@ -1148,9 +1148,7 @@ lvectInstDecl (L _ (HsVectInstOut _)) = True
lvectInstDecl _ = False
instance OutputableBndr name => Outputable (VectDecl name) where
ppr (HsVect v Nothing)
= sep [text "{-# VECTORISE SCALAR" <+> ppr v <+> text "#-}" ]
ppr (HsVect v (Just rhs))
ppr (HsVect v rhs)
= sep [text "{-# VECTORISE" <+> ppr v,
nest 4 $
pprExpr (unLoc rhs) <+> text "#-}" ]
......
......@@ -753,15 +753,15 @@ pprVectInfo :: IfaceVectInfo -> SDoc
pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars
, ifaceVectInfoTyCon = tycons
, ifaceVectInfoTyConReuse = tyconsReuse
, ifaceVectInfoScalarVars = scalarVars
, ifaceVectInfoScalarTyCons = scalarTyCons
, ifaceVectInfoParallelVars = parallelVars
, ifaceVectInfoParallelTyCons = parallelTyCons
}) =
vcat
[ ptext (sLit "vectorised variables:") <+> hsep (map ppr vars)
, ptext (sLit "vectorised tycons:") <+> hsep (map ppr tycons)
, ptext (sLit "vectorised reused tycons:") <+> hsep (map ppr tyconsReuse)
, ptext (sLit "scalar variables:") <+> hsep (map ppr scalarVars)
, ptext (sLit "scalar tycons:") <+> hsep (map ppr scalarTyCons)
, ptext (sLit "parallel variables:") <+> hsep (map ppr parallelVars)
, ptext (sLit "parallel tycons:") <+> hsep (map ppr parallelTyCons)
]
pprTrustInfo :: IfaceTrustInfo -> SDoc
......
......@@ -375,15 +375,15 @@ mkIface_ hsc_env maybe_old_fingerprint
flattenVectInfo (VectInfo { vectInfoVar = vVar
, vectInfoTyCon = vTyCon
, vectInfoScalarVars = vScalarVars
, vectInfoScalarTyCons = vScalarTyCons
, vectInfoParallelVars = vParallelVars
, vectInfoParallelTyCons = vParallelTyCons
}) =
IfaceVectInfo
{ ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar]
, ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v]
, ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v]
, ifaceVectInfoScalarVars = [Var.varName v | v <- varSetElems vScalarVars]
, ifaceVectInfoScalarTyCons = nameSetToList vScalarTyCons
, ifaceVectInfoParallelVars = [Var.varName v | v <- varSetElems vParallelVars]
, ifaceVectInfoParallelTyCons = nameSetToList vParallelTyCons
}
-----------------------------
......
......@@ -751,22 +751,22 @@ tcIfaceVectInfo mod typeEnv (IfaceVectInfo
{ ifaceVectInfoVar = vars
, ifaceVectInfoTyCon = tycons
, ifaceVectInfoTyConReuse = tyconsReuse
, ifaceVectInfoScalarVars = scalarVars
, ifaceVectInfoScalarTyCons = scalarTyCons
, ifaceVectInfoParallelVars = parallelVars
, ifaceVectInfoParallelTyCons = parallelTyCons
})
= do { let scalarTyConsSet = mkNameSet scalarTyCons
= do { let parallelTyConsSet = mkNameSet parallelTyCons
; vVars <- mapM vectVarMapping vars
; let varsSet = mkVarSet (map fst vVars)
; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons
; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse
; vScalarVars <- mapM vectVar scalarVars
; vParallelVars <- mapM vectVar parallelVars
; let (vTyCons, vDataCons, vScSels) = unzip3 (tyConRes1 ++ tyConRes2)
; return $ VectInfo
{ vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels
, vectInfoTyCon = mkNameEnv vTyCons
, vectInfoDataCon = mkNameEnv (concat vDataCons)
, vectInfoScalarVars = mkVarSet vScalarVars
, vectInfoScalarTyCons = scalarTyConsSet
, vectInfoParallelVars = mkVarSet vParallelVars
, vectInfoParallelTyCons = parallelTyConsSet
}
}
where
......
......@@ -1971,8 +1971,8 @@ data VectInfo
{ vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@
, vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@
, vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@
, vectInfoScalarVars :: VarSet -- ^ set of purely scalar variables
, vectInfoScalarTyCons :: NameSet -- ^ set of scalar type constructors
, vectInfoParallelVars :: VarSet -- ^ set of parallel variables
, vectInfoParallelTyCons :: NameSet -- ^ set of parallel type constructors
}
-- |Vectorisation information for 'ModIface'; i.e, the vectorisation information propagated
......@@ -1996,8 +1996,8 @@ data IfaceVectInfo
, ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here
-- coincides with the unconverted form; the name of the
-- isomorphisms is determined by 'OccName.mkVectIsoOcc'
, ifaceVectInfoScalarVars :: [Name] -- iface version of 'vectInfoScalarVar'
, ifaceVectInfoScalarTyCons :: [Name] -- iface version of 'vectInfoScalarTyCon'
, ifaceVectInfoParallelVars :: [Name] -- iface version of 'vectInfoParallelVar'
, ifaceVectInfoParallelTyCons :: [Name] -- iface version of 'vectInfoParallelTyCon'
}
noVectInfo :: VectInfo
......@@ -2009,8 +2009,8 @@ plusVectInfo vi1 vi2 =
VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2)
(vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2)
(vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2)
(vectInfoScalarVars vi1 `unionVarSet` vectInfoScalarVars vi2)
(vectInfoScalarTyCons vi1 `unionNameSets` vectInfoScalarTyCons vi2)
(vectInfoParallelVars vi1 `unionVarSet` vectInfoParallelVars vi2)
(vectInfoParallelTyCons vi1 `unionNameSets` vectInfoParallelTyCons vi2)
concatVectInfo :: [VectInfo] -> VectInfo
concatVectInfo = foldr plusVectInfo noVectInfo
......@@ -2027,8 +2027,8 @@ instance Outputable VectInfo where
[ ptext (sLit "variables :") <+> ppr (vectInfoVar info)
, ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info)
, ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info)
, ptext (sLit "scalar vars :") <+> ppr (vectInfoScalarVars info)
, ptext (sLit "scalar tycons :") <+> ppr (vectInfoScalarTyCons info)
, ptext (sLit "parallel vars :") <+> ppr (vectInfoParallelVars info)
, ptext (sLit "parallel tycons :") <+> ppr (vectInfoParallelTyCons info)
]
\end{code}
......
......@@ -542,10 +542,10 @@ tidyInstances tidy_dfun ispecs
\begin{code}
tidyVectInfo :: TidyEnv -> VectInfo -> VectInfo
tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
, vectInfoScalarVars = scalarVars
, vectInfoParallelVars = parallelVars
})
= info { vectInfoVar = tidy_vars
, vectInfoScalarVars = tidy_scalarVars
, vectInfoParallelVars = tidy_parallelVars
}
where
-- we only export mappings whose domain and co-domain is exported (otherwise, the iface is
......@@ -559,8 +559,8 @@ tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
, isDataConWorkId var || not (isImplicitId var)
]
tidy_scalarVars = mkVarSet [ lookup_var var
| var <- varSetElems scalarVars
tidy_parallelVars = mkVarSet [ lookup_var var
| var <- varSetElems parallelVars
, isGlobalId var || isExportedId var]
lookup_var var = lookupWithDefaultVarEnv var_env var var
......
......@@ -577,8 +577,7 @@ topdecl :: { OrdList (LHsDecl RdrName) }
| '{-# DEPRECATED' deprecations '#-}' { $2 }
| '{-# WARNING' warnings '#-}' { $2 }
| '{-# RULES' rules '#-}' { $2 }
| '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 $4) }
| '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| '{-# VECTORISE' 'type' gtycon '#-}'
{ unitOL $ LL $
......@@ -593,8 +592,6 @@ topdecl :: { OrdList (LHsDecl RdrName) }
{ unitOL $ LL $
VectD (HsVectTypeIn True $3 (Just $5)) }
| '{-# VECTORISE' 'class' gtycon '#-}' { unitOL $ LL $ VectD (HsVectClassIn $3) }
| '{-# VECTORISE_SCALAR' 'instance' type '#-}'
{ unitOL $ LL $ VectD (HsVectInstIn $3) }
| annotation { unitOL $1 }
| decl { unLoc $1 }
......
......@@ -723,18 +723,14 @@ badRuleLhsErr name lhs bad_e
\begin{code}
rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars)
rnHsVectDecl (HsVect var Nothing)
= do { var' <- lookupLocatedOccRn var
; return (HsVect var' Nothing, unitFV (unLoc var'))
}
-- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly
-- typecheck a complex right-hand side without invoking 'vectType' from the vectoriser.
rnHsVectDecl (HsVect var (Just rhs@(L _ (HsVar _))))
rnHsVectDecl (HsVect var rhs@(L _ (HsVar _)))
= do { var' <- lookupLocatedOccRn var
; (rhs', fv_rhs) <- rnLExpr rhs
; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var')
; return (HsVect var' rhs', fv_rhs `addOneFV` unLoc var')
}
rnHsVectDecl (HsVect _var (Just _rhs))
rnHsVectDecl (HsVect _var _rhs)
= failWith $ vcat
[ ptext (sLit "IMPLEMENTATION RESTRICTION: right-hand side of a VECTORISE pragma")
, ptext (sLit "must be an identifier")
......
......@@ -739,17 +739,12 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId)
-- during type checking. Instead, constrain the rhs of a vectorisation declaration to be a single
-- identifier (this is checked in 'rnHsVectDecl'). Fix this by enabling the use of 'vectType'
-- from the vectoriser here.
tcVect (HsVect name Nothing)
= addErrCtxt (vectCtxt name) $
do { var <- wrapLocM tcLookupId name
; return $ HsVect var Nothing
}
tcVect (HsVect name (Just rhs))
tcVect (HsVect name rhs)
= addErrCtxt (vectCtxt name) $
do { var <- wrapLocM tcLookupId name
; let L rhs_loc (HsVar rhs_var_name) = rhs
; rhs_id <- tcLookupId rhs_var_name
; return $ HsVect var (Just $ L rhs_loc (HsVar rhs_id))
; return $ HsVect var (L rhs_loc (HsVar rhs_id))
}
{- OLD CODE:
......
......@@ -1081,7 +1081,7 @@ zonkVects env = mappM (wrapLocM (zonkVect env))
zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id)
zonkVect env (HsVect v e)
= do { v' <- wrapLocM (zonkIdBndr env) v
; e' <- fmapMaybeM (zonkLExpr env) e
; e' <- zonkLExpr env e
; return $ HsVect v' e'
}
zonkVect env (HsNoVect v)
......
......@@ -13,26 +13,22 @@ import Vectorise.Type.Type
import Vectorise.Convert
import Vectorise.Utils.Hoisting
import Vectorise.Exp
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import HscTypes hiding ( MonadThings(..) )
import CoreUnfold ( mkInlineUnfolding )
import CoreFVs
import PprCore
import CoreSyn
import CoreMonad ( CoreM, getHscEnv )
import Type
import Id
import DynFlags
import BasicTypes ( isStrongLoopBreaker )
import Outputable
import Util ( zipLazy )
import MonadUtils
import Control.Monad
import Data.Maybe
-- |Vectorise a single module.
......@@ -69,7 +65,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons
= do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
pprCoreBindings binds
-- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas
-- Pick out all 'VECTORISE [SCALAR] type' and 'VECTORISE class' pragmas
; let ty_vect_decls = [vd | vd@(VectType _ _ _) <- vect_decls]
cls_vect_decls = [vd | vd@(VectClass _) <- vect_decls]
......@@ -87,8 +83,7 @@ vectModule guts@(ModGuts { mg_tcs = tycons
-- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
-- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++
[imp_id | VectInst imp_id <- vect_decls, isGlobalId imp_id]
; let impBinds = [(imp_id, expr) | Vect imp_id expr <- vect_decls, isGlobalId imp_id]
; binds_imp <- mapM vectImpBind impBinds
; binds_top <- mapM vectTopBind binds
......@@ -101,7 +96,8 @@ vectModule guts@(ModGuts { mg_tcs = tycons
}
}
-- Try to vectorise a top-level binding. If it doesn't vectorise then return it unharmed.
-- Try to vectorise a top-level binding. If it doesn't vectorise, or if it is entirely scalar, then
-- omit vectorisation of that binding.
--
-- For example, for the binding
--
......@@ -125,129 +121,173 @@ vectModule guts@(ModGuts { mg_tcs = tycons
-- lfoo = ...
-- @
--
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
-- function foo, but takes an explicit environment.
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original function foo,
-- but takes an explicit environment.
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
-- @v_foo@ combines both of these into a `Closure` that also contains the
-- environment.
-- @v_foo@ combines both of these into a `Closure` that also contains the environment.
--
-- The original binding @foo@ is rewritten to call the vectorised version
-- present in the closure.
-- The original binding @foo@ is rewritten to call the vectorised version present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma. If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
-- the pragma. If only some bindings are annotated, a fatal error is being raised.
-- the pragma. If only some bindings are annotated, a fatal error is being raised. (In the case of
-- scalar bindings, we only omit vectorisation if all bindings in a group are scalar.)
--
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
-- we may emit a warning and refrain from vectorising the entire group.
--
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= unlessNoVectDecl $
do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
-- to the vectorisation map.
; (inline, isScalar, expr') <- vectTopRhs [] var expr
= do
{ traceVt "= Vectorise non-recursive top-level variable" (ppr var)
; (hasNoVect, vectDecl) <- lookupVectDecl var
; if hasNoVect
then do
{ -- 'NOVECTORISE' pragma => leave this binding as it is
; traceVt "NOVECTORISE" $ ppr var
; return b
}
else do
{ vectRhs <- case vectDecl of
Just (_, expr') ->
-- 'VECTORISE' pragma => just use the provided vectorised rhs
do
{ traceVt "VECTORISE" $ ppr var
; return $ Just (False, inlineMe, expr')
}
Nothing ->
-- no pragma => standard vectorisation of rhs
do
{ traceVt "[Vanilla]" $ ppr var <+> char '=' <+> ppr expr
; vectTopExpr var expr
}
; hs <- takeHoisted -- make sure we clean those out (even if we skip)
; case vectRhs of
{ Nothing ->
-- scalar binding => leave this binding as it is
do
{ traceVt "scalar binding [skip]" $ ppr var
; return b
}
; Just (parBind, inline, expr') -> do
{
-- vanilla case => create an appropriate top-level binding & add it to the vectorisation map
; when parBind $
addGlobalParallelVar var
; var' <- vectTopBinder var inline expr'
; when isScalar $
addGlobalScalarVar var
-- We replace the original top-level binding by a value projected from the vectorised
-- closure and add any newly created hoisted top-level bindings.
; cexpr <- tryConvert var var' expr
; hs <- takeHoisted
; return . Rec $ (var, cexpr) : (var', expr') : hs
}
} } } }
`orElseErrV`
do { emitVt " Could NOT vectorise top-level binding" $ ppr var
do
{ emitVt " Could NOT vectorise top-level binding" $ ppr var
; return b
}
where
unlessNoVectDecl vectorise
= do { hasNoVectDecl <- noVectDecl var
; when hasNoVectDecl $
traceVt "NOVECTORISE" $ ppr var
; if hasNoVectDecl then return b else vectorise
vectTopBind b@(Rec binds)
= do
{ traceVt "= Vectorise recursive top-level variables" $ ppr vars
; vectDecls <- mapM lookupVectDecl vars
; let hasNoVects = map fst vectDecls
; if and hasNoVects
then do
{ -- 'NOVECTORISE' pragmas => leave this entire binding group as it is
; traceVt "NOVECTORISE" $ ppr vars
; return b
}
vectTopBind b@(Rec bs)
= unlessSomeNoVectDecl $
do { (vars', _, exprs', hs) <- fixV $
\ ~(_, inlines, rhss, _) ->
do { -- Vectorise the right-hand sides, create an appropriate top-level bindings
-- and add them to the vectorisation map.
; vars' <- sequence [vectTopBinder var inline rhs
| (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
; hs <- takeHoisted
; if and areScalars
then -- (1) Entire recursive group is scalar
-- => add all variables to the global set of scalars
do { mapM_ addGlobalScalarVar vars
; return (vars', inlines, exprs', hs)
else do
{ if or hasNoVects
then do
{ -- Inconsistent 'NOVECTORISE' pragmas => bail out
; dflags <- getDynFlags
; cantVectorise dflags noVectoriseErr (ppr b)
}
else -- (2) At least one binding is not scalar
-- => vectorise again with empty set of local scalars
do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
; hs <- takeHoisted
; return (vars', inlines, exprs', hs)
else do
{ -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
; newBindsWPragma <- concat <$>
sequence [ vectTopBindAndConvert bind inlineMe expr'
| (bind, (_, Just (_, expr'))) <- zip binds vectDecls]
-- Standard vectorisation of all rhses that are *without* a pragma.
-- NB: The reason for 'fixV' is rather subtle: 'vectTopBindAndConvert' adds entries for
-- the bound variables in the recursive group to the vectorisation map, which in turn
-- are needed by 'vectPolyExprs' (unless it returns 'Nothing').
; let bindsWOPragma = [bind | (bind, (_, Nothing)) <- zip binds vectDecls]
; (newBinds, _) <- fixV $
\ ~(_, exprs') ->
do
{ -- Create appropriate top-level bindings, enter them into the vectorisation map, and
-- vectorise the right-hand sides
; newBindsWOPragma <- concat <$>
sequence [vectTopBindAndConvert bind inline expr
| (bind, ~(inline, expr)) <- zipLazy bindsWOPragma exprs']
-- irrefutable pattern and 'zipLazy' to tie the knot;
-- hence, can't use 'zipWithM'
; vectRhses <- vectTopExprs bindsWOPragma
; hs <- takeHoisted -- make sure we clean those out (even if we skip)
; case vectRhses of
Nothing ->
-- scalar bindings => skip all bindings except those with pragmas and retract the
-- entries into the vectorisation map for the scalar bindings
do
{ traceVt "scalar bindings [skip]" $ ppr vars
; mapM_ (undefGlobalVar . fst) bindsWOPragma
; return (bindsWOPragma ++ newBindsWPragma, exprs')
}
Just (parBind, exprs') ->
-- vanilla case => record parallel variables and return the final bindings
do
{ when parBind $
mapM_ addGlobalParallelVar vars
; return (newBindsWOPragma ++ newBindsWPragma ++ hs, exprs')
}
-- Replace the original top-level bindings by a values projected from the vectorised
-- closures and add any newly created hoisted top-level bindings to the group.
; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
}
; return $ Rec newBinds
} } }
`orElseErrV`
return b
do
{ emitVt " Could NOT vectorise top-level bindings" $ ppr vars
; return b
}
where
(vars, exprs) = unzip bs
vars = map fst binds
noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
unlessSomeNoVectDecl vectorise
= do { hasNoVectDecls <- mapM noVectDecl vars
; when (and hasNoVectDecls) $
traceVt "NOVECTORISE" $ ppr vars
; if and hasNoVectDecls
then return b -- all bindings have 'NOVECTORISE'
else if or hasNoVectDecls
then do dflags <- getDynFlags
cantVectorise dflags noVectoriseErr (ppr b) -- some (but not all) have 'NOVECTORISE'
else vectorise -- no binding has a 'NOVECTORISE' decl
-- Replace the original top-level bindings by a values projected from the vectorised
-- closures and add any newly created hoisted top-level bindings to the group.
vectTopBindAndConvert (var, expr) inline expr'
= do
{ var' <- vectTopBinder var inline expr'
; cexpr <- tryConvert var var' expr
; return [(var, cexpr), (var', expr')]
}
noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
-- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
-- Add a vectorised binding to an imported top-level variable that has a VECTORISE pragma
-- in this module.
--
-- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions.
-- RESTIRCTION: Currently, we cannot use the pragma for mutually recursive definitions.
--
vectImpBind :: Id -> VM CoreBind
vectImpBind var
= do { -- Vectorise the right-hand side, create an appropriate top-level binding and add it
-- to the vectorisation map. For the non-lifted version, we refer to the original
-- definition — i.e., 'Var var'.
-- NB: To support recursive definitions, we tie a lazy knot.
; (var', _, expr') <- fixV $
\ ~(_, inline, rhs) ->
do { var' <- vectTopBinder var inline rhs
; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)
vectImpBind :: (Id, CoreExpr) -> VM CoreBind
vectImpBind (var, expr)
= do
{ traceVt "= Add vectorised binding to imported variable" (ppr var)
; when isScalar $
addGlobalScalarVar var
; return (var', inline, expr')
; var' <- vectTopBinder var inlineMe expr
; return $ NonRec var' expr
}
-- We add any newly created hoisted top-level bindings.
; hs <- takeHoisted
; return . Rec $ (var', expr') : hs
}
-- | Make the vectorised version of this top level binder, and add the mapping
-- between it and the original to the state. For some binder @foo@ the vectorised
-- version is @$v_foo@
-- |Make the vectorised version of this top level binder, and add the mapping between it and the
-- original to the state. For some binder @foo@ the vectorised version is @$v_foo@
--
-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
-- used inside of 'fixV' in 'vectTopBind'.
-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is used inside of
-- 'fixV' in 'vectTopBind'.
--
vectTopBinder :: Var -- ^ Name of the binding.
-> Inline -- ^ Whether it should be inlined, used to annotate it.
......@@ -257,20 +297,20 @@ vectTopBinder var inline expr
= do { -- Vectorise the type attached to the var.
; vty <- vectType (idType var)
-- If there is a vectorisation declartion for this binding, make sure that its type
-- matches
; vectDecl <- lookupVectDecl var
-- If there is a vectorisation declartion for this binding, make sure its type matches
; (_, vectDecl) <- lookupVectDecl var
; case vectDecl of
Nothing -> return ()
Just (vdty, _)
| eqType vty vdty -> return ()
| otherwise ->
do dflags <- getDynFlags
cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
do
{ dflags <- getDynFlags
; cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
(text "Expected type" <+> ppr vty)
$$
(text "Inferred type" <+> ppr vdty)
}
-- Make the vectorised version of binding's name, and set the unfolding used for inlining
; var' <- liftM (`setIdUnfoldingLazily` unfolding)