Commit 791ad758 authored by benl@ouroborus.net's avatar benl@ouroborus.net
Browse files

Vectorisation of method types

parent 112780e0
...@@ -324,6 +324,7 @@ data OverlapFlag ...@@ -324,6 +324,7 @@ data OverlapFlag
-- --
-- Example: constraint (Foo [Int]) -- Example: constraint (Foo [Int])
-- instances (Foo [Int]) -- instances (Foo [Int])
-- (Foo [a]) OverlapOk -- (Foo [a]) OverlapOk
-- Since the second instance has the OverlapOk flag, -- Since the second instance has the OverlapOk flag,
-- the first instance will be chosen (otherwise -- the first instance will be chosen (otherwise
......
...@@ -9,11 +9,12 @@ The @Class@ datatype ...@@ -9,11 +9,12 @@ The @Class@ datatype
module Class ( module Class (
Class, ClassOpItem, Class, ClassOpItem,
DefMeth (..), DefMeth (..),
defMethSpecOfDefMeth,
FunDep, pprFundeps, pprFunDep, FunDep, pprFundeps, pprFunDep,
mkClass, classTyVars, classArity, mkClass, classTyVars, classArity,
classKey, className, classATs, classSelIds, classTyCon, classMethods, classKey, className, classATs, classSelIds, classTyCon, classMethods, classOpItems,
classOpItems,classBigSig, classExtraBigSig, classTvsFds, classSCTheta classOpItems,classBigSig, classExtraBigSig, classTvsFds, classSCTheta
) where ) where
...@@ -74,6 +75,16 @@ data DefMeth = NoDefMeth -- No default method ...@@ -74,6 +75,16 @@ data DefMeth = NoDefMeth -- No default method
| DefMeth Name -- A polymorphic default method | DefMeth Name -- A polymorphic default method
| GenDefMeth -- A generic default method | GenDefMeth -- A generic default method
deriving Eq deriving Eq
-- | Convert a `DefMethSpec` to a `DefMeth`, which discards the name field in
-- the `DefMeth` constructor of the `DefMeth`.
defMethSpecOfDefMeth :: DefMeth -> DefMethSpec
defMethSpecOfDefMeth meth
= case meth of
NoDefMeth -> NoDM
DefMeth _ -> VanillaDM
GenDefMeth -> GenericDM
\end{code} \end{code}
The @mkClass@ function fills in the indirect superclasses. The @mkClass@ function fills in the indirect superclasses.
...@@ -122,7 +133,8 @@ classMethods (Class {classOpStuff = op_stuff}) ...@@ -122,7 +133,8 @@ classMethods (Class {classOpStuff = op_stuff})
= [op_sel | (op_sel, _) <- op_stuff] = [op_sel | (op_sel, _) <- op_stuff]
classOpItems :: Class -> [ClassOpItem] classOpItems :: Class -> [ClassOpItem]
classOpItems (Class {classOpStuff = op_stuff}) = op_stuff classOpItems (Class { classOpStuff = op_stuff})
= op_stuff
classTvsFds :: Class -> ([TyVar], [FunDep TyVar]) classTvsFds :: Class -> ([TyVar], [FunDep TyVar])
classTvsFds c classTvsFds c
......
{-# LANGUAGE NamedFieldPuns #-}
-- | The Vectorisation monad. -- | The Vectorisation monad.
module VectMonad ( module VectMonad (
...@@ -461,9 +462,25 @@ lookupVar v ...@@ -461,9 +462,25 @@ lookupVar v
case r of case r of
Just e -> return (Local e) Just e -> return (Local e)
Nothing -> liftM Global Nothing -> liftM Global
. maybeCantVectoriseM "Variable not vectorised:" (ppr v) . maybeCantVectoriseVarM v
. readGEnv $ \env -> lookupVarEnv (global_vars env) v . readGEnv $ \env -> lookupVarEnv (global_vars env) v
maybeCantVectoriseVarM :: Monad m => Var -> m (Maybe Var) -> m Var
maybeCantVectoriseVarM v p
= do r <- p
case r of
Just x -> return x
Nothing -> dumpVar v
dumpVar :: Var -> a
dumpVar var
| Just cls <- isClassOpId_maybe var
= cantVectorise "ClassOpId not vectorised:" (ppr var)
| otherwise
= cantVectorise "Variable not vectorised:" (ppr var)
-------------------------------------------------------------------------------
globalScalars :: VM VarSet globalScalars :: VM VarSet
globalScalars = readGEnv global_scalars globalScalars = readGEnv global_scalars
......
{-# OPTIONS -fno-warn-missing-signatures #-}
module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv, module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
-- arrSumArity, pdataCompTys, pdataCompVars, -- arrSumArity, pdataCompTys, pdataCompVars,
buildPADict, buildPADict,
...@@ -9,6 +11,7 @@ import VectUtils ...@@ -9,6 +11,7 @@ import VectUtils
import VectCore import VectCore
import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons ) import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
import BasicTypes
import CoreSyn import CoreSyn
import CoreUtils import CoreUtils
import CoreUnfold import CoreUnfold
...@@ -16,6 +19,7 @@ import MkCore ( mkWildCase ) ...@@ -16,6 +19,7 @@ import MkCore ( mkWildCase )
import BuildTyCl import BuildTyCl
import DataCon import DataCon
import TyCon import TyCon
import Class
import Type import Type
import TypeRep import TypeRep
import Coercion import Coercion
...@@ -23,9 +27,7 @@ import FamInstEnv ( FamInst, mkLocalFamInst ) ...@@ -23,9 +27,7 @@ import FamInstEnv ( FamInst, mkLocalFamInst )
import OccName import OccName
import Id import Id
import MkId import MkId
import BasicTypes ( HsBang(..), boolToRecFlag, import Var ( Var, TyVar, varType, varName )
alwaysInlinePragma, dfunInlinePragma )
import Var ( Var, TyVar, varType )
import Name ( Name, getOccName ) import Name ( Name, getOccName )
import NameEnv import NameEnv
...@@ -40,7 +42,11 @@ import FastString ...@@ -40,7 +42,11 @@ import FastString
import MonadUtils ( zipWith3M, foldrM, concatMapM ) import MonadUtils ( zipWith3M, foldrM, concatMapM )
import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM ) import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
import Data.List ( inits, tails, zipWith4, zipWith5 ) import Data.List
import Data.Maybe
debug = False
dtrace s x = if debug then pprTrace "VectType" s x else x
-- ---------------------------------------------------------------------------- -- ----------------------------------------------------------------------------
-- Types -- Types
...@@ -72,29 +78,57 @@ vectAndLiftType ty ...@@ -72,29 +78,57 @@ vectAndLiftType ty
-- | Vectorise a type. -- | Vectorise a type.
vectType :: Type -> VM Type vectType :: Type -> VM Type
vectType ty | Just ty' <- coreView ty = vectType ty' vectType ty
vectType (TyVarTy tv) = return $ TyVarTy tv | Just ty' <- coreView ty
vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2) = vectType ty'
vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon) vectType (TyVarTy tv) = return $ TyVarTy tv
(mapM vectAndBoxType [ty1,ty2]) vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
(mapM vectAndBoxType [ty1,ty2])
-- For each quantified var we need to add a PA dictionary out the front of the type.
-- So forall a. C a => a -> a
-- turns into forall a. Cv a => PA a => a :-> a
vectType ty@(ForAllTy _ _) vectType ty@(ForAllTy _ _)
= do = do
mdicts <- mapM paDictArgType tyvars -- split the type into the quantified vars, its dictionaries and the body.
mono_ty' <- vectType mono_ty let (tyvars, tyBody) = splitForAllTys ty
return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty' let (tyArgs, tyResult) = splitFunTys tyBody
where
(tyvars, mono_ty) = splitForAllTys ty let (tyArgs_dict, tyArgs_regular)
= partition isDictType tyArgs
-- vectorise the body.
let tyBody' = mkFunTys tyArgs_regular tyResult
tyBody'' <- vectType tyBody'
-- vectorise the dictionary parameters.
dictsVect <- mapM vectType tyArgs_dict
-- make a PA dictionary for each of the type variables.
dictsPA <- liftM catMaybes $ mapM paDictArgType tyvars
-- pack it all back together.
return $ abstractType tyvars (dictsVect ++ dictsPA) tyBody''
vectType ty = cantVectorise "Can't vectorise type" (ppr ty) vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
vectAndBoxType :: Type -> VM Type
vectAndBoxType ty = vectType ty >>= boxType
-- | Add quantified vars and dictionary parameters to the front of a type. -- | Add quantified vars and dictionary parameters to the front of a type.
abstractType :: [TyVar] -> [Type] -> Type -> Type abstractType :: [TyVar] -> [Type] -> Type -> Type
abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
-- | Check if some type is a type class dictionary.
isDictType :: Type -> Bool
isDictType ty
= case splitTyConApp_maybe ty of
Just (tyCon, _) -> isClassTyCon tyCon
_ -> False
-- ---------------------------------------------------------------------------- -- ----------------------------------------------------------------------------
-- Boxing -- Boxing
...@@ -110,6 +144,10 @@ boxType ty ...@@ -110,6 +144,10 @@ boxType ty
boxType ty = return ty boxType ty = return ty
vectAndBoxType :: Type -> VM Type
vectAndBoxType ty = vectType ty >>= boxType
-- ---------------------------------------------------------------------------- -- ----------------------------------------------------------------------------
-- Type definitions -- Type definitions
...@@ -119,7 +157,8 @@ type TyConGroup = ([TyCon], UniqSet TyCon) ...@@ -119,7 +157,8 @@ type TyConGroup = ([TyCon], UniqSet TyCon)
-- The type environment contains all the type things defined in a module. -- The type environment contains all the type things defined in a module.
vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)]) vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
vectTypeEnv env vectTypeEnv env
= do = dtrace (ppr env)
$ do
cs <- readGEnv $ mk_map . global_tycons cs <- readGEnv $ mk_map . global_tycons
-- Split the list of TyCons into the ones we have to vectorise vs the -- Split the list of TyCons into the ones we have to vectorise vs the
...@@ -127,26 +166,46 @@ vectTypeEnv env ...@@ -127,26 +166,46 @@ vectTypeEnv env
-- types that use non Haskell98 features, as we don't handle those. -- types that use non Haskell98 features, as we don't handle those.
let (conv_tcs, keep_tcs) = classifyTyCons cs groups let (conv_tcs, keep_tcs) = classifyTyCons cs groups
keep_dcs = concatMap tyConDataCons keep_tcs keep_dcs = concatMap tyConDataCons keep_tcs
dtrace (text "conv_tcs = " <> ppr conv_tcs) $ return ()
zipWithM_ defTyCon keep_tcs keep_tcs zipWithM_ defTyCon keep_tcs keep_tcs
zipWithM_ defDataCon keep_dcs keep_dcs zipWithM_ defDataCon keep_dcs keep_dcs
new_tcs <- vectTyConDecls conv_tcs new_tcs <- vectTyConDecls conv_tcs
dtrace (text "new_tcs = " <> ppr new_tcs) $ return ()
let orig_tcs = keep_tcs ++ conv_tcs let orig_tcs = keep_tcs ++ conv_tcs
vect_tcs = keep_tcs ++ new_tcs
-- We don't need to make new representation types for dictionary
-- constructors. The constructors are always fully applied, and we don't
-- need to lift them to arrays as a dictionary of a particular type
-- always has the same value.
let vect_tcs = filter (not . isClassTyCon)
$ keep_tcs ++ new_tcs
dtrace (text "vect_tcs = " <> ppr vect_tcs) $ return ()
mapM_ dumpTycon $ new_tcs
(_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) -> (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
do do
defTyConPAs (zipLazy vect_tcs dfuns') defTyConPAs (zipLazy vect_tcs dfuns')
reprs <- mapM tyConRepr vect_tcs reprs <- mapM tyConRepr vect_tcs
repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs
vect_tcs dfuns <- sequence
repr_tcs $ zipWith5 buildTyConBindings
pdata_tcs orig_tcs
reprs vect_tcs
binds <- takeHoisted repr_tcs
pdata_tcs
reprs
binds <- takeHoisted
return (dfuns, binds, repr_tcs ++ pdata_tcs) return (dfuns, binds, repr_tcs ++ pdata_tcs)
let all_new_tcs = new_tcs ++ inst_tcs let all_new_tcs = new_tcs ++ inst_tcs
...@@ -171,25 +230,106 @@ vectTyConDecls tcs = fixV $ \tcs' -> ...@@ -171,25 +230,106 @@ vectTyConDecls tcs = fixV $ \tcs' ->
mapM_ (uncurry defTyCon) (zipLazy tcs tcs') mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
mapM vectTyConDecl tcs mapM vectTyConDecl tcs
vectTyConDecl :: TyCon -> VM TyCon dumpTycon :: TyCon -> VM ()
vectTyConDecl tc dumpTycon tycon
= do | Just cls <- tyConClass_maybe tycon
name' <- cloneName mkVectTyConOcc name = dtrace (vcat [ ppr tycon
rhs' <- vectAlgTyConRhs tc (algTyConRhs tc) , ppr [(m, varType m) | m <- classMethods cls ]])
$ return ()
| otherwise
= return ()
liftDs $ buildAlgTyCon name'
tyvars
[] -- no stupid theta
rhs'
rec_flag -- FIXME: is this ok?
False -- FIXME: no generics
False -- not GADT syntax
Nothing -- not a family instance
where
name = tyConName tc
tyvars = tyConTyVars tc
rec_flag = boolToRecFlag (isRecursiveTyCon tc)
-- | Vectorise a single type construcrtor.
vectTyConDecl :: TyCon -> VM TyCon
vectTyConDecl tycon
-- a type class constructor.
-- TODO: check for no stupid theta, fds, assoc types.
| isClassTyCon tycon
, Just cls <- tyConClass_maybe tycon
= do -- make the name of the vectorised class tycon.
name' <- cloneName mkVectTyConOcc (tyConName tycon)
-- vectorise right of definition.
rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
-- vectorise method selectors.
-- This also adds a mapping between the original and vectorised method selector
-- to the state.
methods' <- mapM vectMethod
$ [(id, defMethSpecOfDefMeth meth)
| (id, meth) <- classOpItems cls]
-- keep the original recursiveness flag.
let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
-- Calling buildclass here attaches new quantifiers and dictionaries to the method types.
cls' <- liftDs
$ buildClass
False -- include unfoldings on dictionary selectors.
name' -- new name V_T:Class
(tyConTyVars tycon) -- keep original type vars
[] -- no stupid theta
[] -- no functional dependencies
[] -- no associated types
methods' -- method info
rec_flag -- whether recursive
let tycon' = mkClassTyCon name'
(tyConKind tycon)
(tyConTyVars tycon)
rhs'
cls'
rec_flag
return $ tycon'
-- a regular algebraic type constructor.
-- TODO: check for stupid theta, generaics, GADTS etc
| isAlgTyCon tycon
= do name' <- cloneName mkVectTyConOcc (tyConName tycon)
rhs' <- vectAlgTyConRhs tycon (algTyConRhs tycon)
let rec_flag = boolToRecFlag (isRecursiveTyCon tycon)
liftDs $ buildAlgTyCon
name' -- new name
(tyConTyVars tycon) -- keep original type vars.
[] -- no stupid theta.
rhs' -- new constructor defs.
rec_flag -- FIXME: is this ok?
False -- FIXME: no generics
False -- not GADT syntax
Nothing -- not a family instance
-- some other crazy thing that we don't handle.
| otherwise
= cantVectorise "Can't vectorise type constructor: " (ppr tycon)
-- | Vectorise a class method.
vectMethod :: (Id, DefMethSpec) -> VM (Name, DefMethSpec, Type)
vectMethod (id, defMeth)
= do
-- Vectorise the method type.
typ' <- vectType (varType id)
-- Create a name for the vectorised method.
id' <- cloneId mkVectOcc id typ'
defGlobalVar id id'
-- When we call buildClass in vectTyConDecl, it adds foralls and dictionaries
-- to the types of each method. However, the types we get back from vectType
-- above already already have these, so we need to chop them off here otherwise
-- we'll get two copies in the final version.
let (_tyvars, tyBody) = splitForAllTys typ'
let (_dict, tyRest) = splitFunTy tyBody
return (Var.varName id', defMeth, tyRest)
-- | Vectorise the RHS of an algebraic type.
vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
, is_enum = is_enum , is_enum = is_enum
...@@ -200,31 +340,39 @@ vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons ...@@ -200,31 +340,39 @@ vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
return $ DataTyCon { data_cons = data_cons' return $ DataTyCon { data_cons = data_cons'
, is_enum = is_enum , is_enum = is_enum
} }
vectAlgTyConRhs tc _ = cantVectorise "Can't vectorise type definition:" (ppr tc)
vectAlgTyConRhs tc _
= cantVectorise "Can't vectorise type definition:" (ppr tc)
-- | Vectorise a data constructor.
-- Vectorises its argument and return types.
vectDataCon :: DataCon -> VM DataCon vectDataCon :: DataCon -> VM DataCon
vectDataCon dc vectDataCon dc
| not . null $ dataConExTyVars dc | not . null $ dataConExTyVars dc
= cantVectorise "Can't vectorise constructor (existentials):" (ppr dc) = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
| not . null $ dataConEqSpec dc | not . null $ dataConEqSpec dc
= cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc) = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
| otherwise | otherwise
= do = do
name' <- cloneName mkVectDataConOcc name name' <- cloneName mkVectDataConOcc name
tycon' <- vectTyCon tycon tycon' <- vectTyCon tycon
arg_tys <- mapM vectType rep_arg_tys arg_tys <- mapM vectType rep_arg_tys
liftDs $ buildDataCon name' liftDs $ buildDataCon
False -- not infix name'
(map (const HsNoBang) arg_tys) False -- not infix
[] -- no labelled fields (map (const HsNoBang) arg_tys) -- strictness annots on args.
univ_tvs [] -- no labelled fields
[] -- no existential tvs for now univ_tvs -- universally quantified vars
[] -- no eq spec for now [] -- no existential tvs for now
[] -- no context [] -- no eq spec for now
arg_tys [] -- no context
(mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) arg_tys -- argument types
tycon' (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs)) -- return type
tycon' -- representation tycon
where where
name = dataConName dc name = dataConName dc
univ_tvs = dataConUnivTyVars dc univ_tvs = dataConUnivTyVars dc
...@@ -861,6 +1009,7 @@ paMethods = [("dictPRepr", buildPRDict), ...@@ -861,6 +1009,7 @@ paMethods = [("dictPRepr", buildPRDict),
("toArrPRepr", buildToArrPRepr), ("toArrPRepr", buildToArrPRepr),
("fromArrPRepr", buildFromArrPRepr)] ("fromArrPRepr", buildFromArrPRepr)]
-- | Split the given tycons into two sets depending on whether they have to be -- | Split the given tycons into two sets depending on whether they have to be
-- converted (first list) or not (second list). The first argument contains -- converted (first list) or not (second list). The first argument contains
-- information about the conversion status of external tycons: -- information about the conversion status of external tycons:
...@@ -929,8 +1078,31 @@ tyConsOfTypes = unionManyUniqSets . map tyConsOfType ...@@ -929,8 +1078,31 @@ tyConsOfTypes = unionManyUniqSets . map tyConsOfType
-- ---------------------------------------------------------------------------- -- ----------------------------------------------------------------------------
-- Conversions -- Conversions
fromVect :: Type -> CoreExpr -> VM CoreExpr -- | Build an expression that calls the vectorised version of some
fromVect ty expr | Just ty' <- coreView ty = fromVect ty' expr -- function from a `Closure`.
--
-- For example
-- @
-- \(x :: Double) ->
-- \(y :: Double) ->
-- ($v_foo $: x) $: y
-- @
--
-- We use the type of the original binding to work out how many
-- outer lambdas to add.
--
fromVect
:: Type -- ^ The type of the original binding.
-> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
-> VM CoreExpr
-- Convert the type to the core view if it isn't already.
fromVect ty expr
| Just ty' <- coreView ty
= fromVect ty' expr
-- For each function constructor in the original type we add an outer
-- lambda to bind the parameter variable, and an inner application of it.
fromVect (FunTy arg_ty res_ty) expr fromVect (FunTy arg_ty res_ty) expr
= do = do
arg <- newLocalVar (fsLit "x") arg_ty arg <- newLocalVar (fsLit "x") arg_ty
...@@ -941,12 +1113,16 @@ fromVect (FunTy arg_ty res_ty) expr ...@@ -941,12 +1113,16 @@ fromVect (FunTy arg_ty res_ty) expr
body <- fromVect res_ty body <- fromVect res_ty
$ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg] $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
return $ Lam arg body return $ Lam arg body
-- If the type isn't a function then it's time to call on the closure.
fromVect ty expr fromVect ty expr
= identityConv ty >> return expr = identityConv ty >> return expr
toVect :: Type -> CoreExpr -> VM CoreExpr toVect :: Type -> CoreExpr -> VM CoreExpr
toVect ty expr = identityConv ty >> return expr toVect ty expr = identityConv ty >> return expr
identityConv :: Type -> VM () identityConv :: Type -> VM ()
identityConv ty | Just ty' <- coreView ty = identityConv ty'