Commit b77da25e authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.

Rewrote vectorisation avoidance (based on the HS paper)

* Vectorisation avoidance is now the default
* Types and values from unvectorised modules are permitted in scalar code
* Simplified the VECTORISE pragmas (see http://hackage.haskell.org/trac/ghc/wiki/DataParallel/VectPragma for the spec)
* Vectorisation information is now included in the annotated Core AST
parent 2a7217e3
...@@ -328,8 +328,7 @@ breaker, which is perfectly inlinable. ...@@ -328,8 +328,7 @@ breaker, which is perfectly inlinable.
vectsFreeVars :: [CoreVect] -> VarSet vectsFreeVars :: [CoreVect] -> VarSet
vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet
where where
vectFreeVars (Vect _ Nothing) = noFVs vectFreeVars (Vect _ rhs) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (NoVect _) = noFVs vectFreeVars (NoVect _) = noFVs
vectFreeVars (VectType _ _ _) = noFVs vectFreeVars (VectType _ _ _) = noFVs
vectFreeVars (VectClass _) = noFVs vectFreeVars (VectClass _) = noFVs
......
...@@ -749,8 +749,7 @@ substVects subst = map (substVect subst) ...@@ -749,8 +749,7 @@ substVects subst = map (substVect subst)
------------------ ------------------
substVect :: Subst -> CoreVect -> CoreVect substVect :: Subst -> CoreVect -> CoreVect
substVect _subst (Vect v Nothing) = Vect v Nothing substVect subst (Vect v rhs) = Vect v (simpleOptExprWith subst rhs)
substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs))
substVect _subst vd@(NoVect _) = vd substVect _subst vd@(NoVect _) = vd
substVect _subst vd@(VectType _ _ _) = vd substVect _subst vd@(VectType _ _ _) = vd
substVect _subst vd@(VectClass _) = vd substVect _subst vd@(VectClass _) = vd
......
...@@ -592,11 +592,11 @@ Representation of desugared vectorisation declarations that are fed to the vecto ...@@ -592,11 +592,11 @@ Representation of desugared vectorisation declarations that are fed to the vecto
'ModGuts'). 'ModGuts').
\begin{code} \begin{code}
data CoreVect = Vect Id (Maybe CoreExpr) data CoreVect = Vect Id CoreExpr
| NoVect Id | NoVect Id
| VectType Bool TyCon (Maybe TyCon) | VectType Bool TyCon (Maybe TyCon)
| VectClass TyCon -- class tycon | VectClass TyCon -- class tycon
| VectInst Id -- instance dfun (always SCALAR) | VectInst Id -- instance dfun (always SCALAR) !!!FIXME: should be superfluous now
\end{code} \end{code}
......
...@@ -494,8 +494,7 @@ instance Outputable id => Outputable (Tickish id) where ...@@ -494,8 +494,7 @@ instance Outputable id => Outputable (Tickish id) where
\begin{code} \begin{code}
instance Outputable CoreVect where instance Outputable CoreVect where
ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var ppr (Vect var e) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e) 4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
ppr (VectType False var Nothing) = ptext (sLit "VECTORISE type") <+> ppr var ppr (VectType False var Nothing) = ptext (sLit "VECTORISE type") <+> ppr var
......
...@@ -432,7 +432,7 @@ the rule is precisly to optimise them: ...@@ -432,7 +432,7 @@ the rule is precisly to optimise them:
dsVect :: LVectDecl Id -> DsM CoreVect dsVect :: LVectDecl Id -> DsM CoreVect
dsVect (L loc (HsVect (L _ v) rhs)) dsVect (L loc (HsVect (L _ v) rhs))
= putSrcSpanDs loc $ = putSrcSpanDs loc $
do { rhs' <- fmapMaybeM dsLExpr rhs do { rhs' <- dsLExpr rhs
; return $ Vect v rhs' ; return $ Vect v rhs'
} }
dsVect (L _loc (HsNoVect (L _ v))) dsVect (L _loc (HsNoVect (L _ v)))
......
...@@ -1111,7 +1111,7 @@ type LVectDecl name = Located (VectDecl name) ...@@ -1111,7 +1111,7 @@ type LVectDecl name = Located (VectDecl name)
data VectDecl name data VectDecl name
= HsVect = HsVect
(Located name) (Located name)
(Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration (LHsExpr name)
| HsNoVect | HsNoVect
(Located name) (Located name)
| HsVectTypeIn -- pre type-checking | HsVectTypeIn -- pre type-checking
...@@ -1126,9 +1126,9 @@ data VectDecl name ...@@ -1126,9 +1126,9 @@ data VectDecl name
(Located name) (Located name)
| HsVectClassOut -- post type-checking | HsVectClassOut -- post type-checking
Class Class
| HsVectInstIn -- pre type-checking (always SCALAR) | HsVectInstIn -- pre type-checking (always SCALAR) !!!FIXME: should be superfluous now
(LHsType name) (LHsType name)
| HsVectInstOut -- post type-checking (always SCALAR) | HsVectInstOut -- post type-checking (always SCALAR) !!!FIXME: should be superfluous now
ClsInst ClsInst
deriving (Data, Typeable) deriving (Data, Typeable)
...@@ -1148,9 +1148,7 @@ lvectInstDecl (L _ (HsVectInstOut _)) = True ...@@ -1148,9 +1148,7 @@ lvectInstDecl (L _ (HsVectInstOut _)) = True
lvectInstDecl _ = False lvectInstDecl _ = False
instance OutputableBndr name => Outputable (VectDecl name) where instance OutputableBndr name => Outputable (VectDecl name) where
ppr (HsVect v Nothing) ppr (HsVect v rhs)
= sep [text "{-# VECTORISE SCALAR" <+> ppr v <+> text "#-}" ]
ppr (HsVect v (Just rhs))
= sep [text "{-# VECTORISE" <+> ppr v, = sep [text "{-# VECTORISE" <+> ppr v,
nest 4 $ nest 4 $
pprExpr (unLoc rhs) <+> text "#-}" ] pprExpr (unLoc rhs) <+> text "#-}" ]
......
...@@ -753,15 +753,15 @@ pprVectInfo :: IfaceVectInfo -> SDoc ...@@ -753,15 +753,15 @@ pprVectInfo :: IfaceVectInfo -> SDoc
pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars pprVectInfo (IfaceVectInfo { ifaceVectInfoVar = vars
, ifaceVectInfoTyCon = tycons , ifaceVectInfoTyCon = tycons
, ifaceVectInfoTyConReuse = tyconsReuse , ifaceVectInfoTyConReuse = tyconsReuse
, ifaceVectInfoScalarVars = scalarVars , ifaceVectInfoParallelVars = parallelVars
, ifaceVectInfoScalarTyCons = scalarTyCons , ifaceVectInfoParallelTyCons = parallelTyCons
}) = }) =
vcat vcat
[ ptext (sLit "vectorised variables:") <+> hsep (map ppr vars) [ ptext (sLit "vectorised variables:") <+> hsep (map ppr vars)
, ptext (sLit "vectorised tycons:") <+> hsep (map ppr tycons) , ptext (sLit "vectorised tycons:") <+> hsep (map ppr tycons)
, ptext (sLit "vectorised reused tycons:") <+> hsep (map ppr tyconsReuse) , ptext (sLit "vectorised reused tycons:") <+> hsep (map ppr tyconsReuse)
, ptext (sLit "scalar variables:") <+> hsep (map ppr scalarVars) , ptext (sLit "parallel variables:") <+> hsep (map ppr parallelVars)
, ptext (sLit "scalar tycons:") <+> hsep (map ppr scalarTyCons) , ptext (sLit "parallel tycons:") <+> hsep (map ppr parallelTyCons)
] ]
pprTrustInfo :: IfaceTrustInfo -> SDoc pprTrustInfo :: IfaceTrustInfo -> SDoc
......
...@@ -375,15 +375,15 @@ mkIface_ hsc_env maybe_old_fingerprint ...@@ -375,15 +375,15 @@ mkIface_ hsc_env maybe_old_fingerprint
flattenVectInfo (VectInfo { vectInfoVar = vVar flattenVectInfo (VectInfo { vectInfoVar = vVar
, vectInfoTyCon = vTyCon , vectInfoTyCon = vTyCon
, vectInfoScalarVars = vScalarVars , vectInfoParallelVars = vParallelVars
, vectInfoScalarTyCons = vScalarTyCons , vectInfoParallelTyCons = vParallelTyCons
}) = }) =
IfaceVectInfo IfaceVectInfo
{ ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar] { ifaceVectInfoVar = [Var.varName v | (v, _ ) <- varEnvElts vVar]
, ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v] , ifaceVectInfoTyCon = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t /= t_v]
, ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v] , ifaceVectInfoTyConReuse = [tyConName t | (t, t_v) <- nameEnvElts vTyCon, t == t_v]
, ifaceVectInfoScalarVars = [Var.varName v | v <- varSetElems vScalarVars] , ifaceVectInfoParallelVars = [Var.varName v | v <- varSetElems vParallelVars]
, ifaceVectInfoScalarTyCons = nameSetToList vScalarTyCons , ifaceVectInfoParallelTyCons = nameSetToList vParallelTyCons
} }
----------------------------- -----------------------------
......
...@@ -751,22 +751,22 @@ tcIfaceVectInfo mod typeEnv (IfaceVectInfo ...@@ -751,22 +751,22 @@ tcIfaceVectInfo mod typeEnv (IfaceVectInfo
{ ifaceVectInfoVar = vars { ifaceVectInfoVar = vars
, ifaceVectInfoTyCon = tycons , ifaceVectInfoTyCon = tycons
, ifaceVectInfoTyConReuse = tyconsReuse , ifaceVectInfoTyConReuse = tyconsReuse
, ifaceVectInfoScalarVars = scalarVars , ifaceVectInfoParallelVars = parallelVars
, ifaceVectInfoScalarTyCons = scalarTyCons , ifaceVectInfoParallelTyCons = parallelTyCons
}) })
= do { let scalarTyConsSet = mkNameSet scalarTyCons = do { let parallelTyConsSet = mkNameSet parallelTyCons
; vVars <- mapM vectVarMapping vars ; vVars <- mapM vectVarMapping vars
; let varsSet = mkVarSet (map fst vVars) ; let varsSet = mkVarSet (map fst vVars)
; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons ; tyConRes1 <- mapM (vectTyConVectMapping varsSet) tycons
; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse ; tyConRes2 <- mapM (vectTyConReuseMapping varsSet) tyconsReuse
; vScalarVars <- mapM vectVar scalarVars ; vParallelVars <- mapM vectVar parallelVars
; let (vTyCons, vDataCons, vScSels) = unzip3 (tyConRes1 ++ tyConRes2) ; let (vTyCons, vDataCons, vScSels) = unzip3 (tyConRes1 ++ tyConRes2)
; return $ VectInfo ; return $ VectInfo
{ vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels { vectInfoVar = mkVarEnv vVars `extendVarEnvList` concat vScSels
, vectInfoTyCon = mkNameEnv vTyCons , vectInfoTyCon = mkNameEnv vTyCons
, vectInfoDataCon = mkNameEnv (concat vDataCons) , vectInfoDataCon = mkNameEnv (concat vDataCons)
, vectInfoScalarVars = mkVarSet vScalarVars , vectInfoParallelVars = mkVarSet vParallelVars
, vectInfoScalarTyCons = scalarTyConsSet , vectInfoParallelTyCons = parallelTyConsSet
} }
} }
where where
......
...@@ -1971,8 +1971,8 @@ data VectInfo ...@@ -1971,8 +1971,8 @@ data VectInfo
{ vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@ { vectInfoVar :: VarEnv (Var , Var ) -- ^ @(f, f_v)@ keyed on @f@
, vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@ , vectInfoTyCon :: NameEnv (TyCon , TyCon) -- ^ @(T, T_v)@ keyed on @T@
, vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@ , vectInfoDataCon :: NameEnv (DataCon, DataCon) -- ^ @(C, C_v)@ keyed on @C@
, vectInfoScalarVars :: VarSet -- ^ set of purely scalar variables , vectInfoParallelVars :: VarSet -- ^ set of parallel variables
, vectInfoScalarTyCons :: NameSet -- ^ set of scalar type constructors , vectInfoParallelTyCons :: NameSet -- ^ set of parallel type constructors
} }
-- |Vectorisation information for 'ModIface'; i.e, the vectorisation information propagated -- |Vectorisation information for 'ModIface'; i.e, the vectorisation information propagated
...@@ -1996,8 +1996,8 @@ data IfaceVectInfo ...@@ -1996,8 +1996,8 @@ data IfaceVectInfo
, ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here , ifaceVectInfoTyConReuse :: [Name] -- ^ The vectorised form of all the 'TyCon's in here
-- coincides with the unconverted form; the name of the -- coincides with the unconverted form; the name of the
-- isomorphisms is determined by 'OccName.mkVectIsoOcc' -- isomorphisms is determined by 'OccName.mkVectIsoOcc'
, ifaceVectInfoScalarVars :: [Name] -- iface version of 'vectInfoScalarVar' , ifaceVectInfoParallelVars :: [Name] -- iface version of 'vectInfoParallelVar'
, ifaceVectInfoScalarTyCons :: [Name] -- iface version of 'vectInfoScalarTyCon' , ifaceVectInfoParallelTyCons :: [Name] -- iface version of 'vectInfoParallelTyCon'
} }
noVectInfo :: VectInfo noVectInfo :: VectInfo
...@@ -2009,8 +2009,8 @@ plusVectInfo vi1 vi2 = ...@@ -2009,8 +2009,8 @@ plusVectInfo vi1 vi2 =
VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2) VectInfo (vectInfoVar vi1 `plusVarEnv` vectInfoVar vi2)
(vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2) (vectInfoTyCon vi1 `plusNameEnv` vectInfoTyCon vi2)
(vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2) (vectInfoDataCon vi1 `plusNameEnv` vectInfoDataCon vi2)
(vectInfoScalarVars vi1 `unionVarSet` vectInfoScalarVars vi2) (vectInfoParallelVars vi1 `unionVarSet` vectInfoParallelVars vi2)
(vectInfoScalarTyCons vi1 `unionNameSets` vectInfoScalarTyCons vi2) (vectInfoParallelTyCons vi1 `unionNameSets` vectInfoParallelTyCons vi2)
concatVectInfo :: [VectInfo] -> VectInfo concatVectInfo :: [VectInfo] -> VectInfo
concatVectInfo = foldr plusVectInfo noVectInfo concatVectInfo = foldr plusVectInfo noVectInfo
...@@ -2027,8 +2027,8 @@ instance Outputable VectInfo where ...@@ -2027,8 +2027,8 @@ instance Outputable VectInfo where
[ ptext (sLit "variables :") <+> ppr (vectInfoVar info) [ ptext (sLit "variables :") <+> ppr (vectInfoVar info)
, ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info) , ptext (sLit "tycons :") <+> ppr (vectInfoTyCon info)
, ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info) , ptext (sLit "datacons :") <+> ppr (vectInfoDataCon info)
, ptext (sLit "scalar vars :") <+> ppr (vectInfoScalarVars info) , ptext (sLit "parallel vars :") <+> ppr (vectInfoParallelVars info)
, ptext (sLit "scalar tycons :") <+> ppr (vectInfoScalarTyCons info) , ptext (sLit "parallel tycons :") <+> ppr (vectInfoParallelTyCons info)
] ]
\end{code} \end{code}
......
...@@ -542,10 +542,10 @@ tidyInstances tidy_dfun ispecs ...@@ -542,10 +542,10 @@ tidyInstances tidy_dfun ispecs
\begin{code} \begin{code}
tidyVectInfo :: TidyEnv -> VectInfo -> VectInfo tidyVectInfo :: TidyEnv -> VectInfo -> VectInfo
tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
, vectInfoScalarVars = scalarVars , vectInfoParallelVars = parallelVars
}) })
= info { vectInfoVar = tidy_vars = info { vectInfoVar = tidy_vars
, vectInfoScalarVars = tidy_scalarVars , vectInfoParallelVars = tidy_parallelVars
} }
where where
-- we only export mappings whose domain and co-domain is exported (otherwise, the iface is -- we only export mappings whose domain and co-domain is exported (otherwise, the iface is
...@@ -559,8 +559,8 @@ tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars ...@@ -559,8 +559,8 @@ tidyVectInfo (_, var_env) info@(VectInfo { vectInfoVar = vars
, isDataConWorkId var || not (isImplicitId var) , isDataConWorkId var || not (isImplicitId var)
] ]
tidy_scalarVars = mkVarSet [ lookup_var var tidy_parallelVars = mkVarSet [ lookup_var var
| var <- varSetElems scalarVars | var <- varSetElems parallelVars
, isGlobalId var || isExportedId var] , isGlobalId var || isExportedId var]
lookup_var var = lookupWithDefaultVarEnv var_env var var lookup_var var = lookupWithDefaultVarEnv var_env var var
......
...@@ -577,8 +577,7 @@ topdecl :: { OrdList (LHsDecl RdrName) } ...@@ -577,8 +577,7 @@ topdecl :: { OrdList (LHsDecl RdrName) }
| '{-# DEPRECATED' deprecations '#-}' { $2 } | '{-# DEPRECATED' deprecations '#-}' { $2 }
| '{-# WARNING' warnings '#-}' { $2 } | '{-# WARNING' warnings '#-}' { $2 }
| '{-# RULES' rules '#-}' { $2 } | '{-# RULES' rules '#-}' { $2 }
| '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) } | '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 $4) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
| '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) } | '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| '{-# VECTORISE' 'type' gtycon '#-}' | '{-# VECTORISE' 'type' gtycon '#-}'
{ unitOL $ LL $ { unitOL $ LL $
...@@ -593,8 +592,6 @@ topdecl :: { OrdList (LHsDecl RdrName) } ...@@ -593,8 +592,6 @@ topdecl :: { OrdList (LHsDecl RdrName) }
{ unitOL $ LL $ { unitOL $ LL $
VectD (HsVectTypeIn True $3 (Just $5)) } VectD (HsVectTypeIn True $3 (Just $5)) }
| '{-# VECTORISE' 'class' gtycon '#-}' { unitOL $ LL $ VectD (HsVectClassIn $3) } | '{-# VECTORISE' 'class' gtycon '#-}' { unitOL $ LL $ VectD (HsVectClassIn $3) }
| '{-# VECTORISE_SCALAR' 'instance' type '#-}'
{ unitOL $ LL $ VectD (HsVectInstIn $3) }
| annotation { unitOL $1 } | annotation { unitOL $1 }
| decl { unLoc $1 } | decl { unLoc $1 }
......
...@@ -723,18 +723,14 @@ badRuleLhsErr name lhs bad_e ...@@ -723,18 +723,14 @@ badRuleLhsErr name lhs bad_e
\begin{code} \begin{code}
rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars) rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars)
rnHsVectDecl (HsVect var Nothing)
= do { var' <- lookupLocatedOccRn var
; return (HsVect var' Nothing, unitFV (unLoc var'))
}
-- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly -- FIXME: For the moment, the right-hand side is restricted to be a variable as we cannot properly
-- typecheck a complex right-hand side without invoking 'vectType' from the vectoriser. -- typecheck a complex right-hand side without invoking 'vectType' from the vectoriser.
rnHsVectDecl (HsVect var (Just rhs@(L _ (HsVar _)))) rnHsVectDecl (HsVect var rhs@(L _ (HsVar _)))
= do { var' <- lookupLocatedOccRn var = do { var' <- lookupLocatedOccRn var
; (rhs', fv_rhs) <- rnLExpr rhs ; (rhs', fv_rhs) <- rnLExpr rhs
; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var') ; return (HsVect var' rhs', fv_rhs `addOneFV` unLoc var')
} }
rnHsVectDecl (HsVect _var (Just _rhs)) rnHsVectDecl (HsVect _var _rhs)
= failWith $ vcat = failWith $ vcat
[ ptext (sLit "IMPLEMENTATION RESTRICTION: right-hand side of a VECTORISE pragma") [ ptext (sLit "IMPLEMENTATION RESTRICTION: right-hand side of a VECTORISE pragma")
, ptext (sLit "must be an identifier") , ptext (sLit "must be an identifier")
......
...@@ -739,17 +739,12 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId) ...@@ -739,17 +739,12 @@ tcVect :: VectDecl Name -> TcM (VectDecl TcId)
-- during type checking. Instead, constrain the rhs of a vectorisation declaration to be a single -- during type checking. Instead, constrain the rhs of a vectorisation declaration to be a single
-- identifier (this is checked in 'rnHsVectDecl'). Fix this by enabling the use of 'vectType' -- identifier (this is checked in 'rnHsVectDecl'). Fix this by enabling the use of 'vectType'
-- from the vectoriser here. -- from the vectoriser here.
tcVect (HsVect name Nothing) tcVect (HsVect name rhs)
= addErrCtxt (vectCtxt name) $
do { var <- wrapLocM tcLookupId name
; return $ HsVect var Nothing
}
tcVect (HsVect name (Just rhs))
= addErrCtxt (vectCtxt name) $ = addErrCtxt (vectCtxt name) $
do { var <- wrapLocM tcLookupId name do { var <- wrapLocM tcLookupId name
; let L rhs_loc (HsVar rhs_var_name) = rhs ; let L rhs_loc (HsVar rhs_var_name) = rhs
; rhs_id <- tcLookupId rhs_var_name ; rhs_id <- tcLookupId rhs_var_name
; return $ HsVect var (Just $ L rhs_loc (HsVar rhs_id)) ; return $ HsVect var (L rhs_loc (HsVar rhs_id))
} }
{- OLD CODE: {- OLD CODE:
......
...@@ -1081,7 +1081,7 @@ zonkVects env = mappM (wrapLocM (zonkVect env)) ...@@ -1081,7 +1081,7 @@ zonkVects env = mappM (wrapLocM (zonkVect env))
zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id) zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id)
zonkVect env (HsVect v e) zonkVect env (HsVect v e)
= do { v' <- wrapLocM (zonkIdBndr env) v = do { v' <- wrapLocM (zonkIdBndr env) v
; e' <- fmapMaybeM (zonkLExpr env) e ; e' <- zonkLExpr env e
; return $ HsVect v' e' ; return $ HsVect v' e'
} }
zonkVect env (HsNoVect v) zonkVect env (HsNoVect v)
......
This diff is collapsed.
...@@ -84,16 +84,16 @@ identityConv (AppTy {}) = noV $ text "identityConv: type appl. changes under ...@@ -84,16 +84,16 @@ identityConv (AppTy {}) = noV $ text "identityConv: type appl. changes under
identityConv (FunTy {}) = noV $ text "identityConv: function type changes under vectorisation" identityConv (FunTy {}) = noV $ text "identityConv: function type changes under vectorisation"
identityConv (ForAllTy {}) = noV $ text "identityConv: quantified type changes under vectorisation" identityConv (ForAllTy {}) = noV $ text "identityConv: quantified type changes under vectorisation"
-- |Check that this type constructor is neutral under type vectorisation — i.e., it is not altered -- |Check that this type constructor is not changed by vectorisation — i.e., it does not embed any
-- by vectorisation as they contain no parallel arrays. -- parallel arrays.
-- --
identityConvTyCon :: TyCon -> VM () identityConvTyCon :: TyCon -> VM ()
identityConvTyCon tc identityConvTyCon tc
| isBoxedTupleTyCon tc = return () = do
| isUnLiftedTyCon tc = return () { tc' <- lookupTyCon tc
| otherwise ; case tc' of
= do tc' <- maybeV notVectErr (lookupTyCon tc) Nothing -> return ()
if tc == tc' then return () else noV idErr Just _ -> noV idErr
}
where where
notVectErr = text "identityConvTyCon: no vectorised version for type constructor" <+> ppr tc
idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc
...@@ -31,7 +31,7 @@ import Name ...@@ -31,7 +31,7 @@ import Name
import NameEnv import NameEnv
import FastString import FastString
import TysPrim import TysPrim
import TysWiredIn --import TysWiredIn
import Data.Maybe import Data.Maybe
...@@ -60,7 +60,8 @@ data LocalEnv ...@@ -60,7 +60,8 @@ data LocalEnv
-- ^Mapping from tyvars to their PA dictionaries. -- ^Mapping from tyvars to their PA dictionaries.
, local_bind_name :: FastString , local_bind_name :: FastString
-- ^Local binding name. -- ^Local binding name. This is only used to generate better names for hoisted
-- expressions.
} }
-- |Create an empty local environment. -- |Create an empty local environment.
...@@ -84,34 +85,33 @@ data GlobalEnv ...@@ -84,34 +85,33 @@ data GlobalEnv
-- ^Mapping from global variables to their vectorised versions — aka the /vectorisation -- ^Mapping from global variables to their vectorised versions — aka the /vectorisation
-- map/. -- map/.
, global_vect_decls :: VarEnv (Type, CoreExpr) , global_parallel_vars :: VarSet
-- ^Mapping from global variables that have a vectorisation declaration to the right-hand -- ^The domain of 'global_vars'.
-- side of that declaration and its type. This mapping only applies to non-scalar
-- vectorisation declarations. All variables with a scalar vectorisation declaration are
-- mentioned in 'global_scalars_vars'.
, global_scalar_vars :: VarSet
-- ^Purely scalar variables. Code which mentions only these variables doesn't have to be
-- lifted. This includes variables from the current module that have a scalar
-- vectorisation declaration and those that the vectoriser determines to be scalar.
, global_scalar_tycons :: NameSet
-- ^Type constructors whose values can only contain scalar data. This includes type
-- constructors that appear in a 'VECTORISE SCALAR type' pragma or 'VECTORISE type' pragma
-- *without* a right-hand side in the current or an imported module as well as type
-- constructors that are automatically identified as scalar by the vectoriser (in
-- 'Vectorise.Type.Env'). Scalar code may only operate on such data.
-- --
-- NB: Not all type constructors in that set are members of the 'Scalar' type class -- This information is not redundant as it is impossible to extract the domain from a
-- (which can be trivially marshalled across scalar code boundaries). -- 'VarEnv' (which is keyed on uniques alone). Moreover, we have mapped variables that
-- do not involve parallelism — e.g., the workers of vectorised, but scalar data types.
-- In addition, workers of parallel data types that we could not vectorise also need to
-- be tracked.
, global_novect_vars :: VarSet , global_vect_decls :: VarEnv (Maybe (Type, CoreExpr))
-- ^Variables that are not vectorised. (They may be referenced in the right-hand sides -- ^Mapping from global variables that have a vectorisation declaration to the right-hand
-- of vectorisation declarations, though.) -- side of that declaration and its type and mapping variables that have NOVECTORISE
-- declarations to 'Nothing'.
, global_tycons :: NameEnv TyCon , global_tycons :: NameEnv TyCon
-- ^Mapping from TyCons to their vectorised versions. -- ^Mapping from TyCons to their vectorised versions. The vectorised version will be
-- TyCons which do not have to be vectorised are mapped to themselves. -- identical to the original version if it is not changed by vectorisation. In any case,
-- if a tycon appears in the domain of this mapping, it was successfully vectorised.
, global_parallel_tycons :: NameSet
-- ^Type constructors whose definition directly or indirectly includes a parallel type,
-- such as '[::]'.
--
-- NB: This information is not redundant as some types have got a mapping in
-- 'global_tycons' (to a type other than themselves) and are still not parallel. An
-- example is '(->)'. Moreover, some types have *not* got a mapping in 'global_tycons'
-- (because they couldn't be vectorised), but still contain parallel types.
, global_datacons :: NameEnv DataCon , global_datacons :: NameEnv DataCon
-- ^Mapping from DataCons to their vectorised versions. -- ^Mapping from DataCons to their vectorised versions.
...@@ -129,7 +129,7 @@ data GlobalEnv ...@@ -129,7 +129,7 @@ data GlobalEnv
-- ^External package inst-env & home-package inst-env for family instances. -- ^External package inst-env & home-package inst-env for family instances.
, global_bindings :: [(Var, CoreExpr)] , global_bindings :: [(Var, CoreExpr)]
-- ^Hoisted bindings. -- ^Hoisted bindings — temporary storage for toplevel bindings during code gen.
} }
-- |Create an initial global environment. -- |Create an initial global environment.
...@@ -143,9 +143,8 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs ...@@ -143,9 +143,8 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
= GlobalEnv = GlobalEnv
{ global_vars = mapVarEnv snd $ vectInfoVar info { global_vars = mapVarEnv snd $ vectInfoVar info
, global_vect_decls = mkVarEnv vects , global_vect_decls = mkVarEnv vects
, global_scalar_vars = vectInfoScalarVars info `extendVarSetList` scalar_vars , global_parallel_vars = vectInfoParallelVars info
, global_scalar_tycons = vectInfoScalarTyCons info `addListToNameSet` scalar_tycons , global_parallel_tycons = vectInfoParallelTyCons info
, global_novect_vars = mkVarSet novects
, global_tycons = mapNameEnv snd $ vectInfoTyCon info , global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info , global_datacons = mapNameEnv snd $ vectInfoDataCon info
, global_pa_funs = emptyNameEnv , global_pa_funs = emptyNameEnv
...@@ -155,23 +154,12 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs ...@@ -155,23 +154,12 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
, global_bindings = [] , global_bindings = []
} }
where where
vects = [(var, (ty, exp)) | Vect var (Just exp@(Var rhs_var)) <- vectDecls vects = [(var, Just (ty, exp)) | Vect var exp@(Var rhs_var) <- vectDecls
, let ty = varType rhs_var] , let ty = varType rhs_var] ++
-- FIXME: we currently only allow RHSes consisting of a -- FIXME: we currently only allow RHSes consisting of a
-- single variable to be able to obtain the type without -- single variable to be able to obtain the type without
-- inference — see also 'TcBinds.tcVect' -- inference — see also 'TcBinds.tcVect'
scalar_vars = [var | Vect var Nothing <- vectDecls] ++ [(var, Nothing) | NoVect var <- vectDecls]
[var | VectInst var <- vectDecls] ++
[dataConWrapId doubleDataCon, dataConWrapId floatDataCon, dataConWrapId intDataCon] -- TODO: fix this hack
novects = [var | NoVect var <- vectDecls]
scalar_tycons = [tyConName tycon | VectType True tycon Nothing <- vectDecls] ++
[tyConName tycon | VectType _ tycon (Just tycon') <- vectDecls
, tycon == tycon'] ++
map tyConName [doublePrimTyCon, intPrimTyCon, floatPrimTyCon] -- TODO: fix this hack
-- - for 'VectType True tycon Nothing', we checked that the type does not
-- contain arrays (or type variables that could be instatiated to arrays)
-- - for 'VectType _ tycon (Just tycon')', where the two tycons are the same,
-- we also know that there can be no embedded arrays
-- Operators on Global Environments ------------------------------------------- -- Operators on Global Environments -------------------------------------------
...@@ -213,8 +201,8 @@ modVectInfo env mg_ids mg_tyCons vectDecls info ...@@ -213,8 +201,8 @@ modVectInfo env mg_ids mg_tyCons vectDecls info
{ vectInfoVar = mk_env ids (global_vars env) { vectInfoVar = mk_env ids (global_vars env)
, vectInfoTyCon = mk_env tyCons (global_tycons env) , vectInfoTyCon = mk_env tyCons (global_tycons env)
, vectInfoDataCon = mk_env dataCons (global_datacons env) , vectInfoDataCon = mk_env dataCons (global_datacons env)
, vectInfoScalarVars = global_scalar_vars env `minusVarSet` vectInfoScalarVars info , vectInfoParallelVars = global_parallel_vars env `minusVarSet` vectInfoParallelVars info
, vectInfoScalarTyCons = global_scalar_tycons env `minusNameSet` vectInfoScalarTyCons info , vectInfoParallelTyCons = global_parallel_tycons env `minusNameSet` vectInfoParallelTyCons info
} }
where where
vectIds = [id | Vect id _ <- vectDecls] ++ vectIds = [id | Vect id _ <- vectDecls] ++
......
This diff is collapsed.
...@@ -14,8 +14,8 @@ module Vectorise.Monad ( ...@@ -14,8 +14,8 @@ module Vectorise.Monad (
-- * Variables -- * Variables
lookupVar, lookupVar,
lookupVar_maybe, lookupVar_maybe,