Commit 9097e67b authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.
Browse files

First cut at scalar vectorisation of class instances

parent 44d999bb
......@@ -81,25 +81,15 @@ vectModule guts@(ModGuts { mg_tcs = tycons
-- array types.
; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
{- TODO:
instance Num Int where
(+) = primAdd
{-# VECTORISE SCALAR instance Num Int #-}
==> $dNumInt :: Num Int; $dNumInt = Num primAdd
=>> $v$dNumInt :: $vNum Int
$v$dNumInt = $vNum (closure1 (scalar_zipWith primAdd) (scalar_zipWith primAdd))
$dNumInt -v> $v$dNumInt
-}
-- Family instance environment for /all/ home-package modules including those instances
-- generated by 'vectTypeEnv'.
; (_, fam_inst_env) <- readGEnv global_fam_inst_env
-- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
; let impBinds = [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id] ++
[imp_id | VectInst True imp_id <- vect_decls, isGlobalId imp_id]
; binds_top <- mapM vectTopBind binds
; binds_imp <- mapM vectImpBind [imp_id | Vect imp_id _ <- vect_decls, isGlobalId imp_id]
; binds_imp <- mapM vectImpBind impBinds
; return $ guts { mg_tcs = tycons ++ new_tycons
-- we produce no new classes or instances, only new class type constructors
......@@ -283,21 +273,63 @@ vectTopBinder var inline expr
unfolding = case inline of
Inline arity -> mkInlineUnfolding (Just arity) expr
DontInline -> noUnfolding
{-
!!!TODO: dfuns and unfoldings:
-- Do not inline the dfun; instead give it a magic DFunFunfolding
-- See Note [ClassOp/DFun selection]
-- See also note [Single-method classes]
dfun_id_w_fun
| isNewTyCon class_tc
= dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
| otherwise
= dfun_id `setIdUnfolding` mkDFunUnfolding dfun_ty dfun_args
`setInlinePragma` dfunInlinePragma
-}
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
--
-- We need to distinguish three cases:
-- We need to distinguish four cases:
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
-- vectorised code implemented by the user)
-- => no automatic vectorisation & instead use the user-supplied code
--
-- (2) We have a scalar vectorisation declaration for the variable
-- (2) We have a scalar vectorisation declaration for a variable that is no dfun
-- => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
--
-- (3) There is no vectorisation declaration for the variable
-- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun
-- => generate vectorised code according to the the "Note [Scalar dfuns]" below
--
-- (4) There is no vectorisation declaration for the variable
-- => perform automatic vectorisation of the RHS
--
-- Note [Scalar dfuns]
-- ~~~~~~~~~~~~~~~~~~~
--
-- Here is the translation scheme for scalar dfuns — assume the instance declaration:
--
-- instance Num Int where
-- (+) = primAdd
-- {-# VECTORISE SCALAR instance Num Int #-}
--
-- It desugars to
--
-- $dNumInt :: Num Int
-- $dNumInt = D:Num primAdd
--
-- We vectorise it to
--
-- $v$dNumInt :: V:Num Int
-- $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt))))
--
-- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'.
--
-- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'.
--
-- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun.
-- In fact, we definitely want to refer to the dfn variable instead of the right-hand side to
-- ensure that the dictionary selection rules fire.
--
vectTopRhs :: [Var] -- ^ Names of all functions in the rec block
-> Var -- ^ Name of the binding.
-> CoreExpr -- ^ Body of the binding.
......@@ -308,19 +340,24 @@ vectTopRhs recFs var expr
= closedV
$ do { globalScalar <- isGlobalScalar var
; vectDecl <- lookupVectDecl var
; let isDFun = isDFunId var
; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar vectDecl) $ ppr expr
; traceVt ("vectTopRhs of " ++ show var ++ info globalScalar isDFun vectDecl) $ ppr expr
; rhs globalScalar vectDecl
; rhs globalScalar isDFun vectDecl
}
where
rhs _globalScalar (Just (_, expr')) -- Case (1)
rhs _globalScalar _isDFun (Just (_, expr')) -- Case (1)
= return (inlineMe, False, expr')
rhs True Nothing -- Case (2)
rhs True False Nothing -- Case (2)
= do { expr' <- vectScalarFun True recFs expr
; return (inlineMe, True, vectorised expr')
}
rhs False Nothing -- Case (3)
rhs True True Nothing -- Case (3)
= do { expr' <- vectScalarDFun var recFs
; return (DontInline, True, expr')
}
rhs False _isDFun Nothing -- Case (4)
= do { let fvs = freeVars expr
; (inline, isScalar, vexpr)
<- inBind var $
......@@ -328,9 +365,10 @@ vectTopRhs recFs var expr
; return (inline, isScalar, vectorised vexpr)
}
info True _ = " [VECTORISE SCALAR]"
info False vectDecl | isJust vectDecl = " [VECTORISE]"
| otherwise = " (no pragma)"
info True False _ = " [VECTORISE SCALAR]"
info True True _ = " [VECTORISE SCALAR instance]"
info False _ vectDecl | isJust vectDecl = " [VECTORISE]"
| otherwise = " (no pragma)"
-- |Project out the vectorised version of a binding from some closure,
-- or return the original body if that doesn't work or the binding is scalar.
......
......@@ -145,7 +145,8 @@ initGlobalEnv info vectDecls instEnvs famInstEnvs
-- FIXME: we currently only allow RHSes consisting of a
-- single variable to be able to obtain the type without
-- inference — see also 'TcBinds.tcVect'
scalar_vars = [var | Vect var Nothing <- vectDecls]
scalar_vars = [var | Vect var Nothing <- vectDecls] ++
[var | VectInst True var <- vectDecls]
novects = [var | NoVect var <- vectDecls]
scalar_tycons = [tyConName tycon | VectType True tycon _ <- vectDecls]
......
-- |Vectorisation of expressions.
-- | Vectorisation of expressions.
module Vectorise.Exp (
-- Vectorise a polymorphic expression
vectPolyExpr,
-- Vectorise a scalar expression of functional type
vectScalarFun
) where
module Vectorise.Exp
( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular
-- variable bindings
vectPolyExpr
, vectScalarFun
, vectScalarDFun
)
where
#include "HsVersions.h"
import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Convert
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
import Vectorise.Utils
import CoreSyn
import CoreUtils
import MkCore
import CoreSyn
import CoreFVs
import Class
import DataCon
import TyCon
import TcType
import Type
import NameSet
import Var
......@@ -38,6 +41,7 @@ import TysPrim
import Outputable
import FastString
import Control.Monad
import Control.Applicative
import Data.List
......@@ -82,6 +86,7 @@ vectExpr (_, AnnTick tickish expr)
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
-- happy.
-- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
| v == pAT_ERROR_ID
= do { (vty, lty) <- vectAndLiftType ty
......@@ -168,7 +173,7 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
-- | Vectorise an expression with an outer lambda abstraction.
-- |Vectorise an expression with an outer lambda abstraction.
--
vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
-- be inlined
......@@ -201,7 +206,7 @@ vectScalarFun forceScalar recFns expr
; let scalarVars = gscalarVars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
; onlyIfV empty
; onlyIfV (ptext (sLit "not a scalar function"))
(forceScalar -- user asserts the functions is scalar
||
all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
......@@ -300,6 +305,109 @@ mkScalarFun arg_tys res_ty expr
; return (Var clo_var, lclo)
}
-- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
--
-- In other words, all methods in that dictionary are scalar functions — to be vectorised with
-- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
--
-- NB: You may think that we could implement this function guided by the struture of the Core
-- expression of the right-hand side of the dictionary function. We cannot proceed like this as
-- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
-- to the Core code of the unvectorised dfun.
--
-- Here an example — assume,
--
-- > class Eq a where { (==) :: a -> a -> Bool }
-- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
-- > {-# VECTORISE SCALAR instance Eq (a, b) }
--
-- The unvectorised dfun for the above instance has the following signature:
--
-- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
--
-- We generate the following (scalar) vectorised dfun (liberally using TH notation):
--
-- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
-- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
-- > D:V:Eq $(vectScalarFun True recFns
-- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
--
-- NB:
-- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
-- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
-- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
--
vectScalarDFun :: Var -- ^ Original dfun
-> [Var] -- ^ Functions names in same recursive binding group
-> VM CoreExpr
vectScalarDFun var recFns
= do { -- bring the type variables into scope
; mapM_ defLocalTyVar tvs
-- vectorise dictionary argument types and generate variables for them
; vTheta <- mapM vectType theta
; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
; let vThetaVars = varsToCoreExprs vThetaBndr
-- vectorise superclass dictionaries and methods as scalar expressions
; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
; thetaExprs <- zipWithM unVectDict theta vThetaVars
; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
selIds
; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun True recFns e) scsOps
-- vectorised applications of the class-dictionary data constructor
; Just vDataCon <- lookupDataCon dataCon
; vTys <- mapM vectType tys
; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
; return $ mkLams (tvs ++ vThetaBndr) vBody
}
where
ty = varType var
(tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
(cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
selIds = classAllSelIds cls
dataCon = classDataCon cls
-- Build a value of the dictionary before vectorisation from original, unvectorised type and an
-- expression computing the vectorised dictionary.
--
-- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
-- the unvectorised version, thus:
--
-- > D:C op1 .. opm
-- > where
-- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
--
-- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
--
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e
= do { vTys <- mapM vectType tys
; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
; scOps <- zipWithM fromVect methTys meths
; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
}
where
(tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
cls = case tyConClass_maybe tycon of
Just cls -> cls
Nothing -> panic "Vectorise.Exp.unVectDict: no class"
selIds = classAllSelIds cls
{-
!!!How about 'isClassOpId_maybe'? Do we need to treat them specially to get the class ops for
!!!the vectorised instances or do they just work out?? (We may want to make sure that the
!!!vectorised Ids at least get the right IdDetails...)
!!!NB: For *locally defined* instances, the selector functions are part of the vectorised bindings,
!!! but not so for *imported* instances, where we need to generate the vectorised versions from
!!! scratch.
!!!Also need to take care of the builtin rules for selectors (see mkDictSelId).
-}
-- | Vectorise a lambda abstraction.
--
vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
......
......@@ -137,7 +137,6 @@ lookupDataCon :: DataCon -> VM (Maybe DataCon)
lookupDataCon dc
| isTupleTyCon (dataConTyCon dc)
= return (Just dc)
| otherwise
= readGEnv $ \env -> lookupNameEnv (global_datacons env) (dataConName dc)
......
......@@ -9,15 +9,18 @@ module Vectorise.Monad.Naming
, newLocalVars
, newDummyVar
, newTyVar
) where
)
where
import Vectorise.Monad.Base
import DsMonad
import TcType
import Type
import Var
import Name
import SrcLoc
import MkId
import Id
import FastString
......@@ -43,7 +46,8 @@ mkLocalisedName mk_occ name =
; return new_name
}
-- |Produce the vectorised variant of an `Id` with the given type.
-- |Produce the vectorised variant of an `Id` with the given type, while taking care that vectorised
-- dfun ids must be dfuns again.
--
-- Force the new name to be a system name and, if the original was an external name, disambiguate
-- the new name with the module name of the original.
......@@ -51,10 +55,17 @@ mkLocalisedName mk_occ name =
mkVectId :: Id -> Type -> VM Id
mkVectId id ty
= do { name <- mkLocalisedName mkVectOcc (getName id)
; let id' | isExportedId id = Id.mkExportedLocalId name ty
; let id' | isDFunId id = MkId.mkDictFunId name tvs theta cls tys
| isExportedId id = Id.mkExportedLocalId name ty
| otherwise = Id.mkLocalId name ty
; return id'
}
where
-- Decompose a dictionary function signature: \forall tvs. theta -> cls tys
-- NB: We do *not* use closures '(:->)' for vectorised predicate abstraction as dictionary
-- functions are always fully applied.
(tvs, theta, pty) = tcSplitSigmaTy ty
(cls, tys) = tcSplitDFunHead pty
-- |Make a fresh instance of this var, with a new unique.
--
......
......@@ -108,16 +108,16 @@ import Data.List
--
-- It desugars to
--
-- data Num a = Num { (+) :: a -> a -> a }
-- data Num a = D:Num { (+) :: a -> a -> a }
--
-- which we vectorise to
--
-- data $vNum a = $vNum { ($v+) :: PArray a :-> PArray a :-> PArray a }
-- data V:Num a = D:V:Num { ($v+) :: PArray a :-> PArray a :-> PArray a }
--
-- while adding the following entries to the vectorisation map:
--
-- tycon : Num --> $vNum
-- datacon: Num --> $vNum
-- tycon : Num --> V:Num
-- datacon: D:Num --> D:V:Num
-- var : (+) --> ($v+)
-- |Vectorise type constructor including class type constructors.
......
......@@ -6,8 +6,7 @@ module Vectorise.Utils.Closure (
buildClosure,
buildClosures,
buildEnv
)
where
) where
import Vectorise.Builtins
import Vectorise.Vect
......@@ -28,15 +27,14 @@ import BasicTypes( TupleSort(..) )
import FastString
-- | Make a closure.
mkClosure
:: Type -- ^ Type of the argument.
-> Type -- ^ Type of the result.
-> Type -- ^ Type of the environment.
-> VExpr -- ^ The function to apply.
-> VExpr -- ^ The environment to use.
-> VM VExpr
-- |Make a closure.
--
mkClosure :: Type -- ^ Type of the argument.
-> Type -- ^ Type of the result.
-> Type -- ^ Type of the environment.
-> VExpr -- ^ The function to apply.
-> VExpr -- ^ The environment to use.
-> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
= do dict <- paDictOfType env_ty
mkv <- builtin closureVar
......@@ -44,15 +42,13 @@ mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
-- | Make a closure application.
mkClosureApp
:: Type -- ^ Type of the argument.
-> Type -- ^ Type of the result.
-> VExpr -- ^ Closure to apply.
-> VExpr -- ^ Argument to use.
-> VM VExpr
-- |Make a closure application.
--
mkClosureApp :: Type -- ^ Type of the argument.
-> Type -- ^ Type of the result.
-> VExpr -- ^ Closure to apply.
-> VExpr -- ^ Argument to use.
-> VM VExpr
mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
= do vapply <- builtin applyVar
lapply <- builtin liftedApplyVar
......@@ -60,21 +56,16 @@ mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [Var lc, lclo, larg])
buildClosures
:: [TyVar]
-> [VVar]
-> [Type] -- ^ Type of the arguments.
-> Type -- ^ Type of result.
-> VM VExpr
-> VM VExpr
buildClosures :: [TyVar]
-> [VVar]
-> [Type] -- ^ Type of the arguments.
-> Type -- ^ Type of result.
-> VM VExpr
-> VM VExpr
buildClosures _ _ [] _ mk_body
= mk_body
buildClosures tvs vars [arg_ty] res_ty mk_body
= buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
= do res_ty' <- mkClosureTypes arg_tys res_ty
arg <- newLocalVVar (fsLit "x") arg_ty
......@@ -85,7 +76,6 @@ buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
return $ vLams lc (vars ++ [arg]) clo
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
-- where
-- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
......@@ -110,6 +100,7 @@ buildClosure tvs vars arg_ty res_ty mk_body
-- Environments ---------------------------------------------------------------
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
buildEnv []
= do
......@@ -117,10 +108,9 @@ buildEnv []
void <- builtin voidVar
pvoid <- builtin pvoidVar
return (ty, vVar (void, pvoid), \_ body -> body)
buildEnv [v] = return (vVarType v, vVar v,
\env body -> vLet (vNonRec v env) body)
buildEnv [v]
= return (vVarType v, vVar v,
\env body -> vLet (vNonRec v env) body)
buildEnv vs
= do (lenv_tc, lenv_tyargs) <- pdataReprTyCon ty
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment