Commit bfddbe30 authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au

Rewrite vectorisation of product DataCon workers

parent 9f28e733
......@@ -295,19 +295,22 @@ arrShapeTys repr = return [intPrimTy]
arrShapeVars :: Repr -> VM [Var]
arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
arrReprTys :: Repr -> VM [[Type]]
arrReprTys (SumRepr { sum_components = prods })
= mapM arrProdTys prods
arrReprTys prod
= do
tys <- arrProdTys prod
return [tys]
replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
replicateShape (ProdRepr {}) len _ = return [len]
arrProdTys (ProdRepr { prod_components = tys })
= mapM mkPArrayType (mk_types tys)
where
mk_types [] = [unitTy]
mk_types tys = tys
arrReprElemTys :: Repr -> [[Type]]
arrReprElemTys (SumRepr { sum_components = prods })
= map arrProdElemTys prods
arrReprElemTys prod@(ProdRepr {})
= [arrProdElemTys prod]
arrProdElemTys (ProdRepr { prod_components = [] })
= [unitTy]
arrProdElemTys (ProdRepr { prod_components = tys })
= tys
arrReprTys :: Repr -> VM [[Type]]
arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
arrReprVars :: Repr -> VM [[Var]]
arrReprVars repr
......@@ -658,11 +661,7 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
= do
shape <- tyConShape vect_tc
repr <- mkRepr vect_tc
sequence_ (zipWith4 (vectDataConWorker shape vect_tc arr_tc arr_dc)
orig_dcs
vect_dcs
(inits repr_tys)
(tails repr_tys))
vectDataConWorkers repr orig_tc vect_tc arr_tc
dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
binds <- takeHoisted
return $ (dfun, dict) : binds
......@@ -673,6 +672,75 @@ buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
repr_tys = map dataConRepArgTys vect_dcs
vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
-> VM ()
vectDataConWorkers repr orig_tc vect_tc arr_tc
= do
bs <- sequence
. zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
$ zipWith4 mk_data_con (tyConDataCons vect_tc)
rep_tys
(inits arr_tys)
(tail $ tails arr_tys)
mapM_ (uncurry hoistBinding) bs
where
tyvars = tyConTyVars vect_tc
var_tys = mkTyVarTys tyvars
ty_args = map Type var_tys
res_ty = mkTyConApp vect_tc var_tys
rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
arr_tys = arrReprElemTys repr
[arr_dc] = tyConDataCons arr_tc
mk_data_con con tys pre post
= liftM2 (,) (vect_data_con con)
(lift_data_con tys (concat pre)
(concat post)
(mkDataConTag con))
vect_data_con con = return $ mkConApp con ty_args
lift_data_con tys pre_tys post_tys tag
= do
len <- builtin liftingContext
args <- mapM (newLocalVar FSLIT("xs"))
=<< mapM mkPArrayType tys
shape <- replicateShape repr (Var len) tag
repr <- mk_arr_repr (Var len) (map Var args)
pre <- mapM emptyPA pre_tys
post <- mapM emptyPA post_tys
return . mkLams (len : args)
. wrapFamInstBody arr_tc var_tys
. mkConApp arr_dc
$ ty_args ++ shape ++ pre ++ repr ++ post
mk_arr_repr len []
= do
units <- replicatePA len (Var unitDataConId)
return [units]
mk_arr_repr len arrs = return arrs
def_worker data_con arg_tys mk_body
= do
body <- closedV
. inBind orig_worker
. polyAbstract tyvars $ \abstract ->
liftM (abstract . vectorised)
$ buildClosures tyvars [] arg_tys res_ty mk_body
vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
defGlobalVar orig_worker vect_worker
return (vect_worker, body)
where
orig_worker = dataConWorkId data_con
vectDataConWorker :: Shape -> TyCon -> TyCon -> DataCon
-> DataCon -> DataCon -> [[Type]] -> [[Type]]
-> VM ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment