Commit 791ad758 authored by benl@ouroborus.net's avatar benl@ouroborus.net
Browse files

Vectorisation of method types

parent 112780e0
......@@ -324,6 +324,7 @@ data OverlapFlag
--
-- Example: constraint (Foo [Int])
-- instances (Foo [Int])
-- (Foo [a]) OverlapOk
-- Since the second instance has the OverlapOk flag,
-- the first instance will be chosen (otherwise
......
......@@ -9,11 +9,12 @@ The @Class@ datatype
module Class (
Class, ClassOpItem,
DefMeth (..),
defMethSpecOfDefMeth,
FunDep, pprFundeps, pprFunDep,
mkClass, classTyVars, classArity,
classKey, className, classATs, classSelIds, classTyCon, classMethods,
classKey, className, classATs, classSelIds, classTyCon, classMethods, classOpItems,
classOpItems,classBigSig, classExtraBigSig, classTvsFds, classSCTheta
) where
......@@ -74,6 +75,16 @@ data DefMeth = NoDefMeth -- No default method
| DefMeth Name -- A polymorphic default method
| GenDefMeth -- A generic default method
deriving Eq
-- | Convert a `DefMethSpec` to a `DefMeth`, which discards the name field in
-- the `DefMeth` constructor of the `DefMeth`.
defMethSpecOfDefMeth :: DefMeth -> DefMethSpec
defMethSpecOfDefMeth meth
= case meth of
NoDefMeth -> NoDM
DefMeth _ -> VanillaDM
GenDefMeth -> GenericDM
\end{code}
The @mkClass@ function fills in the indirect superclasses.
......@@ -122,7 +133,8 @@ classMethods (Class {classOpStuff = op_stuff})
= [op_sel | (op_sel, _) <- op_stuff]
classOpItems :: Class -> [ClassOpItem]
classOpItems (Class {classOpStuff = op_stuff}) = op_stuff
classOpItems (Class { classOpStuff = op_stuff})
= op_stuff
classTvsFds :: Class -> ([TyVar], [FunDep TyVar])
classTvsFds c
......
{-# LANGUAGE NamedFieldPuns #-}
-- | The Vectorisation monad.
module VectMonad (
......@@ -461,9 +462,25 @@ lookupVar v
case r of
Just e -> return (Local e)
Nothing -> liftM Global
. maybeCantVectoriseM "Variable not vectorised:" (ppr v)
. maybeCantVectoriseVarM v
. readGEnv $ \env -> lookupVarEnv (global_vars env) v
maybeCantVectoriseVarM :: Monad m => Var -> m (Maybe Var) -> m Var
maybeCantVectoriseVarM v p
= do r <- p
case r of
Just x -> return x
Nothing -> dumpVar v
dumpVar :: Var -> a
dumpVar var
| Just cls <- isClassOpId_maybe var
= cantVectorise "ClassOpId not vectorised:" (ppr var)
| otherwise
= cantVectorise "Variable not vectorised:" (ppr var)
-------------------------------------------------------------------------------
globalScalars :: VM VarSet
globalScalars = readGEnv global_scalars
......
{-# OPTIONS -fno-warn-missing-signatures #-}
module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
-- arrSumArity, pdataCompTys, pdataCompVars,
buildPADict,
......@@ -9,6 +11,7 @@ import VectUtils
import VectCore
import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
import BasicTypes
import CoreSyn
import CoreUtils
import CoreUnfold
......@@ -16,6 +19,7 @@ import MkCore ( mkWildCase )
import BuildTyCl
import DataCon
import TyCon
import Class
import Type
import TypeRep
import Coercion
......@@ -23,9 +27,7 @@ import FamInstEnv ( FamInst, mkLocalFamInst )
import OccName
import Id
import MkId
import BasicTypes ( HsBang(..), boolToRecFlag,
alwaysInlinePragma, dfunInlinePragma )
import Var ( Var, TyVar, varType )
import Var ( Var, TyVar, varType, varName )
import Name ( Name, getOccName )
import NameEnv
......@@ -40,7 +42,11 @@ import FastString
import MonadUtils ( zipWith3M, foldrM, concatMapM )
import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
import Data.List ( inits, tails, zipWith4, zipWith5 )
import Data.List
import Data.Maybe
debug = False
dtrace s x = if debug then pprTrace "VectType" s x else x
-- ----------------------------------------------------------------------------
-- Types
......@@ -72,29 +78,57 @@ vectAndLiftType ty
-- | Vectorise a type.
vectType :: Type -> VM Type
vectType ty | Just ty' <- coreView ty = vectType ty'
vectType ty
| Just ty' <- coreView ty
= vectType ty'
vectType (TyVarTy tv) = return $ TyVarTy tv
vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
(mapM vectAndBoxType [ty1,ty2])
-- For each quantified var we need to add a PA dictionary out the front of the type.
-- So forall a. C a => a -> a
-- turns into forall a. Cv a => PA a => a :-> a
vectType ty@(ForAllTy _ _)
= do
mdicts <- mapM paDictArgType tyvars
mono_ty' <- vectType mono_ty
return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty'
where
(tyvars, mono_ty) = splitForAllTys ty
-- split the type into the quantified vars, its dictionaries and the body.
let (tyvars, tyBody) = splitForAllTys ty
let (tyArgs, tyResult) = splitFunTys tyBody
let (tyArgs_dict, tyArgs_regular)
= partition isDictType tyArgs
-- vectorise the body.
let tyBody' = mkFunTys tyArgs_regular tyResult
tyBody'' <- vectType tyBody'
-- vectorise the dictionary parameters.
dictsVect <- mapM vectType tyArgs_dict
-- make a PA dictionary for each of the type variables.
dictsPA <- liftM catMaybes $ mapM paDictArgType tyvars
-- pack it all back together.
return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
vectAndBoxType :: Type -> VM Type
vectAndBoxType ty = vectType ty >>= boxType
-- | Add quantified vars and dictionary parameters to the front of a type.
abstractType :: [TyVar] -> [Type] -> Type -> Type
abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
-- | Check if some type is a type class dictionary.
isDictType :: Type -> Bool
isDictType ty
= case splitTyConApp_maybe ty of
Just (tyCon, _) -> isClassTyCon tyCon
_ -> False
-- ----------------------------------------------------------------------------
-- Boxing
......@@ -110,6 +144,10 @@ boxType ty
boxType ty = return ty
vectAndBoxType :: Type -> VM Type
vectAndBoxType ty = vectType ty >>= boxType
-- ----------------------------------------------------------------------------
-- Type definitions
......@@ -119,7 +157,8 @@ type TyConGroup = ([TyCon], UniqSet TyCon)
-- The type environment contains all the type things defined in a module.
vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
vectTypeEnv env
= do
= dtrace (ppr env)
$ do
cs <- readGEnv $ mk_map . global_tycons
-- Split the list of TyCons into the ones we have to vectorise vs the
......@@ -127,13 +166,29 @@ vectTypeEnv env
-- types that use non Haskell98 features, as we don't handle those.
let (conv_tcs, keep_tcs) = classifyTyCons cs groups
keep_dcs = concatMap tyConDataCons keep_tcs
dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
zipWithM_ defTyCon keep_tcs keep_tcs
zipWithM_ defDataCon keep_dcs keep_dcs
new_tcs <- vectTyConDecls conv_tcs
dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
let orig_tcs = keep_tcs ++ conv_tcs
vect_tcs = keep_tcs ++ new_tcs
-- We don't need to make new representation types for dictionary
-- constructors. The constructors are always fully applied, and we don't
-- need to lift them to arrays as a dictionary of a particular type
-- always has the same value.
let vect_tcs = filter (not . isClassTyCon)
$ keep_tcs ++ new_tcs
dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
mapM_ dumpTycon $ new_tcs
(_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
do
......@@ -141,11 +196,15 @@ vectTypeEnv env
reprs <- mapM tyConRepr vect_tcs
repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs
dfuns <- sequence
$ zipWith5 buildTyConBindings
orig_tcs
vect_tcs
repr_tcs
pdata_tcs
reprs
binds <- takeHoisted
return (dfuns, binds, repr_tcs ++ pdata_tcs)
......@@ -171,25 +230,106 @@ vectTyConDecls tcs = fixV $ \tcs' ->
mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
mapM vectTyConDecl tcs
vectTyConDecl :: TyCon -> VM TyCon
vectTyConDecl tc
= do
name' <- cloneName mkVectTyConOcc name
rhs' <- vectAlgTyConRhs tc (algTyConRhs tc)
dumpTycon :: TyCon -> VM ()
dumpTycon tycon
| Just cls <- tyConClass_maybe tycon
= dtrace (vcat [ ppr tycon
, ppr [(m, varType m) | m <- classMethods cls ]])
$ return ()
liftDs $ buildAlgTyCon name'
tyvars
| otherwise
= return ()
-- | Vectorise a single type construcrtor.
vectTyConDecl :: TyCon -> VM TyCon
vectTyConDecl tycon
-- a type class constructor.
-- TODO: check for no stupid theta, fds, assoc types.
| isClassTyCon tycon
, Just cls <- tyConClass_maybe tycon
= do -- make the name of the vectorised class tycon.
name' <- cloneName mkVectTyConOcc (tyConName tycon)
-- vectorise right of definition.
rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
-- vectorise method selectors.
-- This also adds a mapping between the original and vectorised method selector
-- to the state.
methods' <- mapM vectMethod
$ [(id, defMethSpecOfDefMeth meth)
| (id, meth) <- classOpItems cls]
-- keep the original recursiveness flag.
let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
-- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
cls' <- liftDs
$ buildClass
False -- include unfoldings on dictionary selectors.
name' -- new name V_T:Class
(tyConTyVars tycon) -- keep original type vars
[] -- no stupid theta
[] -- no functional dependencies
[] -- no associated types
methods' -- method info
rec_flag -- whether recursive
let tycon' = mkClassTyCon name'
(tyConKind tycon)
(tyConTyVars tycon)
rhs'
cls'
rec_flag
return $ tycon'
-- a regular algebraic type constructor.
-- TODO: check for stupid theta, generaics, GADTS etc
| isAlgTyCon tycon
= do name' <- cloneName mkVectTyConOcc (tyConName tycon)
rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
liftDs $ buildAlgTyCon
name' -- new name
(tyConTyVars tycon) -- keep original type vars.
[] -- no stupid theta.
rhs' -- new constructor defs.
rec_flag -- FIXME: is this ok?
False -- FIXME: no generics
False -- not GADT syntax
Nothing -- not a family instance
where
name = tyConName tc
tyvars = tyConTyVars tc
rec_flag = boolToRecFlag (isRecursiveTyCon tc)
-- some other crazy thing that we don't handle.
| otherwise
= cantVectorise "Can't vectorise type constructor: " (ppr tycon)
-- | Vectorise a class method.
vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
vectMethod (id, defMeth)
= do
-- Vectorise the method type.
typ' <- vectType (varType id)
-- Create a name for the vectorised method.
id' <- cloneId mkVectOcc id typ'
defGlobalVar id id'
-- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
-- to the types of each method. However, the types we get back from vectType
-- above already already have these, so we need to chop them off here otherwise
-- we'll get two copies in the final version.
let (_tyvars, tyBody) = splitForAllTys typ'
let (_dict, tyRest) = splitFunTy tyBody
return (Var.varName id', defMeth, tyRest)
-- | Vectorise the RHS of an algebraic type.
vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
, is_enum = is_enum
......@@ -200,31 +340,39 @@ vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
return $ DataTyCon { data_cons = data_cons'
, is_enum = is_enum
}
vectAlgTyConRhs tc _ = cantVectorise "Can't vectorise type definition:" (ppr tc)
vectAlgTyConRhs tc _
= cantVectorise "Can't vectorise type definition:" (ppr tc)
-- | Vectorise a data constructor.
-- Vectorises its argument and return types.
vectDataCon :: DataCon -> VM DataCon
vectDataCon dc
| not . null $ dataConExTyVars dc
= cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
| not . null $ dataConEqSpec dc
= cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
| otherwise
= do
name' <- cloneName mkVectDataConOcc name
tycon' <- vectTyCon tycon
arg_tys <- mapM vectType rep_arg_tys
liftDs $ buildDataCon name'
liftDs $ buildDataCon
name'
False -- not infix
(map (const HsNoBang) arg_tys)
(map (const HsNoBang) arg_tys) -- strictness annots on args.
[] -- no labelled fields
univ_tvs
univ_tvs -- universally quantified vars
[] -- no existential tvs for now
[] -- no eq spec for now
[] -- no context
arg_tys
(mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs))
tycon'
arg_tys -- argument types
(mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
tycon' -- representation tycon
where
name = dataConName dc
univ_tvs = dataConUnivTyVars dc
......@@ -861,6 +1009,7 @@ paMethods = [("dictPRepr", buildPRDict),
("toArrPRepr", buildToArrPRepr),
("fromArrPRepr", buildFromArrPRepr)]
-- | Split the given tycons into two sets depending on whether they have to be
-- converted (first list) or not (second list). The first argument contains
-- information about the conversion status of external tycons:
......@@ -929,8 +1078,31 @@ tyConsOfTypes = unionManyUniqSets . map tyConsOfType
-- ----------------------------------------------------------------------------
-- Conversions
fromVect :: Type -> CoreExpr -> VM CoreExpr
fromVect ty expr | Just ty' <- coreView ty = fromVect ty' expr
-- | Build an expression that calls the vectorised version of some
-- function from a `Closure`.
--
-- For example
-- @
-- \(x :: Double) ->
-- \(y :: Double) ->
-- ($v_foo $: x) $: y
-- @
--
-- We use the type of the original binding to work out how many
-- outer lambdas to add.
--
fromVect
:: Type -- ^ The type of the original binding.
-> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
-> VM CoreExpr
-- Convert the type to the core view if it isn't already.
fromVect ty expr
| Just ty' <- coreView ty
= fromVect ty' expr
-- For each function constructor in the original type we add an outer
-- lambda to bind the parameter variable, and an inner application of it.
fromVect (FunTy arg_ty res_ty) expr
= do
arg <- newLocalVar (fsLit "x") arg_ty
......@@ -941,12 +1113,16 @@ fromVect (FunTy arg_ty res_ty) expr
body <- fromVect res_ty
$ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
return $ Lam arg body
-- If the type isn't a function then it's time to call on the closure.
fromVect ty expr
= identityConv ty >> return expr
toVect :: Type -> CoreExpr -> VM CoreExpr
toVect ty expr = identityConv ty >> return expr
identityConv :: Type -> VM ()
identityConv ty | Just ty' <- coreView ty = identityConv ty'
identityConv (TyConApp tycon tys)
......
......@@ -163,6 +163,7 @@ prDFunOfTyCon tycon
. maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
$ lookupTyConPR tycon
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
where
......@@ -183,26 +184,40 @@ paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
go _ _ = return Nothing
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
-- | Get the PA dictionary for some type, or `Nothing` if there isn't one.
paDictOfType :: Type -> VM (Maybe CoreExpr)
paDictOfType ty
= paDictOfTyApp ty_fn ty_args
where
(ty_fn, ty_args) = splitAppTys ty
paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do
dfun <- maybeV (lookupTyVarPA tv)
paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
= do
dfun <- maybeCantVectoriseM "No PA dictionary for tycon" (ppr tc)
$ lookupTyConPA tc
paDFunApply (Var dfun) ty_args
paDictOfTyApp ty _
paDictOfTyApp :: Type -> [Type] -> VM (Maybe CoreExpr)
paDictOfTyApp ty_fn ty_args
| Just ty_fn' <- coreView ty_fn
= paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
= do dfun <- maybeV (lookupTyVarPA tv)
liftM Just $ paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
= do mdfun <- lookupTyConPA tc
case mdfun of
Nothing
-> pprTrace "VectUtils.paDictOfType"
(vcat [ text "No PA dictionary"
, text "for tycon: " <> ppr tc
, text "in type: " <> ppr ty])
$ return Nothing
Just dfun -> liftM Just $ paDFunApply (Var dfun) ty_args
paDictOfTyApp ty _
= cantVectorise "Can't construct PA dictionary for type" (ppr ty)
paDFunType :: TyCon -> VM Type
paDFunType tc
= do
......@@ -216,10 +231,10 @@ paDFunType tc
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
= do
dicts <- mapM paDictOfType tys
= do Just dicts <- liftM sequence $ mapM paDictOfType tys
return $ mkApps (mkTyApps dfun tys) dicts
paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
paMethod _ name ty
| Just tycon <- splitPrimTyCon ty
......@@ -230,7 +245,7 @@ paMethod _ name ty
paMethod method _ ty
= do
fn <- builtin method
dict <- paDictOfType ty
Just dict <- paDictOfType ty
return $ mkApps (Var fn) [Type ty, dict]
prDictOfType :: Type -> VM CoreExpr
......@@ -256,7 +271,7 @@ prDFunApply dfun tys
wrapPR :: Type -> VM CoreExpr
wrapPR ty
= do
pa_dict <- paDictOfType ty
Just pa_dict <- paDictOfType ty
pr_dfun <- prDFunOfTyCon =<< builtin wrapTyCon
return $ mkApps pr_dfun [Type ty, pa_dict]
......@@ -302,7 +317,7 @@ scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
scalarClosure arg_tys res_ty scalar_fun array_fun
= do
ctr <- builtin (closureCtrFun $ length arg_tys)
pas <- mapM paDictOfType (init arg_tys)
Just pas <- liftM sequence $ mapM paDictOfType (init arg_tys)
return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
`mkApps` (pas ++ [scalar_fun, array_fun])
......@@ -338,24 +353,26 @@ polyArity tvs = do
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
= do
dicts <- mapM paDictOfType tys
= do Just dicts <- liftM sequence $ mapM paDictOfType tys
return $ expr `mkTyApps` tys `mkApps` dicts
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
= do
dicts <- mapM paDictOfType tys
= do Just dicts <- liftM sequence $ mapM paDictOfType tys
return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
data Inline = Inline Arity
-- Inline ---------------------------------------------------------------------
-- | Records whether we should inline a particular binding.
data Inline
= Inline Arity
| DontInline
-- | Add to the arity contained within an `Inline`, if any.
addInlineArity :: Inline -> Int -> Inline
addInlineArity (Inline m) n = Inline (m+n)
addInlineArity DontInline _ = DontInline
-- | Says to always inline a binding.
inlineMe :: Inline
inlineMe = Inline 0
......@@ -424,6 +441,7 @@ mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
= do vapply <- builtin applyVar
......