Commit a33dddc9 authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.
Browse files

Vectoriser: distinguish vectorised from parallel types and functions

- We sometimes need to vectorise types and functions because they might be needed in a vectorised context, not because they do directly introduce parallelism.
parent 895ff213
...@@ -243,20 +243,29 @@ liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts) ...@@ -243,20 +243,29 @@ liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
{ vi <- vectAvoidInfoTypeOf expr { vi <- vectAvoidInfoTypeOf expr
; if (vi == VISimple) ; if (vi == VISimple)
then then
return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
else do else do
{ alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
; return ((fvs, vi), AnnCase expr bndr t alts') ; return ((fvs, vi), AnnCase expr bndr t alts')
} }
} }
liftSimpleAndCase aexpr = return $ liftSimple aexpr liftSimpleAndCase aexpr = liftSimple aexpr
liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
liftSimple ((fvs, vi), expr) liftSimple aexpr@((fvs_orig, VISimple), expr)
= ASSERT(vi == VISimple) = do
mkAnnApps (mkAnnLams vars fvs expr) vars { let liftedExpr = mkAnnApps (mkAnnLams vars fvs expr) vars
; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
; return $ liftedExpr
}
where where
vars = varSetElems fvs vars = varSetElems fvs
fvs = filterVarSet isToplevel fvs_orig -- only include 'Id's that are not toplevel
isToplevel v | isId v = not . uf_is_top . realIdUnfolding $ v
| otherwise = False
mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs) mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
...@@ -270,23 +279,31 @@ liftSimple ((fvs, vi), expr) ...@@ -270,23 +279,31 @@ liftSimple ((fvs, vi), expr)
mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
mkAnnApp aexpr@((fvs, _vi), _expr) v mkAnnApp aexpr@((fvs, _vi), _expr) v
= ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v)) = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
liftSimple aexpr
= pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
-- |Vectorise an expression. -- |Vectorise an expression.
-- --
vectExpr :: CoreExprWithVectInfo -> VM VExpr vectExpr :: CoreExprWithVectInfo -> VM VExpr
vectExpr (_, AnnVar v) -- !!!FIXME: needs to check for VIEncaps regardless of syntactic form first; in case it is of functional type
vectExpr aexpr@(_, AnnVar v)
| (isFunTy . varType $ v) && isVIEncaps aexpr
= vectFnExpr False False aexpr
| otherwise
= vectVar v = vectVar v
vectExpr (_, AnnLit lit) vectExpr (_, AnnLit lit)
= vectConst $ Lit lit = vectConst $ Lit lit
vectExpr e@(_, AnnLam bndr _) vectExpr aexpr@(_, AnnLam bndr _)
| isId bndr = vectFnExpr True False e | isId bndr = vectFnExpr True False aexpr
| otherwise | otherwise
= do = do
{ dflags <- getDynFlags { dflags <- getDynFlags
; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e) ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate aexpr)
} }
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty'; -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
...@@ -408,14 +425,18 @@ vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body) ...@@ -408,14 +425,18 @@ vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body)
; vbody <- vectFnExpr inline loop_breaker body ; vbody <- vectFnExpr inline loop_breaker body
; return $ mapVect (mkLams [vectorised vBndr]) vbody ; return $ mapVect (mkLams [vectorised vBndr]) vbody
} }
-- non-predicate abstraction: vectorise as a scalar computation -- encapsulated non-predicate abstraction: vectorise as a scalar computation
| isId bndr && isVIEncaps expr | isId bndr && isVIEncaps expr
= vectScalarFun . deAnnotate $ expr = vectScalarFun . deAnnotate $ expr
-- non-predicate abstraction: vectorise as a non-scalar computation -- non-predicate abstraction: vectorise as a non-scalar computation
| isId bndr | isId bndr
= vectLam inline loop_breaker expr = vectLam inline loop_breaker expr
vectFnExpr _ _ expr vectFnExpr _ _ expr
-- not an abstraction: vectorise as a vanilla expression -- encapsulated function: vectorise as a scalar computation
| (isFunTy . annExprType $ expr) && isVIEncaps expr
= vectScalarFun . deAnnotate $ expr
| otherwise
-- not an abstraction: vectorise as a non-scalar vanilla expression
= vectExpr expr = vectExpr expr
-- |Vectorise type and dictionary applications. -- |Vectorise type and dictionary applications.
...@@ -543,7 +564,7 @@ vectDictExpr (Coercion coe) ...@@ -543,7 +564,7 @@ vectDictExpr (Coercion coe)
vectScalarFun :: CoreExpr -> VM VExpr vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr vectScalarFun expr
= do = do
{ traceVt "vectScalarFun" (ppr expr) { traceVt "vectorise scalar functions:" (ppr expr)
; let (arg_tys, res_ty) = splitFunTys (exprType expr) ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
; mkScalarFun arg_tys res_ty expr ; mkScalarFun arg_tys res_ty expr
} }
......
...@@ -47,32 +47,32 @@ classifyTyCons :: UniqFM Bool -- ^type constructor vectorisati ...@@ -47,32 +47,32 @@ classifyTyCons :: UniqFM Bool -- ^type constructor vectorisati
-> [TyCon] -- ^type constructors that need to be classified -> [TyCon] -- ^type constructors that need to be classified
-> ( [TyCon] -- to be converted -> ( [TyCon] -- to be converted
, [TyCon] -- need not be converted (but could be) , [TyCon] -- need not be converted (but could be)
, [TyCon] -- can't be converted, but involve parallel arrays , [TyCon] -- involve parallel arrays (whether converted or not)
, [TyCon] -- can't be converted and have no parallel arrays , [TyCon] -- can't be converted
) )
classifyTyCons convStatus parTyCons tcs = classify [] [] [] [] convStatus parTyCons (tyConGroups tcs) classifyTyCons convStatus parTyCons tcs = classify [] [] [] [] convStatus parTyCons (tyConGroups tcs)
where where
classify conv keep par novect _ _ [] = (conv, keep, par, novect) classify conv keep par novect _ _ [] = (conv, keep, par, novect)
classify conv keep par novect cs pts ((tcs, ds) : rs) classify conv keep par novect cs pts ((tcs, ds) : rs)
| can_convert && must_convert | can_convert && must_convert
= classify (tcs ++ conv) keep par novect (cs `addListToUFM` [(tc, True) | tc <- tcs]) pts' rs = classify (tcs ++ conv) keep (par ++ tcs_par) novect (cs `addListToUFM` [(tc, True) | tc <- tcs]) pts' rs
| can_convert | can_convert
= classify conv (tcs ++ keep) par novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs = classify conv (tcs ++ keep) (par ++ tcs_par) novect (cs `addListToUFM` [(tc, False) | tc <- tcs]) pts' rs
| has_parr
= classify conv keep (tcs ++ par) novect cs pts' rs
| otherwise | otherwise
= classify conv keep par (tcs ++ novect) cs pts' rs = classify conv keep (par ++ tcs_par) (tcs ++ novect) cs pts' rs
where where
refs = ds `delListFromUniqSet` tcs refs = ds `delListFromUniqSet` tcs
pts' | has_parr = pts `addListToNameSet` map tyConName tcs -- the tycons that directly or indirectly depend on parallel arrays
| otherwise = pts tcs_par | any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs = tcs
| otherwise = []
pts' = pts `addListToNameSet` map tyConName tcs_par
can_convert = (isNullUFM (refs `minusUFM` cs) && all convertable tcs) can_convert = (isNullUFM (refs `minusUFM` cs) && all convertable tcs)
|| isShowClass tcs || isShowClass tcs
must_convert = foldUFM (||) False (intersectUFM_C const cs refs) must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
&& (not . isShowClass $ tcs) && (not . isShowClass $ tcs)
has_parr = any ((`elemNameSet` parTyCons) . tyConName) . eltsUFM $ refs
-- We currently admit Haskell 2011-style data and newtype declarations as well as type -- We currently admit Haskell 2011-style data and newtype declarations as well as type
-- constructors representing classes. -- constructors representing classes.
......
...@@ -205,9 +205,11 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls ...@@ -205,9 +205,11 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
-- these are being handled separately. NB: Some type constructors may be marked SCALAR -- these are being handled separately. NB: Some type constructors may be marked SCALAR
-- /and/ have an explicit right-hand side.) -- /and/ have an explicit right-hand side.)
-- --
-- Furthermore, 'par_tcs' and 'drop_tcs' are those type constructors that we cannot -- Furthermore, 'par_tcs' are those type constructors (converted or not) whose
-- vectorise, and of those, only the 'par_tcs' involve parallel arrays. -- definition, directly or indirectly, depends on parallel arrays. Finally, 'drop_tcs'
; parallelTyCons <- globalParallelTyCons -- are all type constructors that cannot be vectorised.
; parallelTyCons <- (`addListToNameSet` map (tyConName . fst3) vectTyConsWithRHS) <$>
globalParallelTyCons
; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons ; let maybeVectoriseTyCons = filter notVectSpecialTyCon tycons ++ impVectTyCons
(conv_tcs, keep_tcs, par_tcs, drop_tcs) (conv_tcs, keep_tcs, par_tcs, drop_tcs)
= classifyTyCons vectTyConFlavour parallelTyCons maybeVectoriseTyCons = classifyTyCons vectTyConFlavour parallelTyCons maybeVectoriseTyCons
...@@ -223,12 +225,12 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls ...@@ -223,12 +225,12 @@ vectTypeEnv tycons vectTypeDecls vectClassDecls
-- warn the user about unvectorised type constructors -- warn the user about unvectorised type constructors
; let explanation = ptext (sLit "(They use unsupported language extensions") $$ ; let explanation = ptext (sLit "(They use unsupported language extensions") $$
ptext (sLit "or depend on type constructors that are not vectorised)") ptext (sLit "or depend on type constructors that are not vectorised)")
drop_tcs_nosyn = filter (not . isSynTyCon) (par_tcs ++ drop_tcs) drop_tcs_nosyn = filter (not . isSynTyCon) drop_tcs
; unless (null drop_tcs_nosyn) $ ; unless (null drop_tcs_nosyn) $
emitVt "Warning: cannot vectorise these type constructors:" $ emitVt "Warning: cannot vectorise these type constructors:" $
pprQuotedList drop_tcs_nosyn $$ explanation pprQuotedList drop_tcs_nosyn $$ explanation
; mapM_ addParallelTyConAndCons $ conv_tcs ++ par_tcs ; mapM_ addParallelTyConAndCons $ par_tcs ++ [tc | (tc, _, False) <- vectTyConsWithRHS]
; let mapping = ; let mapping =
-- Type constructors that we found we don't need to vectorise and those -- Type constructors that we found we don't need to vectorise and those
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment