Commit ff3bfae6 authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au
Browse files

Fix vectorisation of recursive types

parent 869984cd
......@@ -479,7 +479,6 @@ Library
Vectorise.Utils.Closure
Vectorise.Utils.Hoisting
Vectorise.Utils.PADict
Vectorise.Utils.PRDict
Vectorise.Utils.Poly
Vectorise.Utils
Vectorise.Type.Env
......@@ -487,7 +486,6 @@ Library
Vectorise.Type.PData
Vectorise.Type.PRepr
Vectorise.Type.PADict
Vectorise.Type.PRDict
Vectorise.Type.Type
Vectorise.Type.TyConDecl
Vectorise.Type.Classify
......
......@@ -18,7 +18,6 @@ import CoreSyn
import CoreUnfold ( mkInlineUnfolding )
import CoreFVs
import CoreMonad ( CoreM, getHscEnv )
import FamInstEnv ( extendFamInstEnvList )
import Var
import Id
import OccName
......@@ -62,9 +61,7 @@ vectModule guts
-- TODO: What new binds do we get back here?
(types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
-- TODO: What is this?
let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
updGEnv (setFamInstEnv fam_inst_env')
(_, fam_inst_env) <- readGEnv global_fam_inst_env
-- dicts <- mapM buildPADict pa_insts
-- workers <- mapM vectDataConWorkers pa_insts
......@@ -74,7 +71,7 @@ vectModule guts
return $ guts { mg_types = types'
, mg_binds = Rec tc_binds : binds'
, mg_fam_inst_env = fam_inst_env'
, mg_fam_inst_env = fam_inst_env
, mg_fam_insts = mg_fam_insts guts ++ fam_insts
}
......
......@@ -61,10 +61,12 @@ data Builtins
, parrayTyCon :: TyCon -- ^ PArray
, parrayDataCon :: DataCon -- ^ PArray
, pdataTyCon :: TyCon -- ^ PData
, paClass :: Class -- ^ PA
, paTyCon :: TyCon -- ^ PA
, paDataCon :: DataCon -- ^ PA
, paPRSel :: Var -- ^ PA
, preprTyCon :: TyCon -- ^ PRepr
, prClass :: Class -- ^ PR
, prTyCon :: TyCon -- ^ PR
, prDataCon :: DataCon -- ^ PR
, replicatePDVar :: Var -- ^ replicatePD
......
......@@ -46,14 +46,15 @@ initBuiltins pkg
let [parrayDataCon] = tyConDataCons parrayTyCon
pdataTyCon <- externalTyCon dph_PArray (fsLit "PData")
pa <- externalClass dph_PArray (fsLit "PA")
let paTyCon = classTyCon pa
paClass <- externalClass dph_PArray (fsLit "PA")
let paTyCon = classTyCon paClass
[paDataCon] = tyConDataCons paTyCon
paPRSel = classSCSelId pa 0
paPRSel = classSCSelId paClass 0
preprTyCon <- externalTyCon dph_PArray (fsLit "PRepr")
prTyCon <- externalClassTyCon dph_PArray (fsLit "PR")
let [prDataCon] = tyConDataCons prTyCon
prClass <- externalClass dph_PArray (fsLit "PR")
let prTyCon = classTyCon prClass
[prDataCon] = tyConDataCons prTyCon
closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
......@@ -127,10 +128,12 @@ initBuiltins pkg
, parrayTyCon = parrayTyCon
, parrayDataCon = parrayDataCon
, pdataTyCon = pdataTyCon
, paClass = paClass
, paTyCon = paTyCon
, paDataCon = paDataCon
, paPRSel = paPRSel
, preprTyCon = preprTyCon
, prClass = prClass
, prTyCon = prTyCon
, prDataCon = prDataCon
, voidTyCon = voidTyCon
......@@ -308,9 +311,3 @@ externalClass :: Module -> FastString -> DsM Class
externalClass mod fs
= dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
-- | Like `externalClass`, but get the TyCon of of the class.
externalClassTyCon :: Module -> FastString -> DsM TyCon
externalClassTyCon mod fs = liftM classTyCon (externalClass mod fs)
......@@ -11,7 +11,8 @@ module Vectorise.Env (
initGlobalEnv,
extendImportedVarsEnv,
extendScalars,
setFamInstEnv,
setFamEnv,
extendFamEnv,
extendTyConsEnv,
extendDataConsEnv,
extendPAFunsEnv,
......@@ -142,11 +143,16 @@ extendScalars vs genv
-- | Set the list of type family instances in an environment.
setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
setFamInstEnv l_fam_inst genv
setFamEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
setFamEnv l_fam_inst genv
= genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
where (g_fam_inst, _) = global_fam_inst_env genv
extendFamEnv :: [FamInst] -> GlobalEnv -> GlobalEnv
extendFamEnv new genv
= genv { global_fam_inst_env = (g_fam_inst, extendFamInstEnvList l_fam_inst new) }
where (g_fam_inst, l_fam_inst) = global_fam_inst_env genv
-- | Extend the list of type constructors in an environment.
extendTyConsEnv :: [(Name, TyCon)] -> GlobalEnv -> GlobalEnv
......
......@@ -77,7 +77,6 @@ maybeCantVectoriseM s d p
Just x -> return x
Nothing -> cantVectorise s d
-- Control --------------------------------------------------------------------
-- | Return some result saying we've failed.
noV :: VM a
......
......@@ -82,6 +82,13 @@ vectTypeEnv env
let vect_tcs = filter (not . isClassTyCon)
$ keep_tcs ++ new_tcs
reprs <- mapM tyConRepr vect_tcs
repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
updGEnv $ extendFamEnv
$ map mkLocalFamInst
$ repr_tcs ++ pdata_tcs
-- Create PRepr and PData instances for the vectorised types.
-- We get back the binds for the instance functions,
-- and some new type constructors for the representation types.
......@@ -89,8 +96,6 @@ vectTypeEnv env
do
defTyConPAs (zipLazy vect_tcs dfuns')
reprs <- mapM tyConRepr vect_tcs
repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
dfuns <- sequence
$ zipWith5 buildTyConBindings
......
......@@ -6,7 +6,6 @@ import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Type.Repr
import Vectorise.Type.PRepr
import Vectorise.Type.PRDict
import Vectorise.Utils
import BasicTypes
......@@ -19,7 +18,8 @@ import TypeRep
import Id
import Var
import Name
import Outputable
import FastString
-- import Outputable
-- debug = False
-- dtrace s x = if debug then pprTrace "Vectoris.Type.PADict" s x else x
......@@ -29,38 +29,52 @@ import Outputable
buildPADict
:: TyCon -- ^ tycon of the type being vectorised.
-> TyCon -- ^ tycon of the type used for the vectorised representation.
-> TyCon --
-> TyCon -- ^ PRepr instance tycon
-> SumRepr -- ^ representation used for the type being vectorised.
-> VM Var -- ^ name of the top-level dictionary function.
buildPADict vect_tc prepr_tc arr_tc repr
= polyAbstract tvs $ \args ->
case args of
(_:_) -> pprPanic "Vectorise.Type.PADict.buildPADict" (text "why do we need superclass dicts?")
[] -> do
-- TODO: I'm forcing args to [] because I'm not sure why we need them.
-- class PA has superclass (PR (PRepr a)) but we're not using
-- the superclass dictionary to build the PA dictionary.
do
-- The superclass dictionary is an argument if the tycon is polymorphic
let mk_super_ty = do
r <- mkPReprType inst_ty
pr_cls <- builtin prClass
return $ PredTy $ ClassP pr_cls [r]
super_tys <- sequence [mk_super_ty | not (null tvs)]
super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
let args' = super_args ++ args
-- it is constant otherwise
super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
| null tvs]
-- Get ids for each of the methods in the dictionary.
method_ids <- mapM (method args) paMethods
method_ids <- mapM (method args') paMethods
-- Expression to build the dictionary.
pa_dc <- builtin paDataCon
let dict = mkLams (tvs ++ args)
let dict = mkLams (tvs ++ args')
$ mkConApp pa_dc
$ Type inst_ty : map (method_call args) method_ids
$ Type inst_ty
: map Var super_args ++ super_consts
-- the superclass dictionary is
-- either lambda-bound or
-- constant
++ map (method_call args') method_ids
-- Build the type of the dictionary function.
pa_tc <- builtin paTyCon
let Just pa_cls = tyConClass_maybe pa_tc
pa_cls <- builtin paClass
let dfun_ty = mkForAllTys tvs
$ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
$ mkFunTys (map varType args')
(PredTy $ ClassP pa_cls [inst_ty])
-- Set the unfolding for the inliner.
raw_dfun <- newExportedVar dfun_name dfun_ty
let dfun_unf = mkDFunUnfolding dfun_ty (map (DFunPolyArg . Var) method_ids)
let dfun_unf = mkDFunUnfolding dfun_ty
$ map (const $ DFunLamArg 0) super_args
++ map DFunConstArg super_consts
++ map (DFunPolyArg . Var) method_ids
dfun = raw_dfun `setIdUnfolding` dfun_unf
`setInlinePragma` dfunInlinePragma
......@@ -91,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr
paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
paMethods = [("dictPRepr", buildPRDict),
("toPRepr", buildToPRepr),
paMethods = [("toPRepr", buildToPRepr),
("fromPRepr", buildFromPRepr),
("toArrPRepr", buildToArrPRepr),
("fromArrPRepr", buildFromArrPRepr)]
......
module Vectorise.Type.PRDict
(buildPRDict)
where
import Vectorise.Utils
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Type.Repr
import CoreSyn
import CoreUtils
import TyCon
import Type
import Coercion
buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
buildPRDict vect_tc prepr_tc _ r
= do
dict <- sum_dict r
pr_co <- mkBuiltinCo prTyCon
let co = mkAppCoercion pr_co
. mkSymCoercion
$ mkTyConApp arg_co ty_args
return (mkCoerce co dict)
where
ty_args = mkTyVarTys (tyConTyVars vect_tc)
Just arg_co = tyConFamilyCoercion_maybe prepr_tc
sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
sum_dict (UnarySum r) = con_dict r
sum_dict (Sum { repr_sum_tc = sum_tc
, repr_con_tys = tys
, repr_cons = cons
})
= do
dicts <- mapM con_dict cons
dfun <- prDFunOfTyCon sum_tc
return $ dfun `mkTyApps` tys `mkApps` dicts
con_dict (ConRepr _ r) = prod_dict r
prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
prod_dict (UnaryProd r) = comp_dict r
prod_dict (Prod { repr_tup_tc = tup_tc
, repr_comp_tys = tys
, repr_comps = comps })
= do
dicts <- mapM comp_dict comps
dfun <- prDFunOfTyCon tup_tc
return $ dfun `mkTyApps` tys `mkApps` dicts
comp_dict (Keep _ pr) = return pr
comp_dict (Wrap ty) = wrapPR ty
......@@ -82,7 +82,7 @@ tyConRepr tc = sum_repr (tyConDataCons tc)
where
arity = length tys
comp_repr ty = liftM (Keep ty) (prDictOfType ty)
comp_repr ty = liftM (Keep ty) (prDictOfReprType ty)
`orElseV` return (Wrap ty)
sumReprType :: SumRepr -> VM Type
......
......@@ -4,7 +4,6 @@ module Vectorise.Utils (
module Vectorise.Utils.Closure,
module Vectorise.Utils.Hoisting,
module Vectorise.Utils.PADict,
module Vectorise.Utils.PRDict,
module Vectorise.Utils.Poly,
-- * Annotated Exprs
......@@ -28,7 +27,6 @@ import Vectorise.Utils.Base
import Vectorise.Utils.Closure
import Vectorise.Utils.Hoisting
import Vectorise.Utils.PADict
import Vectorise.Utils.PRDict
import Vectorise.Utils.Poly
import Vectorise.Monad
import Vectorise.Builtins
......
......@@ -2,7 +2,9 @@
module Vectorise.Utils.PADict (
paDictArgType,
paDictOfType,
paMethod
paMethod,
prDictOfReprType,
prDictOfPReprInstTyCon
)
where
import Vectorise.Monad
......@@ -42,7 +44,9 @@ paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
go ty k
| isLiftedTypeKind k
= liftM Just (mkBuiltinTyConApp paTyCon [ty])
= do
pa_cls <- builtin paClass
return $ Just $ PredTy $ ClassP pa_cls [ty]
go _ _ = return Nothing
......@@ -108,17 +112,36 @@ prDictOfPReprInst :: Type -> VM CoreExpr
prDictOfPReprInst ty
= do
(prepr_tc, prepr_args) <- preprSynTyCon ty
case coreView (mkTyConApp prepr_tc prepr_args) of
Just rhs -> do
dict <- prDictOfReprType rhs
pr_co <- mkBuiltinCo prTyCon
let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
let co = mkAppCoercion pr_co
$ mkSymCoercion
$ mkTyConApp arg_co prepr_args
return $ mkCoerce co dict
Nothing -> cantVectorise "Invalid PRepr type instance"
$ ppr ty
prDictOfPReprInstTyCon ty prepr_tc prepr_args
-- | Given a type @ty@, its PRepr synonym tycon and its type arguments,
-- return the PR @PRepr ty@. Suppose we have:
--
-- > type instance PRepr (T a1 ... an) = t
--
-- which is internally translated into
--
-- > type :R:PRepr a1 ... an = t
--
-- and the corresponding coercion. Then,
--
-- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
--
-- Note that @ty@ is only used for error messages
--
prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr
prDictOfPReprInstTyCon ty prepr_tc prepr_args
| Just rhs <- coreView (mkTyConApp prepr_tc prepr_args)
= do
dict <- prDictOfReprType' rhs
pr_co <- mkBuiltinCo prTyCon
let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
let co = mkAppCoercion pr_co
$ mkSymCoercion
$ mkTyConApp arg_co prepr_args
return $ mkCoerce co dict
| otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty)
-- | Get the PR dictionary for a type. The argument must be a representation
-- type.
......@@ -129,14 +152,13 @@ prDictOfReprType ty
prepr <- builtin preprTyCon
if tycon == prepr
then do
[ty'] <- return tyargs
prDictOfPReprInst ty'
let [ty'] = tyargs
pa <- paDictOfType ty'
sel <- builtin paPRSel
return $ Var sel `App` Type ty' `App` pa
else do
-- a representation tycon must have a PR instance
dfun <- maybeCantVectoriseM
"No PR dictionary for type constructor"
(ppr tycon <+> text "in" <+> ppr ty)
$ lookupTyConPR tycon
dfun <- maybeV $ lookupTyConPR tycon
prDFunApply dfun tyargs
| otherwise
......@@ -153,6 +175,11 @@ prDictOfReprType ty
prsel <- builtin paPRSel
return $ Var prsel `mkApps` [Type ty, pa]
prDictOfReprType' :: Type -> VM CoreExpr
prDictOfReprType' ty = prDictOfReprType ty `orElseV`
cantVectorise "No PR dictionary for representation type"
(ppr ty)
-- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
-- to the argument types.
prDFunApply :: Var -> [Type] -> VM CoreExpr
......
module Vectorise.Utils.PRDict (
prDictOfType,
wrapPR
)
where
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils.Base
import Vectorise.Utils.PADict
import CoreSyn
import Type
import TypeRep
import Control.Monad
prDictOfType :: Type -> VM CoreExpr
prDictOfType ty = prDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
prDictOfTyApp :: Type -> [Type] -> VM CoreExpr
prDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn = prDictOfTyApp ty_fn' ty_args
prDictOfTyApp (TyConApp tc _) ty_args
= do
dfun <- liftM Var $ maybeV (lookupTyConPR tc)
prDFunApply dfun ty_args
prDictOfTyApp _ _ = noV
prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
prDFunApply dfun tys
= do
dicts <- mapM prDictOfType tys
return $ mkApps (mkTyApps dfun tys) dicts
wrapPR :: Type -> VM CoreExpr
wrapPR ty
= do
pa_dict <- paDictOfType ty
pr_dfun <- prDFunOfTyCon =<< builtin wrapTyCon
return $ mkApps pr_dfun [Type ty, pa_dict]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment