Commit 3736e30f authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au
Browse files

Separate length from data in DPH arrays

parent 63b4ba98
......@@ -57,7 +57,7 @@ module OccName (
mkSuperDictSelOcc, mkLocalOcc, mkMethodOcc, mkInstTyTcOcc,
mkInstTyCoOcc, mkEqPredCoOcc,
mkVectOcc, mkVectTyConOcc, mkVectDataConOcc, mkVectIsoOcc,
mkPArrayTyConOcc, mkPArrayDataConOcc,
mkPDataTyConOcc, mkPDataDataConOcc,
mkPReprTyConOcc,
mkPADFunOcc,
......@@ -529,7 +529,7 @@ mkDataConWrapperOcc, mkWorkerOcc, mkDefaultMethodOcc, mkDerivedTyConOcc,
mkInstTyCoOcc, mkEqPredCoOcc,
mkCon2TagOcc, mkTag2ConOcc, mkMaxTagOcc,
mkVectOcc, mkVectTyConOcc, mkVectDataConOcc, mkVectIsoOcc,
mkPArrayTyConOcc, mkPArrayDataConOcc, mkPReprTyConOcc, mkPADFunOcc
mkPDataTyConOcc, mkPDataDataConOcc, mkPReprTyConOcc, mkPADFunOcc
:: OccName -> OccName
-- These derived variables have a prefix that no Haskell value could have
......@@ -568,8 +568,8 @@ mkVectOcc = mk_simple_deriv varName "$v_"
mkVectTyConOcc = mk_simple_deriv tcName ":V_"
mkVectDataConOcc = mk_simple_deriv dataName ":VD_"
mkVectIsoOcc = mk_simple_deriv varName "$VI_"
mkPArrayTyConOcc = mk_simple_deriv tcName ":VP_"
mkPArrayDataConOcc = mk_simple_deriv dataName ":VPD_"
mkPDataTyConOcc = mk_simple_deriv tcName ":VP_"
mkPDataDataConOcc = mk_simple_deriv dataName ":VPD_"
mkPReprTyConOcc = mk_simple_deriv tcName ":VR_"
mkPADFunOcc = mk_simple_deriv varName "$PA_"
......
module VectBuiltIn (
Builtins(..), sumTyCon, prodTyCon,
combinePAVar, scalarZip, closureCtrFun,
Builtins(..), sumTyCon, prodTyCon, prodDataCon,
selTy, selReplicate, selPick, selElements,
combinePDVar, scalarZip, closureCtrFun,
initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
initBuiltinPAs, initBuiltinPRs,
initBuiltinBoxedTyCons, initBuiltinScalars,
......@@ -15,6 +16,7 @@ import Module
import DataCon ( DataCon, dataConName, dataConWorkId )
import TyCon ( TyCon, tyConName, tyConDataCons )
import Class ( Class )
import CoreSyn ( CoreExpr, Expr(..) )
import Var ( Var )
import Id ( mkSysLocal )
import Name ( Name, getOccString )
......@@ -44,7 +46,7 @@ mAX_DPH_PROD :: Int
mAX_DPH_PROD = 5
mAX_DPH_SUM :: Int
mAX_DPH_SUM = 3
mAX_DPH_SUM = 2
mAX_DPH_COMBINE :: Int
mAX_DPH_COMBINE = 2
......@@ -60,6 +62,7 @@ data Modules = Modules {
, dph_Instances :: Module
, dph_Combinators :: Module
, dph_Scalar :: Module
, dph_Selector :: Module
, dph_Prelude_PArr :: Module
, dph_Prelude_Int :: Module
, dph_Prelude_Word8 :: Module
......@@ -77,6 +80,7 @@ dph_Modules pkg = Modules {
, dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
, dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
, dph_Scalar = mk (fsLit "Data.Array.Parallel.Lifted.Scalar")
, dph_Selector = mk (fsLit "Data.Array.Parallel.Lifted.Selector")
, dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
, dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
......@@ -92,42 +96,61 @@ dph_Modules pkg = Modules {
data Builtins = Builtins {
dphModules :: Modules
, parrayTyCon :: TyCon
, parrayDataCon :: DataCon
, pdataTyCon :: TyCon
, paTyCon :: TyCon
, paDataCon :: DataCon
, preprTyCon :: TyCon
, prTyCon :: TyCon
, prDataCon :: DataCon
, intPrimArrayTy :: Type
, voidTyCon :: TyCon
, wrapTyCon :: TyCon
, enumerationTyCon :: TyCon
, selTys :: Array Int Type
, selReplicates :: Array Int CoreExpr
, selPicks :: Array Int CoreExpr
, selEls :: Array (Int, Int) CoreExpr
, sumTyCons :: Array Int TyCon
, closureTyCon :: TyCon
, voidVar :: Var
, pvoidVar :: Var
, punitVar :: Var
, mkPRVar :: Var
, mkClosureVar :: Var
, applyClosureVar :: Var
, mkClosurePVar :: Var
, applyClosurePVar :: Var
, replicatePAIntPrimVar :: Var
, upToPAIntPrimVar :: Var
, selectPAIntPrimVar :: Var
, truesPABoolPrimVar :: Var
, lengthPAVar :: Var
, replicatePAVar :: Var
, emptyPAVar :: Var
, packPAVar :: Var
, combinePAVars :: Array Int Var
, closureVar :: Var
, applyVar :: Var
, liftedClosureVar :: Var
, liftedApplyVar :: Var
, replicatePDVar :: Var
, emptyPDVar :: Var
, packPDVar :: Var
, combinePDVars :: Array Int Var
, scalarClass :: Class
, scalarZips :: Array Int Var
, closureCtrFuns :: Array Int Var
, liftingContext :: Var
}
indexBuiltin :: (Ix i, Outputable i) => String -> (Builtins -> Array i a)
-> i -> Builtins -> a
indexBuiltin fn f i bi
| inRange (bounds xs) i = xs ! i
| otherwise = pprPanic fn (ppr i)
where
xs = f bi
selTy :: Int -> Builtins -> Type
selTy = indexBuiltin "selTy" selTys
selReplicate :: Int -> Builtins -> CoreExpr
selReplicate = indexBuiltin "selReplicate" selReplicates
selPick :: Int -> Builtins -> CoreExpr
selPick = indexBuiltin "selPick" selPicks
selElements :: Int -> Int -> Builtins -> CoreExpr
selElements i j = indexBuiltin "selElements" selEls (i,j)
sumTyCon :: Int -> Builtins -> TyCon
sumTyCon n bi
| n >= 2 && n <= mAX_DPH_SUM = sumTyCons bi ! n
| otherwise = pprPanic "sumTyCon" (ppr n)
sumTyCon = indexBuiltin "sumTyCon" sumTyCons
prodTyCon :: Int -> Builtins -> TyCon
prodTyCon n bi
......@@ -135,72 +158,77 @@ prodTyCon n bi
| n >= 0 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
| otherwise = pprPanic "prodTyCon" (ppr n)
combinePAVar :: Int -> Builtins -> Var
combinePAVar n bi
| n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
| otherwise = pprPanic "combinePAVar" (ppr n)
prodDataCon :: Int -> Builtins -> DataCon
prodDataCon n bi = case tyConDataCons (prodTyCon n bi) of
[con] -> con
combinePDVar :: Int -> Builtins -> Var
combinePDVar = indexBuiltin "combinePDVar" combinePDVars
scalarZip :: Int -> Builtins -> Var
scalarZip n bi
| n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n
| otherwise = pprPanic "scalarZip" (ppr n)
scalarZip = indexBuiltin "scalarZip" scalarZips
closureCtrFun :: Int -> Builtins -> Var
closureCtrFun n bi
| n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n
| otherwise = pprPanic "closureCtrFun" (ppr n)
closureCtrFun = indexBuiltin "closureCtrFun" closureCtrFuns
initBuiltins :: PackageId -> DsM Builtins
initBuiltins pkg
= do
parrayTyCon <- externalTyCon dph_PArray (fsLit "PArray")
let [parrayDataCon] = tyConDataCons parrayTyCon
pdataTyCon <- externalTyCon dph_PArray (fsLit "PData")
paTyCon <- externalTyCon dph_PArray (fsLit "PA")
let [paDataCon] = tyConDataCons paTyCon
preprTyCon <- externalTyCon dph_PArray (fsLit "PRepr")
prTyCon <- externalTyCon dph_PArray (fsLit "PR")
let [prDataCon] = tyConDataCons prTyCon
intPrimArrayTy <- externalType dph_Unboxed (fsLit "PArray_Int#")
closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
voidTyCon <- externalTyCon dph_Repr (fsLit "Void")
wrapTyCon <- externalTyCon dph_Repr (fsLit "Wrap")
enumerationTyCon <- externalTyCon dph_Repr (fsLit "Enumeration")
sum_tcs <- mapM (externalTyCon dph_Repr)
[mkFastString ("Sum" ++ show i) | i <- [2..mAX_DPH_SUM]]
let sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
sel_tys <- mapM (externalType dph_Selector)
(numbered "Sel" 2 mAX_DPH_SUM)
sel_replicates <- mapM (externalFun dph_Selector)
(numbered "replicate" 2 mAX_DPH_SUM)
sel_picks <- mapM (externalFun dph_Selector)
(numbered "pick" 2 mAX_DPH_SUM)
sel_els <- mapM mk_elements
[(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
sum_tcs <- mapM (externalTyCon dph_Repr)
(numbered "Sum" 2 mAX_DPH_SUM)
let selTys = listArray (2, mAX_DPH_SUM) sel_tys
selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
selPicks = listArray (2, mAX_DPH_SUM) sel_picks
selEls = array ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
voidVar <- externalVar dph_Repr (fsLit "void")
pvoidVar <- externalVar dph_Repr (fsLit "pvoid")
punitVar <- externalVar dph_Repr (fsLit "punit")
mkPRVar <- externalVar dph_PArray (fsLit "mkPR")
mkClosureVar <- externalVar dph_Closure (fsLit "mkClosure")
applyClosureVar <- externalVar dph_Closure (fsLit "$:")
mkClosurePVar <- externalVar dph_Closure (fsLit "mkClosureP")
applyClosurePVar <- externalVar dph_Closure (fsLit "$:^")
replicatePAIntPrimVar <- externalVar dph_Unboxed (fsLit "replicatePA_Int#")
upToPAIntPrimVar <- externalVar dph_Unboxed (fsLit "upToPA_Int#")
selectPAIntPrimVar <- externalVar dph_Unboxed (fsLit "selectPA_Int#")
truesPABoolPrimVar <- externalVar dph_Unboxed (fsLit "truesPA_Bool#")
lengthPAVar <- externalVar dph_PArray (fsLit "lengthPA#")
replicatePAVar <- externalVar dph_PArray (fsLit "replicatePA#")
emptyPAVar <- externalVar dph_PArray (fsLit "emptyPA")
packPAVar <- externalVar dph_PArray (fsLit "packPA#")
closureVar <- externalVar dph_Closure (fsLit "closure")
applyVar <- externalVar dph_Closure (fsLit "$:")
liftedClosureVar <- externalVar dph_Closure (fsLit "liftedClosure")
liftedApplyVar <- externalVar dph_Closure (fsLit "liftedApply")
replicatePDVar <- externalVar dph_PArray (fsLit "replicatePD")
emptyPDVar <- externalVar dph_PArray (fsLit "emptyPD")
packPDVar <- externalVar dph_PArray (fsLit "packPD")
combines <- mapM (externalVar dph_PArray)
[mkFastString ("combine" ++ show i ++ "PA#")
[mkFastString ("combine" ++ show i ++ "PD")
| i <- [2..mAX_DPH_COMBINE]]
let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
scalarClass <- externalClass dph_Scalar (fsLit "Scalar")
scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
scalar_zips <- mapM (externalVar dph_Scalar)
[mkFastString ("scalar_zipWith" ++ show i)
| i <- [3 .. mAX_DPH_SCALAR_ARGS]]
(numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
(scalar_map : scalar_zip2 : scalar_zips)
closures <- mapM (externalVar dph_Closure)
[mkFastString ("closure" ++ show i)
| i <- [1 .. mAX_DPH_SCALAR_ARGS]]
(numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
......@@ -209,32 +237,33 @@ initBuiltins pkg
return $ Builtins {
dphModules = modules
, parrayTyCon = parrayTyCon
, parrayDataCon = parrayDataCon
, pdataTyCon = pdataTyCon
, paTyCon = paTyCon
, paDataCon = paDataCon
, preprTyCon = preprTyCon
, prTyCon = prTyCon
, prDataCon = prDataCon
, intPrimArrayTy = intPrimArrayTy
, voidTyCon = voidTyCon
, wrapTyCon = wrapTyCon
, enumerationTyCon = enumerationTyCon
, selTys = selTys
, selReplicates = selReplicates
, selPicks = selPicks
, selEls = selEls
, sumTyCons = sumTyCons
, closureTyCon = closureTyCon
, voidVar = voidVar
, pvoidVar = pvoidVar
, punitVar = punitVar
, mkPRVar = mkPRVar
, mkClosureVar = mkClosureVar
, applyClosureVar = applyClosureVar
, mkClosurePVar = mkClosurePVar
, applyClosurePVar = applyClosurePVar
, replicatePAIntPrimVar = replicatePAIntPrimVar
, upToPAIntPrimVar = upToPAIntPrimVar
, selectPAIntPrimVar = selectPAIntPrimVar
, truesPABoolPrimVar = truesPABoolPrimVar
, lengthPAVar = lengthPAVar
, replicatePAVar = replicatePAVar
, emptyPAVar = emptyPAVar
, packPAVar = packPAVar
, combinePAVars = combinePAVars
, closureVar = closureVar
, applyVar = applyVar
, liftedClosureVar = liftedClosureVar
, liftedApplyVar = liftedApplyVar
, replicatePDVar = replicatePDVar
, emptyPDVar = emptyPDVar
, packPDVar = packPDVar
, combinePDVars = combinePDVars
, scalarClass = scalarClass
, scalarZips = scalarZips
, closureCtrFuns = closureCtrFuns
......@@ -245,11 +274,22 @@ initBuiltins pkg
dph_PArray = dph_PArray
, dph_Repr = dph_Repr
, dph_Closure = dph_Closure
, dph_Selector = dph_Selector
, dph_Unboxed = dph_Unboxed
, dph_Scalar = dph_Scalar
})
= dph_Modules pkg
numbered :: String -> Int -> Int -> [FastString]
numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
mk_elements (i,j)
= do
v <- externalVar dph_Selector
$ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
return ((i,j), Var v)
initBuiltinVars :: Builtins -> DsM [(Var, Var)]
initBuiltinVars (Builtins { dphModules = mods })
......@@ -302,7 +342,7 @@ preludeVars (Modules { dph_Combinators = dph_Combinators
, mk' dph_Prelude_Int "mod" "modV"
, mk' dph_Prelude_Int "sqrt" "sqrtV"
, mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
, mk' dph_Prelude_Int "upToP" "upToPA"
-- , mk' dph_Prelude_Int "upToP" "upToPA"
]
++ vars_Ord dph_Prelude_Int
++ vars_Num dph_Prelude_Int
......@@ -456,7 +496,6 @@ builtinPRs bi@(Builtins { dphModules = mods }) =
mk (tyConName unitTyCon) (dph_Repr mods) (fsLit "dPR_Unit")
, mk (tyConName $ voidTyCon bi) (dph_Repr mods) (fsLit "dPR_Void")
, mk (tyConName $ wrapTyCon bi) (dph_Repr mods) (fsLit "dPR_Wrap")
, mk (tyConName $ enumerationTyCon bi) (dph_Repr mods) (fsLit "dPR_Enumeration")
, mk (tyConName $ closureTyCon bi) (dph_Closure mods) (fsLit "dPR_Clo")
-- temporary
......@@ -572,6 +611,12 @@ externalVar :: Module -> FastString -> DsM Var
externalVar mod fs
= dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
externalFun :: Module -> FastString -> DsM CoreExpr
externalFun mod fs
= do
var <- externalVar mod fs
return $ Var var
externalTyCon :: Module -> FastString -> DsM TyCon
externalTyCon mod fs
= dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
......
......@@ -4,11 +4,13 @@ module VectCore (
vectorised, lifted,
mapVect,
vVarType,
vNonRec, vRec,
vVar, vType, vNote, vLet,
vLams, vLamsWithoutLC, vVarApps,
vCaseDEFAULT, vCaseProd, vInlineMe
vCaseDEFAULT, vInlineMe
) where
#include "HsVersions.h"
......@@ -38,6 +40,9 @@ mapVect f (x,y) = (f x, f y)
zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
vVarType :: VVar -> Type
vVarType = varType . vectorised
vVar :: VVar -> VExpr
vVar = mapVect Var
......@@ -81,17 +86,6 @@ vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
where
mkDEFAULT e = [(DEFAULT, [], e)]
vCaseProd :: VExpr -> Type -> Type
-> DataCon -> DataCon -> [Var] -> [VVar] -> VExpr -> VExpr
vCaseProd (vscrut, lscrut) vty lty vdc ldc sh_bndrs bndrs
(vbody,lbody)
= (mkWildCase vscrut (exprType vscrut) vty
[(DataAlt vdc, vbndrs, vbody)],
mkWildCase lscrut (exprType lscrut) lty
[(DataAlt ldc, sh_bndrs ++ lbndrs, lbody)])
where
(vbndrs, lbndrs) = unzip bndrs
vInlineMe :: VExpr -> VExpr
vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr)
......@@ -7,10 +7,11 @@ module VectMonad (
initV, cantVectorise, maybeCantVectorise, maybeCantVectoriseM,
liftDs,
cloneName, cloneId, cloneVar,
newExportedVar, newLocalVar, newDummyVar, newTyVar,
newExportedVar, newLocalVar, newLocalVars, newDummyVar, newTyVar,
Builtins(..), sumTyCon, prodTyCon,
combinePAVar, scalarZip, closureCtrFun,
Builtins(..), sumTyCon, prodTyCon, prodDataCon,
selTy, selReplicate, selPick, selElements,
combinePDVar, scalarZip, closureCtrFun,
builtin, builtins,
GlobalEnv(..),
......@@ -374,6 +375,9 @@ newLocalVar fs ty
u <- liftDs newUnique
return $ mkSysLocal fs u ty
newLocalVars :: FastString -> [Type] -> VM [Var]
newLocalVars fs = mapM (newLocalVar fs)
newDummyVar :: Type -> VM Var
newDummyVar = newLocalVar (fsLit "vv")
......
This diff is collapsed.
......@@ -5,13 +5,15 @@ module VectUtils (
newLocalVVar,
mkBuiltinCo,
mkPADictType, mkPArrayType, mkPReprType,
mkBuiltinCo, voidType,
mkPADictType, mkPArrayType, mkPDataType, mkPReprType, mkPArray,
parrayReprTyCon, parrayReprDataCon, mkVScrut,
pdataReprTyCon, pdataReprDataCon, mkVScrut,
prDFunOfTyCon,
paDictArgType, paDictOfType, paDFunType,
paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
paMethod, mkPR, replicatePD, emptyPD, packPD,
combinePD,
liftPD,
zipScalars, scalarClosure,
polyAbstract, polyApply, polyVApply,
hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
......@@ -95,21 +97,8 @@ mkBuiltinTyConApps get_tc tys ty
where
mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
{-
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 _ dft [] = return dft
mkBuiltinTyConApps1 get_tc _ tys
= do
tc <- builtin get_tc
case tys of
[] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
_ -> return $ foldr1 (mk tc) tys
where
mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
-}
voidType :: VM Type
voidType = mkBuiltinTyConApp voidTyCon []
mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon
......@@ -130,27 +119,38 @@ mkPArrayType ty
Nothing -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
mkPDataType :: Type -> VM Type
mkPDataType ty = mkBuiltinTyConApp pdataTyCon [ty]
mkPArray :: Type -> CoreExpr -> CoreExpr -> VM CoreExpr
mkPArray ty len dat = do
tc <- builtin parrayTyCon
let [dc] = tyConDataCons tc
return $ mkConApp dc [Type ty, len, dat]
mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
mkBuiltinCo get_tc
= do
tc <- builtin get_tc
return $ mkTyConApp tc []
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
pdataReprTyCon :: Type -> VM (TyCon, [Type])
pdataReprTyCon ty = builtin pdataTyCon >>= (`lookupFamInst` [ty])
parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
pdataReprDataCon :: Type -> VM (DataCon, [Type])
pdataReprDataCon ty
= do
(tc, arg_tys) <- parrayReprTyCon ty
(tc, arg_tys) <- pdataReprTyCon ty
let [dc] = tyConDataCons tc
return (dc, arg_tys)
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut :: VExpr -> VM (CoreExpr, CoreExpr, TyCon, [Type])
mkVScrut (ve, le)
= do
(tc, arg_tys) <- parrayReprTyCon (exprType ve)
return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
(tc, arg_tys) <- pdataReprTyCon ty
return (ve, unwrapFamInstScrut tc arg_tys le, tc, arg_tys)
where
ty = exprType ve
prDFunOfTyCon :: TyCon -> VM CoreExpr
prDFunOfTyCon tycon
......@@ -217,20 +217,14 @@ paDFunApply dfun tys
type PAMethod = (Builtins -> Var, String)
pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
pa_length = (lengthPAVar, "lengthPA")
pa_replicate = (replicatePAVar, "replicatePA")
pa_empty = (emptyPAVar, "emptyPA")
pa_pack = (packPAVar, "packPA")
paMethod :: PAMethod -> Type -> VM CoreExpr
paMethod (_method, name) ty
paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
paMethod _ name ty
| Just tycon <- splitPrimTyCon ty
= liftM Var
. maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
$ lookupPrimMethod tycon name
paMethod (method, _name) ty
paMethod method _ ty
= do
fn <- builtin method
dict <- paDictOfType ty
......@@ -243,33 +237,30 @@ mkPR ty
dict <- paDictOfType ty
return $ mkApps (Var fn) [Type ty, dict]
lengthPA :: Type -> CoreExpr -> VM CoreExpr
lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
(paMethod pa_replicate (exprType x))
replicatePD :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePD len x = liftM (`mkApps` [len,x])
(paMethod replicatePDVar "replicatePD" (exprType x))
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod pa_empty
emptyPD :: Type -> VM CoreExpr
emptyPD = paMethod emptyPDVar "emptyPD"
packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
(paMethod pa_pack ty)
packPD :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
packPD ty xs len sel = liftM (`mkApps` [xs, len, sel])
(paMethod packPDVar "packPD" ty)
combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
combinePD :: Type -> CoreExpr -> CoreExpr -> [CoreExpr]
-> VM CoreExpr
combinePA ty len sel is xs
= liftM (`mkApps` (len : sel : is : xs))
(paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
combinePD ty len sel xs
= liftM (`mkApps` (len : sel : xs))
(paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
where
n = length xs
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
liftPD :: CoreExpr -> VM CoreExpr
liftPD x
= do
lc <- builtin liftingContext
replicatePA (Var lc) x
replicatePD (Var lc) x
zipScalars :: [Type] -> Type -> VM CoreExpr
zipScalars arg_tys res_ty
......@@ -292,7 +283,7 @@ scalarClosure arg_tys res_ty scalar_fun array_fun
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
= do
lty <- mkPArrayType vty
lty <- mkPDataType vty
vv <- newLocalVar fs vty
lv <- newLocalVar fs lty
return (vv,lv)
......@@ -377,18 +368,19 @@ mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)