Commit 28bb3c3c authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au
Browse files

Try not to avoid vectorising purely scalar functions

parent 7106cd1b
module VectBuiltIn (
Builtins(..), sumTyCon, prodTyCon,
combinePAVar,
combinePAVar, scalarZip, closureCtrFun,
initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
initBuiltinPAs, initBuiltinPRs,
initBuiltinBoxedTyCons,
initBuiltinBoxedTyCons, initBuiltinScalars,
primMethod, primPArray
) where
......@@ -14,6 +14,7 @@ import IfaceEnv ( lookupOrig )
import Module
import DataCon ( DataCon, dataConName, dataConWorkId )
import TyCon ( TyCon, tyConName, tyConDataCons )
import Class ( Class )
import Var ( Var )
import Id ( mkSysLocal )
import Name ( Name, getOccString )
......@@ -48,6 +49,9 @@ mAX_DPH_SUM = 3
mAX_DPH_COMBINE :: Int
mAX_DPH_COMBINE = 2
mAX_DPH_SCALAR_ARGS :: Int
mAX_DPH_SCALAR_ARGS = 3
data Modules = Modules {
dph_PArray :: Module
, dph_Repr :: Module
......@@ -55,6 +59,7 @@ data Modules = Modules {
, dph_Unboxed :: Module
, dph_Instances :: Module
, dph_Combinators :: Module
, dph_Scalar :: Module
, dph_Prelude_PArr :: Module
, dph_Prelude_Int :: Module
, dph_Prelude_Word8 :: Module
......@@ -71,6 +76,7 @@ dph_Modules pkg = Modules {
, dph_Unboxed = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
, dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
, dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
, dph_Scalar = mk (fsLit "Data.Array.Parallel.Lifted.Scalar")
, dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
, dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
......@@ -112,6 +118,9 @@ data Builtins = Builtins {
, emptyPAVar :: Var
, packPAVar :: Var
, combinePAVars :: Array Int Var
, scalarClass :: Class
, scalarZips :: Array Int Var
, closureCtrFuns :: Array Int Var
, liftingContext :: Var
}
......@@ -131,6 +140,16 @@ combinePAVar n bi
| n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
| otherwise = pprPanic "combinePAVar" (ppr n)
scalarZip :: Int -> Builtins -> Var
scalarZip n bi
| n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n
| otherwise = pprPanic "scalarZip" (ppr n)
closureCtrFun :: Int -> Builtins -> Var
closureCtrFun n bi
| n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n
| otherwise = pprPanic "closureCtrFun" (ppr n)
initBuiltins :: PackageId -> DsM Builtins
initBuiltins pkg
= do
......@@ -171,6 +190,19 @@ initBuiltins pkg
| i <- [2..mAX_DPH_COMBINE]]
let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
scalarClass <- externalClass dph_Scalar (fsLit "Scalar")
scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
scalar_zips <- mapM (externalVar dph_Scalar)
[mkFastString ("scalar_zipWith" ++ show i)
| i <- [3 .. mAX_DPH_SCALAR_ARGS]]
let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
(scalar_map : scalar_zip2 : scalar_zips)
closures <- mapM (externalVar dph_Closure)
[mkFastString ("closure" ++ show i)
| i <- [1 .. mAX_DPH_SCALAR_ARGS]]
let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
newUnique
......@@ -203,6 +235,9 @@ initBuiltins pkg
, emptyPAVar = emptyPAVar
, packPAVar = packPAVar
, combinePAVars = combinePAVars
, scalarClass = scalarClass
, scalarZips = scalarZips
, closureCtrFuns = closureCtrFuns
, liftingContext = liftingContext
}
where
......@@ -211,6 +246,7 @@ initBuiltins pkg
, dph_Repr = dph_Repr
, dph_Closure = dph_Closure
, dph_Unboxed = dph_Unboxed
, dph_Scalar = dph_Scalar
})
= dph_Modules pkg
......@@ -447,6 +483,91 @@ builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
builtinBoxedTyCons _ =
[(tyConName intPrimTyCon, intTyCon)]
initBuiltinScalars :: Builtins -> DsM [Var]
initBuiltinScalars bi
= mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
preludeScalars :: Modules -> [(Module, FastString)]
preludeScalars (Modules { dph_Prelude_Int = dph_Prelude_Int
, dph_Prelude_Word8 = dph_Prelude_Word8
, dph_Prelude_Double = dph_Prelude_Double
})
= [
mk dph_Prelude_Int "div"
, mk dph_Prelude_Int "mod"
, mk dph_Prelude_Int "sqrt"
]
++ scalars_Ord dph_Prelude_Int
++ scalars_Num dph_Prelude_Int
++ scalars_Ord dph_Prelude_Word8
++ scalars_Num dph_Prelude_Word8
++
[ mk dph_Prelude_Word8 "div"
, mk dph_Prelude_Word8 "mod"
, mk dph_Prelude_Word8 "fromInt"
, mk dph_Prelude_Word8 "toInt"
]
++ scalars_Ord dph_Prelude_Double
++ scalars_Num dph_Prelude_Double
++ scalars_Fractional dph_Prelude_Double
++ scalars_Floating dph_Prelude_Double
++ scalars_RealFrac dph_Prelude_Double
where
mk mod s = (mod, fsLit s)
scalars_Ord mod = [mk mod "=="
,mk mod "/="
,mk mod "<="
,mk mod "<"
,mk mod ">="
,mk mod ">"
,mk mod "min"
,mk mod "max"
]
scalars_Num mod = [mk mod "+"
,mk mod "-"
,mk mod "*"
,mk mod "negate"
,mk mod "abs"
]
scalars_Fractional mod = [mk mod "/"
,mk mod "recip"
]
scalars_Floating mod = [mk mod "pi"
,mk mod "exp"
,mk mod "sqrt"
,mk mod "log"
,mk mod "sin"
,mk mod "tan"
,mk mod "cos"
,mk mod "asin"
,mk mod "atan"
,mk mod "acos"
,mk mod "sinh"
,mk mod "tanh"
,mk mod "cosh"
,mk mod "asinh"
,mk mod "atanh"
,mk mod "acosh"
,mk mod "**"
,mk mod "logBase"
]
scalars_RealFrac mod = [mk mod "fromInt"
,mk mod "truncate"
,mk mod "round"
,mk mod "ceiling"
,mk mod "floor"
]
externalVar :: Module -> FastString -> DsM Var
externalVar mod fs
= dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
......@@ -461,6 +582,10 @@ externalType mod fs
tycon <- externalTyCon mod fs
return $ mkTyConApp tycon []
externalClass :: Module -> FastString -> DsM Class
externalClass mod fs
= dsLookupClass =<< lookupOrig mod (mkTcOccFS fs)
unitTyConName :: Name
unitTyConName = tyConName unitTyCon
......
......@@ -2,14 +2,15 @@ module VectMonad (
Scope(..),
VM,
noV, traceNoV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV,
noV, traceNoV, ensureV, traceEnsureV, tryV, maybeV, traceMaybeV, orElseV,
onlyIfV, fixV, localV, closedV,
initV, cantVectorise, maybeCantVectorise, maybeCantVectoriseM,
liftDs,
cloneName, cloneId, cloneVar,
newExportedVar, newLocalVar, newDummyVar, newTyVar,
Builtins(..), sumTyCon, prodTyCon,
combinePAVar,
combinePAVar, scalarZip, closureCtrFun,
builtin, builtins,
GlobalEnv(..),
......@@ -21,7 +22,7 @@ module VectMonad (
getBindName, inBind,
lookupVar, defGlobalVar,
lookupVar, defGlobalVar, globalScalars,
lookupTyCon, defTyCon,
lookupDataCon, defDataCon,
lookupTyConPA, defTyConPA, defTyConPAs,
......@@ -30,7 +31,7 @@ module VectMonad (
lookupPrimMethod, lookupPrimPArray,
lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars,
{-lookupInst,-} lookupFamInst
lookupInst, lookupFamInst
) where
#include "HsVersions.h"
......@@ -40,10 +41,12 @@ import VectBuiltIn
import HscTypes hiding ( MonadThings(..) )
import Module ( PackageId )
import CoreSyn
import Class
import TyCon
import DataCon
import Type
import Var
import VarSet
import VarEnv
import Id
import Name
......@@ -71,6 +74,10 @@ data GlobalEnv = GlobalEnv {
--
global_vars :: VarEnv Var
-- Purely scalar variables. Code which mentions only these
-- variables doesn't have to be lifted.
, global_scalars :: VarSet
-- Exported variables which have a vectorised version
--
, global_exported_vars :: VarEnv (Var, Var)
......@@ -130,6 +137,7 @@ initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> GlobalEnv
initGlobalEnv info instEnvs famInstEnvs
= GlobalEnv {
global_vars = mapVarEnv snd $ vectInfoVar info
, global_scalars = emptyVarSet
, global_exported_vars = emptyVarEnv
, global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info
......@@ -145,6 +153,10 @@ extendImportedVarsEnv :: [(Var, Var)] -> GlobalEnv -> GlobalEnv
extendImportedVarsEnv ps genv
= genv { global_vars = extendVarEnvList (global_vars genv) ps }
extendScalars :: [Var] -> GlobalEnv -> GlobalEnv
extendScalars vs genv
= genv { global_scalars = extendVarSetList (global_scalars genv) vs }
setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
setFamInstEnv l_fam_inst genv
= genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
......@@ -231,6 +243,17 @@ noV = VM $ \_ _ _ -> return No
traceNoV :: String -> SDoc -> VM a
traceNoV s d = pprTrace s d noV
ensureV :: Bool -> VM ()
ensureV False = noV
ensureV True = return ()
onlyIfV :: Bool -> VM a -> VM a
onlyIfV b p = ensureV b >> p
traceEnsureV :: String -> SDoc -> Bool -> VM ()
traceEnsureV s d False = traceNoV s d
traceEnsureV s d True = return ()
tryV :: VM a -> VM (Maybe a)
tryV (VM p) = VM $ \bi genv lenv ->
do
......@@ -301,10 +324,8 @@ setLEnv lenv = VM $ \_ genv _ -> return (Yes genv lenv ())
updLEnv :: (LocalEnv -> LocalEnv) -> VM ()
updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
{-
getInstEnv :: VM (InstEnv, InstEnv)
getInstEnv = readGEnv global_inst_env
-}
getFamInstEnv :: VM FamInstEnvs
getFamInstEnv = readGEnv global_fam_inst_env
......@@ -382,6 +403,9 @@ lookupVar v
. maybeCantVectoriseM "Variable not vectorised:" (ppr v)
. readGEnv $ \env -> lookupVarEnv (global_vars env) v
globalScalars :: VM VarSet
globalScalars = readGEnv global_scalars
lookupTyCon :: TyCon -> VM (Maybe TyCon)
lookupTyCon tc
| isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc)
......@@ -453,7 +477,6 @@ localTyVars = readLEnv (reverse . local_tyvars)
-- instances head (i.e., no flexi vars); for details for what this means,
-- see the docs at InstEnv.lookupInstEnv.
--
{-
lookupInst :: Class -> [Type] -> VM (DFunId, [Type])
lookupInst cls tys
= do { instEnv <- getInstEnv
......@@ -465,12 +488,12 @@ lookupInst cls tys
where
inst_tys' = [ty | Right ty <- inst_tys]
noFlexiVar = all isRight inst_tys
_other -> traceNoV "lookupInst" (ppr cls <+> ppr tys)
_other ->
pprPanic "VectMonad.lookupInst: not found " (ppr cls <+> ppr tys)
}
where
isRight (Left _) = False
isRight (Right _) = True
-}
-- Look up the representation tycon of a family instance.
--
......@@ -520,12 +543,14 @@ initV pkg hsc_env guts info p
builtin_pas <- initBuiltinPAs builtins
builtin_prs <- initBuiltinPRs builtins
builtin_boxed <- initBuiltinBoxedTyCons builtins
builtin_scalars <- initBuiltinScalars builtins
eps <- liftIO $ hscEPS hsc_env
let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
instEnvs = (eps_inst_env eps, mg_inst_env guts)
let genv = extendImportedVarsEnv builtin_vars
. extendScalars builtin_scalars
. extendTyConsEnv builtin_tycons
. extendDataConsEnv builtin_datacons
. extendPAFunsEnv builtin_pas
......
......@@ -12,6 +12,7 @@ module VectUtils (
prDFunOfTyCon,
paDictArgType, paDictOfType, paDFunType,
paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
zipScalars, scalarClosure,
polyAbstract, polyApply, polyVApply,
hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
buildClosure, buildClosures,
......@@ -270,6 +271,24 @@ liftPA x
lc <- builtin liftingContext
replicatePA (Var lc) x
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
= do
scalar <- builtin scalarClass
(dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
zipf <- builtin (scalarZip $ length arg_tys)
return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
where
ty_args = arg_tys ++ [res_ty]
scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
= do
ctr <- builtin (closureCtrFun $ length arg_tys)
pas <- mapM paDictOfType (init arg_tys)
return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
`mkApps` (pas ++ [scalar_fun, array_fun])
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
= do
......
......@@ -264,12 +264,44 @@ vectExpr (_, AnnLet (AnnRec bs) body)
$ vectExpr rhs
vectExpr e@(fvs, AnnLam bndr _)
| isId bndr = vectLam fvs bs body
| isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam fvs bs body
where
(bs,body) = collectAnnValBinders e
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
vectScalarLam args body
= do
scalars <- globalScalars
onlyIfV (all is_scalar_ty arg_tys
&& is_scalar_ty res_ty
&& is_scalar (extendVarSetList scalars args) body)
$ do
fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
zipf <- zipScalars arg_tys res_ty
clo <- scalarClosure arg_tys res_ty (Var fn_var)
(zipf `App` Var fn_var)
clo_var <- hoistExpr (fsLit "clo") clo
lclo <- liftPA (Var clo_var)
return (Var clo_var, lclo)
where
arg_tys = map idType args
res_ty = exprType body
is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
| otherwise = False
is_scalar vs (Var v) = v `elemVarSet` vs
is_scalar _ e@(Lit l) = is_scalar_ty $ exprType e
is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
is_scalar _ _ = False
vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
= do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment