Commit 02cff9df authored by rl@cse.unsw.edu.au's avatar rl@cse.unsw.edu.au
Browse files

More vectorisation-related smart constructors

parent f3114b4a
module VectCore (
Vect, VVar, VExpr,
Vect, VVar, VExpr, VBind,
vectorised, lifted,
mapVect,
vVar, mkVLams, mkVVarApps
vNonRec, vRec,
vVar, vType, vNote, vLet,
mkVLams, mkVVarApps
) where
#include "HsVersions.h"
import CoreSyn
import Type ( Type )
import Var
type Vect a = (a,a)
type VVar = Vect Var
type VExpr = Vect CoreExpr
type VBind = Vect CoreBind
vectorised :: Vect a -> a
vectorised = fst
......@@ -25,9 +30,30 @@ lifted = snd
mapVect :: (a -> b) -> Vect a -> Vect b
mapVect f (x,y) = (f x, f y)
zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
vVar :: VVar -> VExpr
vVar = mapVect Var
vType :: Type -> VExpr
vType ty = (Type ty, Type ty)
vNote :: Note -> VExpr -> VExpr
vNote = mapVect . Note
vNonRec :: VVar -> VExpr -> VBind
vNonRec = zipWithVect NonRec
vRec :: [VVar] -> [VExpr] -> VBind
vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
where
(vvs, lvs) = unzip vs
(ves, les) = unzip es
vLet :: VBind -> VExpr -> VExpr
vLet = zipWithVect Let
mkVLams :: [VVar] -> VExpr -> VExpr
mkVLams vvs (ve,le) = (mkLams vs ve, mkLams ls le)
where
......
......@@ -34,6 +34,7 @@ import OccName
import DsMonad hiding (mapAndUnzipM)
import DsUtils ( mkCoreTup, mkCoreTupTy )
import Literal ( Literal )
import PrelNames
import TysWiredIn
import TysPrim ( intPrimTy )
......@@ -179,6 +180,12 @@ vectPolyVar lc v tys
lexpr <- replicatePA vexpr (Var lc)
return (vexpr, lexpr)
vectLiteral :: Var -> Literal -> VM VExpr
vectLiteral lc lit
= do
lexpr <- replicatePA (Lit lit) (Var lc)
return (Lit lit, lexpr)
vectPolyExpr :: Var -> CoreExprWithFVs -> VM VExpr
vectPolyExpr lc expr
= polyAbstract tvs $ \abstract ->
......@@ -191,22 +198,14 @@ vectPolyExpr lc expr
vectExpr :: Var -> CoreExprWithFVs -> VM VExpr
vectExpr lc (_, AnnType ty)
= do
vty <- vectType ty
return (Type vty, Type vty)
= liftM vType (vectType ty)
vectExpr lc (_, AnnVar v) = vectVar lc v
vectExpr lc (_, AnnLit lit)
= do
let vexpr = Lit lit
lexpr <- replicatePA vexpr (Var lc)
return (vexpr, lexpr)
vectExpr lc (_, AnnLit lit) = vectLiteral lc lit
vectExpr lc (_, AnnNote note expr)
= do
(vexpr, lexpr) <- vectExpr lc expr
return (Note note vexpr, Note note lexpr)
= liftM (vNote note) (vectExpr lc expr)
vectExpr lc e@(_, AnnApp _ arg)
| isAnnTypeArg arg
......@@ -225,24 +224,19 @@ vectExpr lc (_, AnnCase expr bndr ty alts)
vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
= do
(vrhs, lrhs) <- vectPolyExpr lc rhs
((vbndr, lbndr), (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
return (Let (NonRec vbndr vrhs) vbody,
Let (NonRec lbndr lrhs) lbody)
vrhs <- vectPolyExpr lc rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr lc body)
return $ vLet (vNonRec vbndr vrhs) vbody
vectExpr lc (_, AnnLet (AnnRec prs) body)
vectExpr lc (_, AnnLet (AnnRec bs) body)
= do
(bndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
let (vbndrs, lbndrs) = unzip bndrs
return (Let (Rec (zip vbndrs vrhss)) vbody,
Let (Rec (zip lbndrs lrhss)) lbody)
(vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(mapM (vectExpr lc) rhss)
(vectPolyExpr lc body)
return $ vLet (vRec vbndrs vrhss) vbody
where
(bndrs, rhss) = unzip prs
vect = do
(vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
(vbody, lbody) <- vectPolyExpr lc body
return (vrhss, vbody, lrhss, lbody)
(bndrs, rhss) = unzip bs
vectExpr lc e@(_, AnnLam bndr body)
| isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment