Commit 981e4f13 authored by gckeller's avatar gckeller
Browse files

Partial Vectoriasation

parent 93212247
......@@ -95,7 +95,7 @@ initBuiltins
; applyVar <- externalVar (fsLit "$:")
; liftedApplyVar <- externalVar (fsLit "liftedApply")
; closures <- mapM externalVar (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
; let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
; let closureCtrFuns = listArray (1, mAX_DPH_SCALAR_ARGS) closures
-- Types and functions for selectors
; sel_tys <- mapM externalType (numbered "Sel" 2 mAX_DPH_SUM)
......
......@@ -30,6 +30,9 @@ import NameSet
import Name
import NameEnv
import FastString
import TysPrim
import TysWiredIn
import DataCon
import Data.Maybe
......@@ -158,11 +161,13 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
-- single variable to be able to obtain the type without
-- inference — see also 'TcBinds.tcVect'
scalar_vars = [var | Vect var Nothing <- vectDecls] ++
[var | VectInst var <- vectDecls]
[var | VectInst var <- vectDecls] ++
[dataConWrapId doubleDataCon, dataConWrapId floatDataCon, dataConWrapId intDataCon] -- TODO: fix this hack
novects = [var | NoVect var <- vectDecls]
scalar_tycons = [tyConName tycon | VectType True tycon Nothing <- vectDecls] ++
[tyConName tycon | VectType _ tycon (Just tycon') <- vectDecls
, tycon == tycon']
, tycon == tycon'] ++
map tyConName [doublePrimTyCon, intPrimTyCon, floatPrimTyCon] -- TODO: fix this hack
-- - for 'VectType True tycon Nothing', we checked that the type does not
-- contain arrays (or type variables that could be instatiated to arrays)
-- - for 'VectType _ tycon (Just tycon')', where the two tycons are the same,
......
......@@ -50,24 +50,320 @@ import Data.Maybe
import Data.List
-- |Vectorise a polymorphic expression.
--
vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding: is that binding a loop breaker?
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr)
= do { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
; return (inline, isScalarFn, vTick tickish expr')
import Debug.Trace
-- For prototyping, the VITree is a separate data structure with the same shape as the corresponding expression
-- tree. This will become part of the annotation
data VectInfo = VIParr
| VISimple
| VIComplex
deriving (Eq, Show)
data VITree = VITNode VectInfo [VITree]
deriving (Show)
viTrace :: CoreExprWithFVs -> VectInfo -> [VITree] -> VM ()
viTrace ce vi vTs =
-- return ()
traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") (ppr $ deAnnotate ce)
viOr :: [VITree] -> Bool
viOr = or . (map (\(VITNode vi _) -> vi == VIParr))
vectInfo:: CoreExprWithFVs -> VM VITree
vectInfo ce@(_, AnnVar v)
= do { vi <- vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
; return $ VITNode vi []
}
vectInfo ce@(_, AnnLit _)
= do { vi <- vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi []
; traceVt "vectInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
; return $ VITNode vi []
}
vectInfo ce@(_, AnnApp e1 e2)
= do { vt1 <- vectInfo e1
; vt2 <- vectInfo e2
; vi <- if viOr [vt1, vt2]
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi [vt1, vt2]
; return $ VITNode vi [vt1, vt2]
}
vectInfo ce@(_, AnnLam _ body)
= do { vt@(VITNode vi _) <- vectInfo body
; viTrace ce vi [vt]
; return $ VITNode vi [vt]
}
vectInfo ce@(_, AnnLet (AnnNonRec _ expr) body)
= do { vtE <- vectInfo expr
; vtB <- vectInfo body
; vi <- if viOr [vtE, vtB]
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
; viTrace ce vi [vtE, vtB]
; return $ VITNode vi [vtE, vtB]
}
vectInfo ce@(_, AnnLet (AnnRec bnds) body)
= do { vtB <- vectInfo body
; let exprs = snd $ unzip bnds
; vtBnds <- mapM vectInfo exprs
; ni <- if viOr (vtB : vtBnds)
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
; viTrace ce ni (vtB : vtBnds)
; return $ VITNode ni (vtB : vtBnds)
}
vectInfo ce@(_, AnnCase expr _var _ty alts)
= do { vtExpr <- vectInfo expr
; vtAlts <- mapM (\(_, _, e) -> vectInfo e) alts
; ni <- if viOr (vtExpr : vtAlts)
then return VIParr
else vectInfoType $ exprType $ deAnnotate ce
; viTrace ce ni (vtExpr : vtAlts)
; return $ VITNode ni (vtExpr : vtAlts)
}
vectInfo (_, AnnCast expr _)
= do { vt@(VITNode vi _) <- vectInfo expr
; return $ VITNode vi [vt]
}
vectInfo (_, AnnTick _ expr )
= do { vt@(VITNode vi _) <- vectInfo expr
; return $ VITNode vi [vt]
}
vectInfo (_, AnnType {})
= return $ VITNode VISimple []
vectInfo (_, AnnCoercion {})
= return $ VITNode VISimple []
vectInfoType:: Type -> VM VectInfo
vectInfoType ty
| maybeParrTy ty = return VIParr
| otherwise
= do { sType <- isSimpleType ty
; if sType
then return VISimple
else return VIComplex
}
-- Checks whether the type might be a parallel array type. In particular, if the outermost
-- constructor is a type family, we conservatively assume that it may be a parallel array type.
maybeParrTy :: Type -> Bool
maybeParrTy ty
| Just ty' <- coreView ty = maybeParrTy ty'
| Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
|| or (map maybeParrTy ts)
maybeParrTy _ = False
isSimpleType:: Type -> VM Bool
isSimpleType ty
| Just (c, _cs) <- splitTyConApp_maybe ty
= do { globals <- globalScalarTyCons
; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
; return (elemNameSet (tyConName c) globals )
}
| Nothing <- splitTyConApp_maybe ty
= return False
isSimpleType ty
= pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
varsSimple :: VarSet -> VM Bool
varsSimple vs
= do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
; return $ and varTypes
}
-- | Vectorise a polymorphic expression.
vectPolyExpr:: Bool -> [Var] -> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr)
= do { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
; return (inline, isScalarFn, vTick tickish expr')
}
vectPolyExpr loop_breaker recFns expr
= do { arity <- polyArity tvs
; polyAbstract tvs $ \args -> do
{ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
} }
= do { let vectAvoidance = True
; (tvs, mono) <- if vectAvoidance
then do { vi <- vectInfo expr
; extExpr <- encapsulateScalar vi expr
; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr)
; return $ collectAnnTypeBinders extExpr
}
else return $ collectAnnTypeBinders expr
; arity <- polyArity tvs
; polyAbstract tvs $ \args ->
do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
}
}
-- | encapsulate every purely sequentail subexpression with a simple return type
-- of a (potentially) parallel expression into a lambda abstraction over all its
-- free variables followed by the corresponding application to those variables.
-- Condition:
-- all free variables and the result type must be of `simple' type
-- the expression is 'complex enough', which is, for now, every expression
-- which is not constant and contains at least one operation.
--
encapsulateScalar :: VITree -> CoreExprWithFVs -> VM CoreExprWithFVs
encapsulateScalar _ ce@(_, AnnType _ty)
= return ce
encapsulateScalar _ ce@(_, AnnVar _v)
= return ce
encapsulateScalar _ ce@(_, AnnLit _)
= return ce
encapsulateScalar (VITNode _ [vit]) (fvs, AnnTick tck expr)
= do { extExpr <- encapsulateScalar vit expr
; return (fvs, AnnTick tck extExpr)
}
encapsulateScalar _ (_fvs, AnnTick _tck _expr)
= panic "encapsulateScalar AnnTick doesn't match up"
encapsulateScalar (VITNode _ [vit]) (fvs, AnnLam bndr expr)
= do { extExpr <- encapsulateScalar vit expr
; return (fvs, AnnLam bndr extExpr)
}
encapsulateScalar _ (_fvs, AnnLam _bndr _expr)
= panic "encapsulateScalar AnnLam doesn't match up"
encapsulateScalar (VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ encaps ce
_ -> do { etaCe1 <- encapsulateScalar vit1 ce1
; etaCe2 <- encapsulateScalar vit2 ce2
; return (fvs, AnnApp etaCe1 etaCe2)
}
}
encapsulateScalar _ (_fvs, AnnApp _ce1 _ce2)
= panic "encapsulateScalar AnnApp doesn't match up"
encapsulateScalar (VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ encaps ce
_ -> do { extScrut <- encapsulateScalar scrutVit scrut
; extAlts <- zipWithM expAlt altVits alts
; return (fvs, AnnCase extScrut bndr ty extAlts)
}
}
where
(tvs, mono) = collectAnnTypeBinders expr
expAlt vt (con, bndrs, expr)
= do { extExpr <- encapsulateScalar vt expr
; return (con, bndrs, extExpr)
}
encapsulateScalar _ (_fvs, AnnCase _scrut _bndr _ty _alts)
= panic "encapsulateScalar AnnCase doesn't match up"
encapsulateScalar (VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ encaps ce
_ -> do { extExpr1 <- encapsulateScalar vt1 expr1
; extExpr2 <- encapsulateScalar vt2 expr2
; return (fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2)
}
}
encapsulateScalar _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
= panic "encapsulateScalar AnnLet nonrec doesn't match up"
encapsulateScalar (VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
= do { varsS <- varsSimple fvs
; case (vi, varsS) of
(VISimple, True) -> return $ encaps ce
_ -> do { extBnds <- zipWithM expBndg vtBnds bndngs
; extExpr <- encapsulateScalar vtB expr
; return (fvs, AnnLet (AnnRec extBnds) extExpr)
}
}
where
expBndg vit (bndr, expr)
= do { extExpr <- encapsulateScalar vit expr
; return (bndr, extExpr)
}
encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2)
= panic "encapsulateScalar AnnLet rec doesn't match up"
encapsulateScalar (VITNode _ [vit]) (fvs, AnnCast expr coercion)
= do { extExpr <- encapsulateScalar vit expr
; return (fvs, AnnCast extExpr coercion)
}
encapsulateScalar _ (_fvs, AnnCast _expr _coercion)
= panic "encapsulateScalar AnnCast rec doesn't match up"
encapsulateScalar _ _
= panic "encapsulateScalar case not handled"
mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
mkAnnLam bndr ce = AnnLam bndr ce
mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
mkAnnApps (fv, aex') [] = (fv, aex')
mkAnnApps ae (v:vs) =
let
(fv, aex') = mkAnnApps ae vs
in (extendVarSet fv v, mkAnnApp (fv, aex') v)
-- CoreExprWithFVs, -- = AnnExpr Id VarSet
-- AnnExpr bndr VarSet = (annot, AnnExpr' bndr VarSet)
-- AnnLam :: bndr -> (AnnExpr bndr VarSet) -> AnnExpr' bndr VarSet
-- AnnLam bndr (AnnExpr bndr annot)
encaps :: CoreExprWithFVs -> CoreExprWithFVs
encaps (fvs, AnnCase expr bndr t alts)
| Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
(not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- TODO: globalScalarTyCons
= (fvs, AnnCase expr bndr t (map (\(ac, bndrs, aex) -> (ac, bndrs, encaps aex)) alts))
encaps ae@(fvs, _annEx)
= let
vars = varSetElems fvs
in mkAnnApps (mkAnnLams ae vars) vars
-- |Vectorise an expression.
--
......@@ -332,6 +628,25 @@ vectScalarFun forceScalar recFns expr
; let scalarVars = gscalarVars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
; traceVt ("vectScalarFun - not scalar? " ++
"\n\tall tycons scalar? : " ++ (show $all (is_scalar_ty scalarTyCons) arg_tys) ++
"\n\tresult scalar? : " ++ (show $is_scalar_ty scalarTyCons res_ty) ++
"\n\tscalar body? : " ++ (show $is_scalar scalarVars (is_scalar_ty scalarTyCons) expr) ++
"\n\tuses vars? : " ++ (show $uses scalarVars expr)
)
(ppr expr)
; onlyIfV (ptext (sLit "not a scalar function"))
(forceScalar -- user asserts the functions is scalar
||
all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
&& is_scalar_ty scalarTyCons res_ty
&& is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
&& uses scalarVars expr)
$ do { traceVt "vectScalarFun - is scalar" (ppr expr)
; mkScalarFun arg_tys res_ty expr
}
}
{-
; onlyIfV (ptext (sLit "not a scalar function"))
(forceScalar -- user asserts the functions is scalar
||
......@@ -342,7 +657,9 @@ vectScalarFun forceScalar recFns expr
&& length arg_tys <= mAX_DPH_SCALAR_ARGS)
$ mkScalarFun arg_tys res_ty expr
}
-}
where
{-
-- !!!FIXME: We would like to allow scalar functions with arguments and results that can be
-- any 'scalarTyCons', but can't at the moment, as those argument and result types
-- need to be members of the 'Scalar' class (that in its current form would better
......@@ -354,12 +671,18 @@ vectScalarFun forceScalar recFns expr
= tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
| otherwise
= False
is_scalar_ty scalarTyCons ty
-}
is_scalar_ty _scalarTyCons ty
| isPredTy ty -- dictionaries never get into the environment
= True
| Just (tycon, []) <- splitTyConApp_maybe ty -- TODO: FIX THIS!
= tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
-- = tyConName tycon `elemNameSet` scalarTyCons
| Just (tycon, _) <- splitTyConApp_maybe ty
= tyConName tycon `elemNameSet` scalarTyCons
= tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName]
-- = tyConName tycon `elemNameSet` scalarTyCons
| otherwise
= False
......@@ -374,17 +697,17 @@ vectScalarFun forceScalar recFns expr
-- The second argument is a predicate that checks whether a type is scalar.
--
is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars
is_scalar scalars _isScalarTC (Var v) =
v `elemVarSet` scalars
is_scalar _scalars _isScalarTC (Lit _) = True
is_scalar scalars isScalarTC e@(App e1 e2)
| maybe_parr_ty (exprType e) = False
| otherwise = is_scalar scalars isScalarTC e1 &&
is_scalar scalars isScalarTC (App e1 e2) = is_scalar scalars isScalarTC e1 &&
is_scalar scalars isScalarTC e2
is_scalar scalars isScalarTC (Lam var body)
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var)
isScalarTC body
is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar &&
is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++ " " ++ (show $ is_scalar scalars' isScalarTC body) ++ (show $ showSDoc $ ppr bind)) $
bindsAreScalar &&
is_scalar scalars' isScalarTC body
where
(bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment