Commit 46fa261e authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.

Add VECTORISE [SCALAR] type pragma

- Pragma to determine how a given type is vectorised
- At this stage only the VECTORISE SCALAR variant is used by the vectoriser.
- '{-# VECTORISE SCALAR type t #-}' implies that 't' cannot contain parallel arrays and may be used in vectorised code.  However, its constructors can only be used in scalar code.  We use this, e.g., for 'Int'.
- May be used on imported types

See also http://hackage.haskell.org/trac/ghc/wiki/DataParallel/VectPragma
parent 2d0438f3
......@@ -334,6 +334,8 @@ vectsFreeVars = foldr (unionVarSet . vectFreeVars) emptyVarSet
vectFreeVars (Vect _ Nothing) = noFVs
vectFreeVars (Vect _ (Just rhs)) = expr_fvs rhs isLocalId emptyVarSet
vectFreeVars (NoVect _) = noFVs
vectFreeVars (VectType _ _) = noFVs
-- this function is only concerned with values, not types
\end{code}
......
......@@ -735,7 +735,8 @@ substVects subst = map (substVect subst)
substVect :: Subst -> CoreVect -> CoreVect
substVect _subst (Vect v Nothing) = Vect v Nothing
substVect subst (Vect v (Just rhs)) = Vect v (Just (simpleOptExprWith subst rhs))
substVect _subst (NoVect v) = NoVect v
substVect _subst vd@(NoVect _) = vd
substVect _subst vd@(VectType _ _) = vd
------------------
substVarSet :: Subst -> VarSet -> VarSet
......
......@@ -87,12 +87,13 @@ import Coercion
import Name
import Literal
import DataCon
import TyCon
import BasicTypes
import FastString
import Outputable
import Util
import Data.Data
import Data.Data hiding (TyCon)
import Data.Word
infixl 4 `mkApps`, `mkTyApps`, `mkVarApps`, `App`, `mkCoApps`
......@@ -428,9 +429,9 @@ Representation of desugared vectorisation declarations that are fed to the vecto
'ModGuts').
\begin{code}
data CoreVect = Vect Id (Maybe CoreExpr)
| NoVect Id
data CoreVect = Vect Id (Maybe CoreExpr)
| NoVect Id
| VectType TyCon (Maybe Type)
\end{code}
......
......@@ -473,8 +473,11 @@ pprRule (Rule { ru_name = name, ru_act = act, ru_fn = fn,
\begin{code}
instance Outputable CoreVect where
ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var
ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
ppr (Vect var Nothing) = ptext (sLit "VECTORISE SCALAR") <+> ppr var
ppr (Vect var (Just e)) = hang (ptext (sLit "VECTORISE") <+> ppr var <+> char '=')
4 (pprCoreExpr e)
ppr (NoVect var) = ptext (sLit "NOVECTORISE") <+> ppr var
ppr (VectType var Nothing) = ptext (sLit "VECTORISE SCALAR type") <+> ppr var
ppr (VectType var (Just ty)) = hang (ptext (sLit "VECTORISE type") <+> ppr var <+> char '=')
4 (ppr ty)
\end{code}
......@@ -403,7 +403,11 @@ dsVect (L loc (HsVect (L _ v) rhs))
= putSrcSpanDs loc $
do { rhs' <- fmapMaybeM dsLExpr rhs
; return $ Vect v rhs'
}
}
dsVect (L _loc (HsNoVect (L _ v)))
= return $ NoVect v
dsVect (L _loc (HsVectTypeOut tycon ty))
= return $ VectType tycon ty
dsVect vd@(L _ (HsVectTypeIn _ _ty))
= pprPanic "Desugar.dsVect: unexpected 'HsVectTypeIn'" (ppr vd)
\end{code}
......@@ -59,6 +59,7 @@ import HsBinds
import HsPat
import HsTypes
import HsDoc
import TyCon
import NameSet
import {- Kind parts of -} Type
import BasicTypes
......@@ -72,7 +73,7 @@ import SrcLoc
import FastString
import Control.Monad ( liftM )
import Data.Data
import Data.Data hiding (TyCon)
import Data.Maybe ( isJust )
\end{code}
......@@ -1014,6 +1015,9 @@ A vectorisation pragma, one of
{-# VECTORISE f = closure1 g (scalar_map g) #-}
{-# VECTORISE SCALAR f #-}
{-# NOVECTORISE f #-}
{-# VECTORISE type T = ty #-}
{-# VECTORISE SCALAR type T #-}
Note [Typechecked vectorisation pragmas]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -1036,11 +1040,19 @@ data VectDecl name
(Maybe (LHsExpr name)) -- 'Nothing' => SCALAR declaration
| HsNoVect
(Located name)
| HsVectTypeIn -- pre type-checking
(Located name)
(Maybe (LHsType name)) -- 'Nothing' => SCALAR declaration
| HsVectTypeOut -- post type-checking
TyCon
(Maybe Type) -- 'Nothing' => SCALAR declaration
deriving (Data, Typeable)
lvectDeclName :: LVectDecl name -> name
lvectDeclName (L _ (HsVect (L _ name) _)) = name
lvectDeclName (L _ (HsNoVect (L _ name))) = name
lvectDeclName :: Outputable name => LVectDecl name -> name
lvectDeclName (L _ (HsVect (L _ name) _)) = name
lvectDeclName (L _ (HsNoVect (L _ name))) = name
lvectDeclName (L _ (HsVectTypeIn (L _ name) _)) = name
lvectDeclName (L _ (HsVectTypeOut name _)) = pprPanic "HsDecls.HsVectTypeOut" (ppr name)
instance OutputableBndr name => Outputable (VectDecl name) where
ppr (HsVect v Nothing)
......@@ -1051,6 +1063,18 @@ instance OutputableBndr name => Outputable (VectDecl name) where
pprExpr (unLoc rhs) <+> text "#-}" ]
ppr (HsNoVect v)
= sep [text "{-# NOVECTORISE" <+> ppr v <+> text "#-}" ]
ppr (HsVectTypeIn t Nothing)
= sep [text "{-# VECTORISE SCALAR type" <+> ppr t <+> text "#-}" ]
ppr (HsVectTypeIn t (Just ty))
= sep [text "{-# VECTORISE type" <+> ppr t,
nest 4 $
ppr (unLoc ty) <+> text "#-}" ]
ppr (HsVectTypeOut t Nothing)
= sep [text "{-# VECTORISE SCALAR type" <+> ppr t <+> text "#-}" ]
ppr (HsVectTypeOut t (Just ty))
= sep [text "{-# VECTORISE type" <+> ppr t,
nest 4 $
ppr ty <+> text "#-}" ]
\end{code}
%************************************************************************
......
......@@ -563,8 +563,8 @@ topdecls :: { OrdList (LHsDecl RdrName) }
| topdecl { $1 }
topdecl :: { OrdList (LHsDecl RdrName) }
: cl_decl { unitOL (L1 (TyClD (unLoc $1))) }
| ty_decl { unitOL (L1 (TyClD (unLoc $1))) }
: cl_decl { unitOL (L1 (TyClD (unLoc $1))) }
| ty_decl { unitOL (L1 (TyClD (unLoc $1))) }
| 'instance' inst_type where_inst
{ let (binds, sigs, ats, _) = cvBindsAndSigs (unLoc $3)
in
......@@ -575,9 +575,13 @@ topdecl :: { OrdList (LHsDecl RdrName) }
| '{-# DEPRECATED' deprecations '#-}' { $2 }
| '{-# WARNING' warnings '#-}' { $2 }
| '{-# RULES' rules '#-}' { $2 }
| '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
| '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| '{-# VECTORISE_SCALAR' qvar '#-}' { unitOL $ LL $ VectD (HsVect $2 Nothing) }
| '{-# VECTORISE' qvar '=' exp '#-}' { unitOL $ LL $ VectD (HsVect $2 (Just $4)) }
| '{-# NOVECTORISE' qvar '#-}' { unitOL $ LL $ VectD (HsNoVect $2) }
| '{-# VECTORISE_SCALAR' 'type' qtycon '#-}'
{ unitOL $ LL $ VectD (HsVectTypeIn $3 Nothing) }
| '{-# VECTORISE' 'type' qtycon '=' ctype '#-}'
{ unitOL $ LL $ VectD (HsVectTypeIn $3 (Just $5)) }
| annotation { unitOL $1 }
| decl { unLoc $1 }
......
......@@ -659,24 +659,37 @@ badRuleLhsErr name lhs bad_e
\begin{code}
rnHsVectDecl :: VectDecl RdrName -> RnM (VectDecl Name, FreeVars)
rnHsVectDecl (HsVect var Nothing)
= do { var' <- wrapLocM lookupTopBndrRn var
= do { var' <- lookupLocatedTopBndrRn var
; return (HsVect var' Nothing, unitFV (unLoc var'))
}
rnHsVectDecl (HsVect var (Just rhs))
= do { var' <- wrapLocM lookupTopBndrRn var
= do { var' <- lookupLocatedTopBndrRn var
; (rhs', fv_rhs) <- rnLExpr rhs
; return (HsVect var' (Just rhs'), fv_rhs `addOneFV` unLoc var')
}
rnHsVectDecl (HsNoVect var)
= do { var' <- wrapLocM lookupTopBndrRn var
= do { var' <- lookupLocatedTopBndrRn var
; return (HsNoVect var', unitFV (unLoc var'))
}
rnHsVectDecl (HsVectTypeIn tycon Nothing)
= do { tycon' <- lookupLocatedOccRn tycon
; return (HsVectTypeIn tycon' Nothing, unitFV (unLoc tycon'))
}
rnHsVectDecl (HsVectTypeIn tycon (Just ty))
= do { tycon' <- lookupLocatedOccRn tycon
; (ty', fv_ty) <- rnHsTypeFVs vect_doc ty
; return (HsVectTypeIn tycon' (Just ty'), fv_ty `addOneFV` unLoc tycon')
}
where
vect_doc = text "In the VECTORISE pragma for type constructor" <+> quotes (ppr tycon)
rnHsVectDecl (HsVectTypeOut _ _)
= panic "RnSource.rnHsVectDecl: Unexpected 'HsVectTypeOut'"
\end{code}
%*********************************************************
%* *
%* *
\subsection{Type, class and iface sig declarations}
%* *
%* *
%*********************************************************
@rnTyDecl@ uses the `global name function' to create a new type
......@@ -711,7 +724,7 @@ rnTyClDecl (ForeignType {tcdLName = name, tcdExtName = ext_name})
return (ForeignType {tcdLName = name', tcdExtName = ext_name},
emptyFVs)
-- all flavours of type family declarations ("type family", "newtype family",
-- and "data family")
rnTyClDecl tydecl@TyFamily {} = rnFamily tydecl bindTyVarsFV
......
......@@ -24,6 +24,7 @@ import TcSimplify
import TcHsType
import TcPat
import TcMType
import TyCon
import TcType
-- import Coercion
import TysPrim
......@@ -682,10 +683,23 @@ tcVect (HsNoVect name)
do { id <- wrapLocM tcLookupId name
; return $ HsNoVect id
}
tcVect (HsVectTypeIn lname@(L _ name) ty)
= addErrCtxt (vectCtxt lname) $
do { tycon <- tcLookupTyCon name
; checkTc (tyConArity tycon /= 0) scalarTyConMustBeNullary
; ty' <- fmapMaybeM dsHsType ty
; return $ HsVectTypeOut tycon ty'
}
tcVect (HsVectTypeOut _ _)
= panic "TcBinds.tcVect: Unexpected 'HsVectTypeOut'"
vectCtxt :: Located Name -> SDoc
vectCtxt name = ptext (sLit "When checking the vectorisation declaration for") <+> ppr name
scalarTyConMustBeNullary :: Message
scalarTyConMustBeNullary = ptext (sLit "VECTORISE SCALAR type constructor must be nullary")
--------------
-- If typechecking the binds fails, then return with each
-- signature-less binder given type (forall a.a), to minimise
......
1%
%
% (c) The University of Glasgow 2006
% (c) The AQUA Project, Glasgow University, 1996-1998
%
......@@ -1022,19 +1022,20 @@ zonkVects :: ZonkEnv -> [LVectDecl TcId] -> TcM [LVectDecl Id]
zonkVects env = mappM (wrapLocM (zonkVect env))
zonkVect :: ZonkEnv -> VectDecl TcId -> TcM (VectDecl Id)
zonkVect env (HsVect v Nothing)
= do { v' <- wrapLocM (zonkIdBndr env) v
; return $ HsVect v' Nothing
}
zonkVect env (HsVect v (Just e))
zonkVect env (HsVect v e)
= do { v' <- wrapLocM (zonkIdBndr env) v
; e' <- zonkLExpr env e
; return $ HsVect v' (Just e')
; e' <- fmapMaybeM (zonkLExpr env) e
; return $ HsVect v' e'
}
zonkVect env (HsNoVect v)
= do { v' <- wrapLocM (zonkIdBndr env) v
; return $ HsNoVect v'
}
zonkVect _env (HsVectTypeOut t ty)
= do { ty' <- fmapMaybeM zonkTypeZapping ty
; return $ HsVectTypeOut t ty'
}
zonkVect _ (HsVectTypeIn _ _) = panic "TcHsSyn.zonkVect: HsVectTypeIn"
\end{code}
%************************************************************************
......
-- Main entry point to the vectoriser. It is invoked iff the option '-fvectorise' is passed.
--
-- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
-- It vectorises all type declarations and value bindings. It also processes all VECTORISE pragmas
-- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
-- and the enrichment of imported functions with vectorised versions.
module Vectorise ( vectorise )
where
......@@ -55,22 +61,22 @@ vectoriseIO hsc_env guts
-- | Vectorise a single module, in the VM monad.
--
vectModule :: ModGuts -> VM ModGuts
vectModule guts@(ModGuts { mg_types = types
, mg_binds = binds
, mg_fam_insts = fam_insts
vectModule guts@(ModGuts { mg_types = types
, mg_binds = binds
, mg_fam_insts = fam_insts
, mg_vect_decls = vect_decls
})
= do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
pprCoreBindings binds
-- Vectorise the type environment.
-- This may add new TyCons and DataCons.
; (types', new_fam_insts, tc_binds) <- vectTypeEnv types
-- Vectorise the type environment. This will add vectorised type constructors, their
-- representaions, and the conrresponding data constructors. Moreover, we produce
-- bindings for dfuns and family instances of the classes and type families used in the
-- DPH library to represent array types.
; (types', new_fam_insts, tc_binds) <- vectTypeEnv types [vd | vd@(VectType _ _) <- vect_decls]
; (_, fam_inst_env) <- readGEnv global_fam_inst_env
-- dicts <- mapM buildPADict pa_insts
-- workers <- mapM vectDataConWorkers pa_insts
-- Vectorise all the top level bindings.
; binds' <- mapM vectTopBind binds
......
-- | Builtin types and functions used by the vectoriser.
-- The source program uses functions from Data.Array.Parallel, which the vectoriser rewrites
-- to use equivalent vectorised versions in the DPH backend packages.
--
-- The `Builtins` structure holds the name of all the things in the DPH packages
-- we will need. We can get specific things using the selectors, which print a
-- civilized panic message if the specified thing cannot be found.
-- Types and functions declared in the DPH packages and used by the vectoriser.
--
-- The @Builtins@ structure holds the name of all the things in the DPH packages that appear in
-- code generated by the vectoriser. We can get specific things using the selectors, which print a
-- civilized panic message if the specified thing cannot be found.
module Vectorise.Builtins (
-- * Builtins
Builtins(..),
......
-- Set up the data structures provided by 'Vectorise.Builtins'.
module Vectorise.Builtins.Initialise (
-- * Initialisation
......@@ -81,10 +82,10 @@ initBuiltins pkg
-- From dph-common:Data.Array.Parallel.PArray.Types
voidTyCon <- externalTyCon dph_PArray_Types (fsLit "Void")
voidVar <- externalVar dph_PArray_Types (fsLit "void")
fromVoidVar <- externalVar dph_PArray_Types (fsLit "fromVoid")
voidVar <- externalVar dph_PArray_Types (fsLit "void")
fromVoidVar <- externalVar dph_PArray_Types (fsLit "fromVoid")
wrapTyCon <- externalTyCon dph_PArray_Types (fsLit "Wrap")
sum_tcs <- mapM (externalTyCon dph_PArray_Types) (numbered "Sum" 2 mAX_DPH_SUM)
sum_tcs <- mapM (externalTyCon dph_PArray_Types) (numbered "Sum" 2 mAX_DPH_SUM)
-- from dph-common:Data.Array.Parallel.PArray.PDataInstances
pvoidVar <- externalVar dph_PArray_PDataInstances (fsLit "pvoid")
......
......@@ -76,55 +76,56 @@ emptyLocalEnv = LocalEnv {
--
data GlobalEnv
= GlobalEnv
{ global_vars :: VarEnv Var
{ global_vars :: VarEnv Var
-- ^Mapping from global variables to their vectorised versions — aka the /vectorisation
-- map/.
, global_vect_decls :: VarEnv (Type, CoreExpr)
, global_vect_decls :: VarEnv (Type, CoreExpr)
-- ^Mapping from global variables that have a vectorisation declaration to the right-hand
-- side of that declaration and its type. This mapping only applies to non-scalar
-- vectorisation declarations. All variables with a scalar vectorisation declaration are
-- mentioned in 'global_scalars_vars'.
, global_scalar_vars :: VarSet
, global_scalar_vars :: VarSet
-- ^Purely scalar variables. Code which mentions only these variables doesn't have to be
-- lifted. This includes variables from the current module that have a scalar
-- vectorisation declaration and those that the vectoriser determines to be scalar.
, global_scalar_tycons :: NameSet
-- ^Type constructors whose values can only contain scalar data. Scalar code may only
-- operate on such data.
, global_scalar_tycons :: NameSet
-- ^Type constructors whose values can only contain scalar data and that appear in a
-- 'VECTORISE SCALAR type' pragma in the current or an imported module. Scalar code may
-- only operate on such data.
, global_novect_vars :: VarSet
, global_novect_vars :: VarSet
-- ^Variables that are not vectorised. (They may be referenced in the right-hand sides
-- of vectorisation declarations, though.)
, global_exported_vars :: VarEnv (Var, Var)
, global_exported_vars :: VarEnv (Var, Var)
-- ^Exported variables which have a vectorised version.
, global_tycons :: NameEnv TyCon
, global_tycons :: NameEnv TyCon
-- ^Mapping from TyCons to their vectorised versions.
-- TyCons which do not have to be vectorised are mapped to themselves.
, global_datacons :: NameEnv DataCon
, global_datacons :: NameEnv DataCon
-- ^Mapping from DataCons to their vectorised versions.
, global_pa_funs :: NameEnv Var
, global_pa_funs :: NameEnv Var
-- ^Mapping from TyCons to their PA dfuns.
, global_pr_funs :: NameEnv Var
, global_pr_funs :: NameEnv Var
-- ^Mapping from TyCons to their PR dfuns.
, global_boxed_tycons :: NameEnv TyCon
, global_boxed_tycons :: NameEnv TyCon
-- ^Mapping from unboxed TyCons to their boxed versions.
, global_inst_env :: (InstEnv, InstEnv)
, global_inst_env :: (InstEnv, InstEnv)
-- ^External package inst-env & home-package inst-env for class instances.
, global_fam_inst_env :: FamInstEnvs
, global_fam_inst_env :: FamInstEnvs
-- ^External package inst-env & home-package inst-env for family instances.
, global_bindings :: [(Var, CoreExpr)]
, global_bindings :: [(Var, CoreExpr)]
-- ^Hoisted bindings.
}
......@@ -133,25 +134,26 @@ data GlobalEnv
initGlobalEnv :: VectInfo -> [CoreVect] -> (InstEnv, InstEnv) -> FamInstEnvs -> GlobalEnv
initGlobalEnv info vectDecls instEnvs famInstEnvs
= GlobalEnv
{ global_vars = mapVarEnv snd $ vectInfoVar info
, global_vect_decls = mkVarEnv vects
, global_scalar_vars = vectInfoScalarVars info `extendVarSetList` scalars
, global_scalar_tycons = vectInfoScalarTyCons info
, global_novect_vars = mkVarSet novects
, global_exported_vars = emptyVarEnv
, global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info
, global_pa_funs = mapNameEnv snd $ vectInfoPADFun info
, global_pr_funs = emptyNameEnv
, global_boxed_tycons = emptyNameEnv
, global_inst_env = instEnvs
, global_fam_inst_env = famInstEnvs
, global_bindings = []
{ global_vars = mapVarEnv snd $ vectInfoVar info
, global_vect_decls = mkVarEnv vects
, global_scalar_vars = vectInfoScalarVars info `extendVarSetList` scalar_vars
, global_scalar_tycons = vectInfoScalarTyCons info `addListToNameSet` scalar_tycons
, global_novect_vars = mkVarSet novects
, global_exported_vars = emptyVarEnv
, global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info
, global_pa_funs = mapNameEnv snd $ vectInfoPADFun info
, global_pr_funs = emptyNameEnv
, global_boxed_tycons = emptyNameEnv
, global_inst_env = instEnvs
, global_fam_inst_env = famInstEnvs
, global_bindings = []
}
where
vects = [(var, (varType var, exp)) | Vect var (Just exp) <- vectDecls]
scalars = [var | Vect var Nothing <- vectDecls]
novects = [var | NoVect var <- vectDecls]
vects = [(var, (varType var, exp)) | Vect var (Just exp) <- vectDecls]
scalar_vars = [var | Vect var Nothing <- vectDecls]
novects = [var | NoVect var <- vectDecls]
scalar_tycons = [tyConName tycon | VectType tycon Nothing <- vectDecls]
-- Operators on Global Environments -------------------------------------------
......@@ -214,9 +216,9 @@ modVectInfo :: GlobalEnv -> TypeEnv -> VectInfo -> VectInfo
modVectInfo env tyenv info
= info
{ vectInfoVar = global_exported_vars env
, vectInfoTyCon = mk_env typeEnvTyCons global_tycons
, vectInfoTyCon = mk_env typeEnvTyCons global_tycons
, vectInfoDataCon = mk_env typeEnvDataCons global_datacons
, vectInfoPADFun = mk_env typeEnvTyCons global_pa_funs
, vectInfoPADFun = mk_env typeEnvTyCons global_pa_funs
, vectInfoScalarVars = global_scalar_vars env `minusVarSet` vectInfoScalarVars info
, vectInfoScalarTyCons = global_scalar_tycons env `minusNameSet` vectInfoScalarTyCons info
}
......
......@@ -26,6 +26,7 @@ import CoreFVs
import DataCon
import TyCon
import Type
import NameSet
import Var
import VarEnv
import VarSet
......@@ -42,11 +43,11 @@ import Data.List
-- | Vectorise a polymorphic expression.
--
vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
-- binding is a loop breaker.
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
-- binding is a loop breaker.
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
= do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
......@@ -194,26 +195,24 @@ vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
-> CoreExpr -- ^ Expression to be vectorised
-> VM VExpr
vectScalarFun forceScalar recFns expr
= do { gscalars <- globalScalars
; let scalars = gscalars `extendVarSetList` recFns
= do { gscalarVars <- globalScalarVars
; scalarTyCons <- globalScalarTyCons
; let scalarVars = gscalarVars `extendVarSetList` recFns
(arg_tys, res_ty) = splitFunTys (exprType expr)
; MASSERT( not $ null arg_tys )
; onlyIfV (forceScalar -- user asserts the functions is scalar
; onlyIfV (forceScalar -- user asserts the functions is scalar
||
all is_prim_ty arg_tys -- check whether the function is scalar
&& is_prim_ty res_ty
&& is_scalar scalars expr
&& uses scalars expr)
all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
&& is_scalar_ty scalarTyCons res_ty
&& is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
&& uses scalarVars expr)
$ mkScalarFun arg_tys res_ty expr
}
where
-- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!!
is_prim_ty ty
| Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
| otherwise = False
is_scalar_ty scalarTyCons ty
| Just (tycon, _) <- splitTyConApp_maybe ty
= tyConName tycon `elemNameSet` scalarTyCons
| otherwise = False
-- Checks whether an expression contain a non-scalar subexpression.
--
......@@ -223,40 +222,45 @@ vectScalarFun forceScalar recFns expr
-- them to the list of scalar variables) and then check them. If one of them turns out not to
-- be scalar, the entire group is regarded as not being scalar.
--
-- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
-- data constructor as scalar. Should be changed once scalar types are passed
-- through VectInfo.
-- The second argument is a predicate that checks whether a type is scalar.
--
is_scalar :: VarSet -> CoreExpr -> Bool
is_scalar scalars (Var v) = v `elemVarSet` scalars
is_scalar _scalars (Lit _) = True
is_scalar scalars e@(App e1 e2)
| maybe_parr_ty (exprType e) = False
| otherwise = is_scalar scalars e1 && is_scalar scalars e2
is_scalar scalars (Lam var body)
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var) body
is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body
is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars
is_scalar _scalars _isScalarTC (Lit _) = True
is_scalar scalars isScalarTC e@(App e1 e2)
| maybe_parr_ty (exprType e) = False
| otherwise = is_scalar scalars isScalarTC e1 &&
is_scalar scalars isScalarTC e2
is_scalar scalars isScalarTC (Lam var body)
| maybe_parr_ty (varType var) = False
| otherwise = is_scalar (scalars `extendVarSet` var)
isScalarTC body
is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar &&
is_scalar scalars' isScalarTC body
where
(bindsAreScalar, scalars') = is_scalar_bind scalars bind
is_scalar scalars (Case e var ty alts)
| is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts
(bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
is_scalar scalars isScalarTC (Case e var ty alts)
| isScalarTC ty = is_scalar scalars' isScalarTC e &&
all (is_scalar_alt scalars' isScalarTC) alts
| otherwise = False
where
scalars' = scalars `extendVarSet` var
is_scalar scalars (Cast e _coe) = is_scalar scalars e
is_scalar scalars (Note _ e ) = is_scalar scalars e
is_scalar _scalars (Type {}) = True
is_scalar _scalars (Coercion {}) = True
is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e
is_scalar scalars isScalarTC (Note _ e ) = is_scalar scalars isScalarTC e
is_scalar _scalars _isScalarTC (Type {}) = True
is_scalar _scalars _isScalarTC (Coercion {}) = True
-- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars')
is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e,
scalars `extendVarSet` var)
is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es,
scalars')
where
(vars, es) = unzip bnds
scalars' = scalars `extendVarSetList` vars
is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars)
isScalarTCs e
-- Checks whether the type might be a parallel array type. In particular, if the outermost
-- constructor is a type family, we conservatively assume that it may be a parallel array type.
......
-- Operations on the global state of the vectorisation monad.
module Vectorise.Monad.Global (
readGEnv,
......@@ -11,12 +12,11 @@ module Vectorise.Monad.Global (
lookupVectDecl, noVectDecl,
-- * Scalars
globalScalars, isGlobalScalar,
globalScalarVars, isGlobalScalar, globalScalarTyCons,
-- * TyCons
lookupTyCon,
lookupBoxedTyCon,
defTyCon,
lookupTyCon, lookupBoxedTyCon,
defTyCon, globalVectTyCons,
-- * Datacons
lookupDataCon,
......@@ -24,7 +24,6 @@ module Vectorise.Monad.Global (
-- * PA Dictionaries
lookupTyConPA,
defTyConPA,
defTyConPAs,
-- * PR Dictionaries
......@@ -39,6 +38,7 @@ import Type
import TyCon
import DataCon
import NameEnv
import NameSet
import Var
import VarEnv
import VarSet
......@@ -49,17 +49,17 @@ import VarSet
-- |Project something from the global environment.
--
readGEnv :: (GlobalEnv -> a) -> VM a
readGEnv f = VM $ \_ genv lenv -> return (Yes genv lenv (f genv))
readGEnv f = VM $ \_ genv lenv -> return (Yes genv lenv (f genv))
-- |Set the value of the global environment.
--
setGEnv :: GlobalEnv -> VM ()
setGEnv genv = VM $ \_ _ lenv -> return (Yes genv lenv ())
setGEnv genv = VM $ \_ _ lenv -> return (Yes genv lenv ())
-- |Update the global environment using the provided function.
--
updGEnv :: (GlobalEnv -> GlobalEnv) -> VM ()
updGEnv f = VM $ \_ genv lenv -> return (Yes (f genv) lenv ())
updGEnv f = VM $ \_ genv lenv -> return (Yes (f genv) lenv ())
-- Vars -----------------------------------------------------------------------
......@@ -93,13 +93,19 @@ noVectDecl var = readGEnv $ \env -> elemVarSet var (global_novect_vars env)
-- |Get the set of global scalar variables.
--
globalScalars :: VM VarSet
globalScalars = readGEnv global_scalar_vars
globalScalarVars :: VM VarSet
globalScalarVars = readGEnv global_scalar_vars
-- |Check whether a given variable is in the set of global scalar variables.
--
isGlobalScalar :: Var -> VM Bool
isGlobalScalar var = readGEnv $ \env -> elemVarSet var (global_scalar_vars env)
isGlobalScalar var = readGEnv $ \env -> var `elemVarSet` global_scalar_vars env
-- |Get the set of global scalar type constructors including both those scalar type constructors
-- declared in an imported module and those declared in the current module.
--
globalScalarTyCons :: VM NameSet
globalScalarTyCons = readGEnv global_scalar_tycons
-- TyCons ---------------------------------------------------------------------
...