Commit 9f28e733 authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au
Browse files

Rewrite generation of PA dictionaries

parent eaaecbae
......@@ -36,7 +36,7 @@ import Digraph ( SCC(..), stronglyConnComp )
import Outputable
import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_ )
import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
import Data.List ( inits, tails, zipWith4, zipWith5 )
-- ----------------------------------------------------------------------------
......@@ -208,399 +208,353 @@ buildPReprTyCon orig_tc vect_tc
where
tyvars = tyConTyVars vect_tc
data TyConRepr = ProdRepr {
repr_prod_arg_tys :: [Type]
, repr_prod_tycon :: TyCon
, repr_prod_data_con :: DataCon
, repr_prod_arr_tycon :: TyCon
, repr_prod_arr_data_con :: DataCon
, repr_type :: Type
}
| SumRepr {
repr_tys :: [[Type]]
, repr_prod_tycons :: [Maybe TyCon]
, repr_prod_data_cons :: [Maybe DataCon]
, repr_prod_tys :: [Type]
, repr_sum_tycon :: TyCon
, repr_type :: Type
}
arrShapeTys :: TyConRepr -> VM [Type]
arrShapeTys (ProdRepr {}) = return [intPrimTy]
arrShapeTys (SumRepr {})
= do
uarr <- builtin uarrTyCon
return [intPrimTy, mkTyConApp uarr [intTy]]
arrReprTys :: TyConRepr -> VM [Type]
arrReprTys (ProdRepr { repr_prod_arg_tys = tys })
= mapM mkPArrayType tys
arrReprTys (SumRepr { repr_tys = tys })
= concat `liftM` mapM (mapM mkPArrayType) (map mk_prod tys)
where
mk_prod [] = [unitTy]
mk_prod tys = tys
mkTyConRepr :: TyCon -> VM TyConRepr
mkTyConRepr vect_tc
| is_product
= let
[prod_arg_tys] = repr_tys
arity = length prod_arg_tys
in
do
prod_tycon <- builtin (prodTyCon arity)
let [prod_data_con] = tyConDataCons prod_tycon
data Repr = ProdRepr {
prod_components :: [Type]
, prod_tycon :: TyCon
, prod_data_con :: DataCon
, prod_arr_tycon :: TyCon
, prod_arr_data_con :: DataCon
}
(arr_tycon, _) <- parrayReprTyCon
. mkTyConApp prod_tycon
$ replicate arity unitTy
| SumRepr {
sum_components :: [Repr]
, sum_tycon :: TyCon
, sum_arr_tycon :: TyCon
, sum_arr_data_con :: DataCon
}
mkProduct :: [Type] -> VM Repr
mkProduct tys
= do
tycon <- builtin (prodTyCon arity)
let [data_con] = tyConDataCons tycon
(arr_tycon, _) <- parrayReprTyCon $ mkTyConApp tycon tys
let [arr_data_con] = tyConDataCons arr_tycon
return $ ProdRepr {
repr_prod_arg_tys = prod_arg_tys
, repr_prod_tycon = prod_tycon
, repr_prod_data_con = prod_data_con
, repr_prod_arr_tycon = arr_tycon
, repr_prod_arr_data_con = arr_data_con
, repr_type = mkTyConApp prod_tycon prod_arg_tys
prod_components = tys
, prod_tycon = tycon
, prod_data_con = data_con
, prod_arr_tycon = arr_tycon
, prod_arr_data_con = arr_data_con
}
where
arity = length tys
| otherwise
mkSum :: [Repr] -> VM Repr
mkSum [repr] = return repr
mkSum reprs
= do
uarr <- builtin uarrTyCon
prod_tycons <- mapM (mk_tycon prodTyCon) repr_tys
let prod_tys = zipWith mk_tc_app_maybe prod_tycons repr_tys
sum_tycon <- builtin (sumTyCon $ length repr_tys)
arr_repr_tys <- mapM (mapM mkPArrayType . arr_repr_elem_tys) repr_tys
tycon <- builtin (sumTyCon arity)
(arr_tycon, _) <- parrayReprTyCon
. mkTyConApp tycon
$ map reprType reprs
let [arr_data_con] = tyConDataCons arr_tycon
return $ SumRepr {
repr_tys = repr_tys
, repr_prod_tycons = prod_tycons
, repr_prod_data_cons = map (fmap mk_single_datacon) prod_tycons
, repr_prod_tys = prod_tys
, repr_sum_tycon = sum_tycon
, repr_type = mkTyConApp sum_tycon prod_tys
sum_components = reprs
, sum_tycon = tycon
, sum_arr_tycon = arr_tycon
, sum_arr_data_con = arr_data_con
}
where
tyvars = tyConTyVars vect_tc
data_cons = tyConDataCons vect_tc
repr_tys = map dataConRepArgTys data_cons
is_product | [_] <- data_cons = True
| otherwise = False
arity = length reprs
mk_shape uarr = intPrimTy : mk_sel uarr
reprProducts :: Repr -> [Repr]
reprProducts (SumRepr { sum_components = rs }) = rs
reprProducts repr = [repr]
mk_sel uarr | is_product = []
| otherwise = [uarr_int, uarr_int]
where
uarr_int = mkTyConApp uarr [intTy]
reprType :: Repr -> Type
reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
= mkTyConApp tycon tys
reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
= mkTyConApp tycon (map reprType reprs)
mk_tycon get_tc tys
| n > 1 = builtin (Just . get_tc n)
| otherwise = return Nothing
where n = length tys
arrReprType :: Repr -> VM Type
arrReprType = mkPArrayType . reprType
mk_single_datacon tc | [dc] <- tyConDataCons tc = dc
reprTys :: Repr -> [[Type]]
reprTys (SumRepr { sum_components = prods }) = map prodTys prods
reprTys prod = [prodTys prod]
mk_tc_app_maybe Nothing [] = unitTy
mk_tc_app_maybe Nothing [ty] = ty
mk_tc_app_maybe (Just tc) tys = mkTyConApp tc tys
prodTys (ProdRepr { prod_components = tys }) = tys
arr_repr_elem_tys [] = [unitTy]
arr_repr_elem_tys tys = tys
reprVars :: Repr -> VM [[Var]]
reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys
buildPReprType :: TyCon -> VM Type
buildPReprType = liftM repr_type . mkTyConRepr
buildToPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToPRepr (ProdRepr {
repr_prod_arg_tys = prod_arg_tys
, repr_prod_data_con = prod_data_con
, repr_type = repr_type
})
vect_tc prepr_tc _
arrShapeTys :: Repr -> VM [Type]
arrShapeTys (SumRepr {})
= do
arg <- newLocalVar FSLIT("x") arg_ty
vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
uarr <- builtin uarrTyCon
return [intPrimTy, mkTyConApp uarr [intTy]]
arrShapeTys repr = return [intPrimTy]
return . Lam arg
. wrapFamInstBody prepr_tc var_tys
$ Case (Var arg) (mkWildId arg_ty) repr_type
[(DataAlt data_con, vars,
mkConApp prod_data_con (map Type prod_arg_tys ++ map Var vars))]
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
arg_ty = mkTyConApp vect_tc var_tys
[data_con] = tyConDataCons vect_tc
buildToPRepr (SumRepr {
repr_tys = repr_tys
, repr_prod_data_cons = prod_data_cons
, repr_prod_tys = prod_tys
, repr_sum_tycon = sum_tycon
, repr_type = repr_type
})
vect_tc prepr_tc _
= do
arg <- newLocalVar FSLIT("x") arg_ty
vars <- mapM (mapM (newLocalVar FSLIT("x"))) repr_tys
arrShapeVars :: Repr -> VM [Var]
arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
return . Lam arg
. wrapFamInstBody prepr_tc var_tys
. Case (Var arg) (mkWildId arg_ty) repr_type
. zipWith4 mk_alt data_cons vars sum_data_cons
. zipWith3 mk_prod prod_data_cons repr_tys $ map (map Var) vars
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
arg_ty = mkTyConApp vect_tc var_tys
data_cons = tyConDataCons vect_tc
sum_data_cons = tyConDataCons sum_tycon
mk_alt dc vars sum_dc expr = (DataAlt dc, vars,
mkConApp sum_dc (map Type prod_tys ++ [expr]))
mk_prod _ _ [] = Var unitDataConId
mk_prod _ _ [expr] = expr
mk_prod (Just dc) tys exprs = mkConApp dc (map Type tys ++ exprs)
buildFromPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromPRepr (ProdRepr {
repr_prod_arg_tys = prod_arg_tys
, repr_prod_data_con = prod_data_con
, repr_type = repr_type
})
vect_tc prepr_tc _
arrReprTys :: Repr -> VM [[Type]]
arrReprTys (SumRepr { sum_components = prods })
= mapM arrProdTys prods
arrReprTys prod
= do
arg_ty <- mkPReprType res_ty
arg <- newLocalVar FSLIT("x") arg_ty
vars <- mapM (newLocalVar FSLIT("x")) prod_arg_tys
tys <- arrProdTys prod
return [tys]
return . Lam arg
$ Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
(mkWildId repr_type)
res_ty
[(DataAlt prod_data_con, vars,
mkConApp data_con (map Type var_tys ++ map Var vars))]
arrProdTys (ProdRepr { prod_components = tys })
= mapM mkPArrayType (mk_types tys)
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
ty_args = map Type var_tys
res_ty = mkTyConApp vect_tc var_tys
[data_con] = tyConDataCons vect_tc
buildFromPRepr (SumRepr {
repr_tys = repr_tys
, repr_prod_data_cons = prod_data_cons
, repr_prod_tys = prod_tys
, repr_sum_tycon = sum_tycon
, repr_type = repr_type
})
vect_tc prepr_tc _
mk_types [] = [unitTy]
mk_types tys = tys
arrReprVars :: Repr -> VM [[Var]]
arrReprVars repr
= mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr
mkRepr :: TyCon -> VM Repr
mkRepr vect_tc
= mkSum
=<< mapM mkProduct (map dataConRepArgTys $ tyConDataCons vect_tc)
buildPReprType :: TyCon -> VM Type
buildPReprType = liftM reprType . mkRepr
buildToPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToPRepr repr vect_tc prepr_tc _
= do
arg_ty <- mkPReprType res_ty
arg <- newLocalVar FSLIT("x") arg_ty
result <- to_repr repr (Var arg)
liftM (Lam arg
. Case (unwrapFamInstScrut prepr_tc var_tys (Var arg))
(mkWildId repr_type)
res_ty
. zipWith mk_alt sum_data_cons)
(sequence
$ zipWith4 un_prod data_cons prod_data_cons prod_tys repr_tys)
return . Lam arg
. wrapFamInstBody prepr_tc var_tys
$ result
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
ty_args = map Type var_tys
res_ty = mkTyConApp vect_tc var_tys
data_cons = tyConDataCons vect_tc
var_tys = mkTyVarTys $ tyConTyVars vect_tc
arg_ty = mkTyConApp vect_tc var_tys
res_ty = reprType repr
sum_data_cons = tyConDataCons sum_tycon
cons = tyConDataCons vect_tc
[con] = cons
un_prod dc _ _ []
to_repr (SumRepr { sum_components = prods
, sum_tycon = tycon })
expr
= do
var <- newLocalVar FSLIT("u") unitTy
return (var, mkConApp dc ty_args)
un_prod dc _ _ [ty]
(vars, bodies) <- mapAndUnzipM prod_alt prods
return . Case expr (mkWildId (exprType expr)) res_ty
$ zipWith4 mk_alt cons vars (tyConDataCons tycon) bodies
where
mk_alt con vars sum_con body
= (DataAlt con, vars, mkConApp sum_con (ty_args ++ [body]))
ty_args = map (Type . reprType) prods
to_repr prod expr
= do
var <- newLocalVar FSLIT("x") ty
return (var, mkConApp dc (ty_args ++ [Var var]))
(vars, body) <- prod_alt prod
return $ Case expr (mkWildId (exprType expr)) res_ty
[(DataAlt con, vars, body)]
un_prod dc (Just prod_dc) prod_ty tys
prod_alt (ProdRepr { prod_components = tys
, prod_data_con = data_con })
= do
vars <- mapM (newLocalVar FSLIT("x")) tys
pv <- newLocalVar FSLIT("p") prod_ty
vars <- mapM (newLocalVar FSLIT("r")) tys
return (vars, mkConApp data_con (map Type tys ++ map Var vars))
let res = mkConApp dc (ty_args ++ map Var vars)
expr = Case (Var pv) (mkWildId prod_ty) res_ty
[(DataAlt prod_dc, vars, res)]
buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromPRepr repr vect_tc prepr_tc _
= do
arg_ty <- mkPReprType res_ty
arg <- newLocalVar FSLIT("x") arg_ty
return (pv, expr)
liftM (Lam arg)
. from_repr repr
$ unwrapFamInstScrut prepr_tc var_tys (Var arg)
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
res_ty = mkTyConApp vect_tc var_tys
mk_alt sum_dc (var, expr) = (DataAlt sum_dc, [var], expr)
cons = map (`mkConApp` map Type var_tys) (tyConDataCons vect_tc)
[con] = cons
from_repr repr@(SumRepr { sum_components = prods
, sum_tycon = tycon })
expr
= do
vars <- mapM (newLocalVar FSLIT("x")) (map reprType prods)
bodies <- sequence . zipWith3 from_prod prods cons
$ map Var vars
return . Case expr (mkWildId (reprType repr)) res_ty
$ zipWith3 sum_alt (tyConDataCons tycon) vars bodies
where
sum_alt data_con var body = (DataAlt data_con, [var], body)
buildToArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToArrPRepr repr@(ProdRepr {
repr_prod_arg_tys = prod_arg_tys
, repr_prod_arr_tycon = prod_arr_tycon
, repr_prod_arr_data_con = prod_arr_data_con
, repr_type = repr_type
})
vect_tc prepr_tc arr_tc
from_repr repr expr = from_prod repr con expr
from_prod prod@(ProdRepr { prod_components = tys
, prod_data_con = data_con })
con
expr
= do
vars <- mapM (newLocalVar FSLIT("y")) tys
return $ Case expr (mkWildId (reprType prod)) res_ty
[(DataAlt data_con, vars, con `mkVarApps` vars)]
buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToArrPRepr repr vect_tc prepr_tc arr_tc
= do
arg_ty <- mkPArrayType el_ty
shape_tys <- arrShapeTys repr
arr_tys <- arrReprTys repr
res_ty <- mkPArrayType repr_type
rep_el_ty <- mkPReprType el_ty
arg <- newLocalVar FSLIT("xs") arg_ty
shape_vars <- mapM (newLocalVar FSLIT("sh")) shape_tys
rep_vars <- mapM (newLocalVar FSLIT("ys")) arr_tys
let vars = shape_vars ++ rep_vars
res_ty <- mkPArrayType (reprType repr)
shape_vars <- arrShapeVars repr
repr_vars <- arrReprVars repr
parray_co <- mkBuiltinCo parrayTyCon
let res = wrapFamInstBody prod_arr_tycon prod_arg_tys
. mkConApp prod_arr_data_con
$ map Type prod_arg_tys ++ map Var vars
let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
co = mkAppCoercion parray_co
. mkSymCoercion
$ mkTyConApp repr_co var_tys
scrut = unwrapFamInstScrut arr_tc var_tys (Var arg)
Just repr_co = tyConFamilyCoercion_maybe prepr_tc
co = mkAppCoercion parray_co
. mkSymCoercion
$ mkTyConApp repr_co var_tys
result <- to_repr shape_vars repr_vars repr
return . Lam arg
. mkCoerce co
$ Case (unwrapFamInstScrut arr_tc var_tys (Var arg))
(mkWildId (mkTyConApp arr_tc var_tys))
res_ty
[(DataAlt arr_dc, vars, res)]
$ Case scrut (mkWildId (mkTyConApp arr_tc var_tys)) res_ty
[(DataAlt arr_dc, shape_vars ++ concat repr_vars, result)]
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
el_ty = mkTyConApp vect_tc var_tys
[arr_dc] = tyConDataCons arr_tc
to_repr shape_vars@(len_var : _)
repr_vars
(SumRepr { sum_components = prods
, sum_arr_tycon = tycon
, sum_arr_data_con = data_con })
= do
exprs <- zipWithM (to_prod len_var) repr_vars prods
buildToArrPRepr _ _ _ _ = return (Var unitDataConId)
{-
buildToArrPRepr _ vect_tc prepr_tc arr_tc
= do
arg_ty <- mkPArrayType el_ty
rep_tys <- mapM (mapM mkPArrayType) rep_el_tys
arg <- newLocalVar FSLIT("xs") arg_ty
bndrss <- mapM (mapM (newLocalVar FSLIT("ys"))) rep_tys
len <- newLocalVar FSLIT("len") intPrimTy
sel <- newLocalVar FSLIT("sel") =<< mkPArrayType intTy
let add_sel xs | has_selector = sel : xs
| otherwise = xs
return . wrapFamInstBody tycon tys
. mkConApp data_con
$ map Type tys ++ map Var shape_vars ++ exprs
where
tys = map reprType prods
all_bndrs = len : add_sel (concat bndrss)
to_repr [len_var] [repr_vars] prod = to_prod len_var repr_vars prod
res <- parrayCoerce prepr_tc var_tys
=<< mkToArrPRepr (Var len) (Var sel) (map (map Var) bndrss)
res_ty <- mkPArrayType =<< mkPReprType el_ty
to_prod len_var
repr_vars
(ProdRepr { prod_components = tys
, prod_arr_tycon = tycon
, prod_arr_data_con = data_con })
= return . wrapFamInstBody tycon tys
. mkConApp data_con
$ map Type tys ++ map Var (len_var : repr_vars)
return . Lam arg
$ Case (unwrapFamInstScrut arr_tc var_tys (Var arg))
(mkWildId (mkTyConApp arr_tc var_tys))
res_ty
[(DataAlt arr_dc, all_bndrs, res)]
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
el_ty = mkTyConApp vect_tc var_tys
data_cons = tyConDataCons vect_tc
rep_el_tys = map dataConRepArgTys data_cons
[arr_dc] = tyConDataCons arr_tc
has_selector | [_] <- data_cons = False
| otherwise = True
-}
buildFromArrPRepr :: TyConRepr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromArrPRepr repr@(ProdRepr {
repr_prod_arg_tys = prod_arg_tys
, repr_prod_arr_tycon = prod_arr_tycon
, repr_prod_arr_data_con = prod_arr_data_con
, repr_type = repr_type
})
vect_tc prepr_tc arr_tc
buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromArrPRepr repr vect_tc prepr_tc arr_tc
= do
rep_el_ty <- mkPReprType el_ty
arg_ty <- mkPArrayType rep_el_ty
shape_tys <- arrShapeTys repr
arr_tys <- arrReprTys repr
res_ty <- mkPArrayType el_ty
arg_ty <- mkPArrayType =<< mkPReprType el_ty
arg <- newLocalVar FSLIT("xs") arg_ty
shape_vars <- mapM (newLocalVar FSLIT("sh")) shape_tys
rep_vars <- mapM (newLocalVar FSLIT("ys")) arr_tys
let vars = shape_vars ++ rep_vars
res_ty <- mkPArrayType el_ty
shape_vars <- arrShapeVars repr
repr_vars <- arrReprVars repr
parray_co <- mkBuiltinCo parrayTyCon
let res = wrapFamInstBody arr_tc var_tys
. mkConApp arr_dc
$ map Type var_tys ++ map Var vars
let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
co = mkAppCoercion parray_co
$ mkTyConApp repr_co var_tys
Just repr_co = tyConFamilyCoercion_maybe prepr_tc
co = mkAppCoercion parray_co
$ mkTyConApp repr_co var_tys
scrut = mkCoerce co (Var arg)
scrut = unwrapFamInstScrut prod_arr_tycon prod_arg_tys
$ mkCoerce co (Var arg)
result = wrapFamInstBody arr_tc var_tys
. mkConApp arr_dc
$ map Type var_tys ++ map Var (shape_vars ++ concat repr_vars)
return . Lam arg
$ Case (scrut)
(mkWildId (mkTyConApp prod_arr_tycon prod_arg_tys))
res_ty
[(DataAlt prod_arr_data_con, vars, res)]
liftM (Lam arg)
(from_repr repr scrut shape_vars repr_vars res_ty result)
where
var_tys = mkTyVarTys $ tyConTyVars vect_tc
el_ty = mkTyConApp vect_tc var_tys
[arr_dc] = tyConDataCons arr_tc
buildFromArrPRepr _ _ _ _ = return (Var unitDataConId)
buildPRDictRepr :: TyConRepr -> VM CoreExpr
from_repr (SumRepr { sum_components = prods
, sum_arr_tycon = tycon
, sum_arr_data_con = data_con })
expr
shape_vars
repr_vars
res_ty
body
= do
vars <- mapM (newLocalVar FSLIT("xs")) =<< mapM arrReprType prods
result <- go prods repr_vars vars body
let scrut = unwrapFamInstScrut tycon ty_args expr
return . Case scrut (mkWildId scrut_ty) res_ty
$ [(DataAlt data_con, shape_vars ++ vars, result)]
where
ty_args = map reprType prods
scrut_ty = mkTyConApp tycon ty_args
go [] [] [] body = return body
go (prod : prods) (repr_vars : rss) (var :<