Commit facf3d6c authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au

Fix vectorisation of nullary data constructors

parent 7ab46257
...@@ -46,9 +46,11 @@ data Builtins = Builtins { ...@@ -46,9 +46,11 @@ data Builtins = Builtins {
, prTyCon :: TyCon , prTyCon :: TyCon
, prDataCon :: DataCon , prDataCon :: DataCon
, parrayIntPrimTyCon :: TyCon , parrayIntPrimTyCon :: TyCon
, voidTyCon :: TyCon
, wrapTyCon :: TyCon , wrapTyCon :: TyCon
, sumTyCons :: Array Int TyCon , sumTyCons :: Array Int TyCon
, closureTyCon :: TyCon , closureTyCon :: TyCon
, voidVar :: Var
, mkPRVar :: Var , mkPRVar :: Var
, mkClosureVar :: Var , mkClosureVar :: Var
, applyClosureVar :: Var , applyClosureVar :: Var
...@@ -71,8 +73,9 @@ sumTyCon n bi ...@@ -71,8 +73,9 @@ sumTyCon n bi
prodTyCon :: Int -> Builtins -> TyCon prodTyCon :: Int -> Builtins -> TyCon
prodTyCon n bi prodTyCon n bi
| n == 0 = voidTyCon bi
| n == 1 = wrapTyCon bi | n == 1 = wrapTyCon bi
| n >= 0 && n <= mAX_NDP_PROD = tupleTyCon Boxed n | n >= 2 && n <= mAX_NDP_PROD = tupleTyCon Boxed n
| otherwise = pprPanic "prodTyCon" (ppr n) | otherwise = pprPanic "prodTyCon" (ppr n)
initBuiltins :: DsM Builtins initBuiltins :: DsM Builtins
...@@ -87,12 +90,14 @@ initBuiltins ...@@ -87,12 +90,14 @@ initBuiltins
parrayIntPrimTyCon <- dsLookupTyCon parrayIntPrimTyConName parrayIntPrimTyCon <- dsLookupTyCon parrayIntPrimTyConName
closureTyCon <- dsLookupTyCon closureTyConName closureTyCon <- dsLookupTyCon closureTyConName
voidTyCon <- lookupExternalTyCon nDP_REPR FSLIT("Void")
wrapTyCon <- lookupExternalTyCon nDP_REPR FSLIT("Wrap") wrapTyCon <- lookupExternalTyCon nDP_REPR FSLIT("Wrap")
sum_tcs <- mapM (lookupExternalTyCon nDP_REPR) sum_tcs <- mapM (lookupExternalTyCon nDP_REPR)
[mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]] [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]]
let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs
voidVar <- lookupExternalVar nDP_REPR FSLIT("void")
mkPRVar <- dsLookupGlobalId mkPRName mkPRVar <- dsLookupGlobalId mkPRName
mkClosureVar <- dsLookupGlobalId mkClosureName mkClosureVar <- dsLookupGlobalId mkClosureName
applyClosureVar <- dsLookupGlobalId applyClosureName applyClosureVar <- dsLookupGlobalId applyClosureName
...@@ -117,9 +122,11 @@ initBuiltins ...@@ -117,9 +122,11 @@ initBuiltins
, prTyCon = prTyCon , prTyCon = prTyCon
, prDataCon = prDataCon , prDataCon = prDataCon
, parrayIntPrimTyCon = parrayIntPrimTyCon , parrayIntPrimTyCon = parrayIntPrimTyCon
, voidTyCon = voidTyCon
, wrapTyCon = wrapTyCon , wrapTyCon = wrapTyCon
, sumTyCons = sumTyCons , sumTyCons = sumTyCons
, closureTyCon = closureTyCon , closureTyCon = closureTyCon
, voidVar = voidVar
, mkPRVar = mkPRVar , mkPRVar = mkPRVar
, mkClosureVar = mkClosureVar , mkClosureVar = mkClosureVar
, applyClosureVar = applyClosureVar , applyClosureVar = applyClosureVar
...@@ -154,16 +161,18 @@ initBuiltinDicts ps ...@@ -154,16 +161,18 @@ initBuiltinDicts ps
where where
(tcs, mods, fss) = unzip3 ps (tcs, mods, fss) = unzip3 ps
initBuiltinPAs = initBuiltinDicts builtinPAs initBuiltinPAs = initBuiltinDicts . builtinPAs
builtinPAs :: [(Name, Module, FastString)] builtinPAs :: Builtins -> [(Name, Module, FastString)]
builtinPAs = [ builtinPAs bi
mk closureTyConName nDP_CLOSURE FSLIT("dPA_Clo") = [
, mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit") mk closureTyConName nDP_CLOSURE FSLIT("dPA_Clo")
, mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPA_Void")
, mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit")
, mk intTyConName nDP_INSTANCES FSLIT("dPA_Int") , mk intTyConName nDP_INSTANCES FSLIT("dPA_Int")
] ]
++ tups ++ tups
where where
mk name mod fs = (name, mod, fs) mk name mod fs = (name, mod, fs)
...@@ -178,6 +187,7 @@ builtinPRs :: Builtins -> [(Name, Module, FastString)] ...@@ -178,6 +187,7 @@ builtinPRs :: Builtins -> [(Name, Module, FastString)]
builtinPRs bi = builtinPRs bi =
[ [
mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit") mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit")
, mk (tyConName $ voidTyCon bi) nDP_REPR FSLIT("dPR_Void")
, mk (tyConName $ wrapTyCon bi) nDP_REPR FSLIT("dPR_Wrap") , mk (tyConName $ wrapTyCon bi) nDP_REPR FSLIT("dPR_Wrap")
, mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo") , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo")
......
...@@ -463,7 +463,7 @@ initV hsc_env guts info p ...@@ -463,7 +463,7 @@ initV hsc_env guts info p
do do
builtins <- initBuiltins builtins <- initBuiltins
builtin_tycons <- initBuiltinTyCons builtin_tycons <- initBuiltinTyCons
builtin_pas <- initBuiltinPAs builtin_pas <- initBuiltinPAs builtins
builtin_prs <- initBuiltinPRs builtins builtin_prs <- initBuiltinPRs builtins
eps <- ioToIOEnv $ hscEPS hsc_env eps <- ioToIOEnv $ hscEPS hsc_env
......
...@@ -226,6 +226,20 @@ data Repr = ProdRepr { ...@@ -226,6 +226,20 @@ data Repr = ProdRepr {
| IdRepr Type | IdRepr Type
| VoidRepr {
void_tycon :: TyCon
, void_bottom :: CoreExpr
}
mkVoid :: VM Repr
mkVoid = do
tycon <- builtin voidTyCon
var <- builtin voidVar
return $ VoidRepr {
void_tycon = tycon
, void_bottom = Var var
}
mkProduct :: [Type] -> VM Repr mkProduct :: [Type] -> VM Repr
mkProduct tys mkProduct tys
= do = do
...@@ -246,6 +260,7 @@ mkProduct tys ...@@ -246,6 +260,7 @@ mkProduct tys
arity = length tys arity = length tys
mkSubProduct :: [Type] -> VM Repr mkSubProduct :: [Type] -> VM Repr
mkSubProduct [] = mkVoid
mkSubProduct [ty] = return $ IdRepr ty mkSubProduct [ty] = return $ IdRepr ty
mkSubProduct tys = mkProduct tys mkSubProduct tys = mkProduct tys
...@@ -275,6 +290,7 @@ reprType (ProdRepr { prod_tycon = tycon, prod_components = tys }) ...@@ -275,6 +290,7 @@ reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
reprType (SumRepr { sum_tycon = tycon, sum_components = reprs }) reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
= mkTyConApp tycon (map reprType reprs) = mkTyConApp tycon (map reprType reprs)
reprType (IdRepr ty) = ty reprType (IdRepr ty) = ty
reprType (VoidRepr { void_tycon = tycon }) = mkTyConApp tycon []
arrReprType :: Repr -> VM Type arrReprType :: Repr -> VM Type
arrReprType = mkPArrayType . reprType arrReprType = mkPArrayType . reprType
...@@ -286,6 +302,7 @@ arrShapeTys (SumRepr {}) ...@@ -286,6 +302,7 @@ arrShapeTys (SumRepr {})
return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []] return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
arrShapeTys (ProdRepr {}) = return [intPrimTy] arrShapeTys (ProdRepr {}) = return [intPrimTy]
arrShapeTys (IdRepr _) = return [] arrShapeTys (IdRepr _) = return []
arrShapeTys (VoidRepr {}) = return [intPrimTy]
arrShapeVars :: Repr -> VM [Var] arrShapeVars :: Repr -> VM [Var]
arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
...@@ -298,22 +315,31 @@ replicateShape (SumRepr {}) len tag ...@@ -298,22 +315,31 @@ replicateShape (SumRepr {}) len tag
up <- builtin upToPAIntPrimVar up <- builtin upToPAIntPrimVar
return [len, Var rep `mkApps` [len, tag], Var up `App` len] return [len, Var rep `mkApps` [len, tag], Var up `App` len]
replicateShape (IdRepr _) _ _ = return [] replicateShape (IdRepr _) _ _ = return []
replicateShape (VoidRepr {}) len _ = return [len]
arrReprElemTys :: Repr -> [[Type]] arrReprElemTys :: Repr -> VM [[Type]]
arrReprElemTys (SumRepr { sum_components = prods }) arrReprElemTys (SumRepr { sum_components = prods })
= map arrProdElemTys prods = mapM arrProdElemTys prods
arrReprElemTys prod@(ProdRepr {}) arrReprElemTys prod@(ProdRepr {})
= [arrProdElemTys prod] = do
arrReprElemTys (IdRepr ty) = [[ty]] tys <- arrProdElemTys prod
return [tys]
arrReprElemTys (IdRepr ty) = return [[ty]]
arrReprElemTys (VoidRepr { void_tycon = tycon })
= return [[mkTyConApp tycon []]]
arrProdElemTys (ProdRepr { prod_components = [] }) arrProdElemTys (ProdRepr { prod_components = [] })
= [unitTy] = do
void <- builtin voidTyCon
return [mkTyConApp void []]
arrProdElemTys (ProdRepr { prod_components = tys }) arrProdElemTys (ProdRepr { prod_components = tys })
= tys = return tys
arrProdElemTys (IdRepr ty) = [ty] arrProdElemTys (IdRepr ty) = return [ty]
arrProdElemTys (VoidRepr { void_tycon = tycon })
= return [mkTyConApp tycon []]
arrReprTys :: Repr -> VM [[Type]] arrReprTys :: Repr -> VM [[Type]]
arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys arrReprTys repr = mapM (mapM mkPArrayType) =<< arrReprElemTys repr
arrReprVars :: Repr -> VM [[Var]] arrReprVars :: Repr -> VM [[Var]]
arrReprVars repr arrReprVars repr
...@@ -376,6 +402,10 @@ buildToPRepr repr vect_tc prepr_tc _ ...@@ -376,6 +402,10 @@ buildToPRepr repr vect_tc prepr_tc _
var <- newLocalVar FSLIT("y") ty var <- newLocalVar FSLIT("y") ty
return ([var], Var var) return ([var], Var var)
prod_alt (VoidRepr { void_bottom = bottom })
= return ([], bottom)
buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromPRepr repr vect_tc prepr_tc _ buildFromPRepr repr vect_tc prepr_tc _
= do = do
...@@ -418,6 +448,9 @@ buildFromPRepr repr vect_tc prepr_tc _ ...@@ -418,6 +448,9 @@ buildFromPRepr repr vect_tc prepr_tc _
from_prod (IdRepr _) con expr from_prod (IdRepr _) con expr
= return $ con `App` expr = return $ con `App` expr
from_prod (VoidRepr {}) con expr
= return con
buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToArrPRepr repr vect_tc prepr_tc arr_tc buildToArrPRepr repr vect_tc prepr_tc arr_tc
= do = do
...@@ -483,8 +516,9 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc ...@@ -483,8 +516,9 @@ buildToArrPRepr repr vect_tc prepr_tc arr_tc
. mkConApp data_con . mkConApp data_con
$ map Type tys ++ len : map Var repr_vars $ map Type tys ++ len : map Var repr_vars
to_prod [var] (IdRepr ty) to_prod [var] (IdRepr ty) = return (Var var)
= return (Var var) to_prod [var] (VoidRepr {}) = return (Var var)
buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromArrPRepr repr vect_tc prepr_tc arr_tc buildFromArrPRepr repr vect_tc prepr_tc arr_tc
...@@ -571,7 +605,17 @@ buildFromArrPRepr repr vect_tc prepr_tc arr_tc ...@@ -571,7 +605,17 @@ buildFromArrPRepr repr vect_tc prepr_tc arr_tc
body body
= return $ Let (NonRec repr_var expr) body = return $ Let (NonRec repr_var expr) body
from_prod (VoidRepr {})
expr
shape_vars
[repr_var]
res_ty
body
= return $ Let (NonRec repr_var expr) body
buildPRDictRepr :: Repr -> VM CoreExpr buildPRDictRepr :: Repr -> VM CoreExpr
buildPRDictRepr (VoidRepr { void_tycon = tycon })
= prDFunOfTyCon tycon
buildPRDictRepr (IdRepr ty) = mkPR ty buildPRDictRepr (IdRepr ty) = mkPR ty
buildPRDictRepr (ProdRepr { buildPRDictRepr (ProdRepr {
prod_components = tys prod_components = tys
...@@ -679,6 +723,7 @@ vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon ...@@ -679,6 +723,7 @@ vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
-> VM () -> VM ()
vectDataConWorkers repr orig_tc vect_tc arr_tc vectDataConWorkers repr orig_tc vect_tc arr_tc
= do = do
arr_tys <- arrReprElemTys repr
bs <- sequence bs <- sequence
. zipWith3 def_worker (tyConDataCons orig_tc) rep_tys . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
$ zipWith4 mk_data_con (tyConDataCons vect_tc) $ zipWith4 mk_data_con (tyConDataCons vect_tc)
...@@ -694,7 +739,6 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc ...@@ -694,7 +739,6 @@ vectDataConWorkers repr orig_tc vect_tc arr_tc
res_ty = mkTyConApp vect_tc var_tys res_ty = mkTyConApp vect_tc var_tys
rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
arr_tys = arrReprElemTys repr
[arr_dc] = tyConDataCons arr_tc [arr_dc] = tyConDataCons arr_tc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment