Commit 80cb2c39 authored by keller@cse.unsw.edu.au's avatar keller@cse.unsw.edu.au
Browse files

Handling of recursive scalar functions in isScalarLam

parent 37b0cb11
......@@ -115,7 +115,7 @@ vectModule guts
vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= do
(inline, expr') <- vectTopRhs var expr
(inline, _, expr') <- vectTopRhs [] var expr
var' <- vectTopBinder var inline expr'
-- Vectorising the body may create other top-level bindings.
......@@ -131,15 +131,23 @@ vectTopBind b@(NonRec var expr)
vectTopBind b@(Rec bs)
= do
-- pprTrace "in Rec" (ppr vars) $ return ()
(vars', _, exprs')
<- fixV $ \ ~(_, inlines, rhss) ->
do vars' <- sequence [vectTopBinder var inline rhs
| (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
(inlines', exprs')
<- mapAndUnzipM (uncurry vectTopRhs) bs
return (vars', inlines', exprs')
(inlines', areScalars', exprs')
<- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
if (and areScalars') || (length bs <= 1)
then do
-- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return ()
return (vars', inlines', exprs')
else do
-- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return ()
mapM deleteGlobalScalar vars
(inlines'', _, exprs'') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
return (vars', inlines'', exprs'')
hs <- takeHoisted
cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
......@@ -147,7 +155,9 @@ vectTopBind b@(Rec bs)
return b
where
(vars, exprs) = unzip bs
mapAndUnzip3M f xs = do
ys <- mapM f xs
return $ unzip3 ys
-- | Make the vectorised version of this top level binder, and add the mapping
-- between it and the original to the state. For some binder @foo@ the vectorised
......@@ -182,21 +192,22 @@ vectTopBinder var inline expr
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs
:: Var -- ^ Name of the binding.
:: [Var] -- ^ Names of all functions in the rec block
-> Var -- ^ Name of the binding.
-> CoreExpr -- ^ Body of the binding.
-> VM (Inline, CoreExpr)
-> VM (Inline, Bool, CoreExpr)
vectTopRhs var expr
vectTopRhs recFs var expr
= dtrace (vcat [text "vectTopRhs", ppr expr])
$ closedV
$ do (inline, isScalar, vexpr) <- inBind var
$ pprTrace "vectTopRhs" (ppr var)
$ vectPolyExpr (isLoopBreaker $ idOccInfo var)
-- $ pprTrace "vectTopRhs" (ppr var)
$ vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs
(freeVars expr)
if isScalar
then addGlobalScalar var
else return ()
return (inline, vectorised vexpr)
else deleteGlobalScalar var
return (inline, isScalar, vectorised vexpr)
-- | Project out the vectorised version of a binding from some closure,
......
......@@ -35,20 +35,21 @@ import Data.List
-- | Vectorise a polymorphic expression.
vectPolyExpr
:: Bool -- ^ When vectorising the RHS of a binding, whether that
-- binding is a loop breaker.
-- binding is a loop breaker.
-> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
vectPolyExpr loop_breaker (_, AnnNote note expr)
= do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
= do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
vectPolyExpr loop_breaker expr
vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
......@@ -117,7 +118,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
......@@ -134,10 +135,10 @@ vectExpr (_, AnnLet (AnnRec bs) body)
vect_rhs bndr rhs = localV
. inBind bndr
. liftM (\(_,_,z)->z)
$ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
$ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
| isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
| isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
......@@ -152,18 +153,19 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno
vectFnExpr
:: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
-> Bool -- ^ Whether the binding is a loop breaker.
-> [Var]
-> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-> VM (Inline, Bool, VExpr)
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
| isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
| isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up
(mark DontInline True . vectScalarLam bs $ deAnnotate body)
(mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
`orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
where
(bs,body) = collectAnnValBinders e
vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
vectFnExpr _ _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
......@@ -172,13 +174,18 @@ mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
-- | Vectorise a function where are the args have scalar type,
-- that is Int, Float, Double etc.
vectScalarLam
:: [Var] -- ^ Bound variables of function.
:: [Var] -- ^ Bound variables of function
-> [Var]
-> CoreExpr -- ^ Function body.
-> VM VExpr
vectScalarLam args body
= do scalars <- globalScalars
pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
vectScalarLam args recFns body
= do scalars' <- globalScalars
let scalars = unionVarSet (mkVarSet recFns) scalars'
{- pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
onlyIfV (all is_prim_ty arg_tys
&& is_prim_ty res_ty
&& is_scalar (extendVarSetList scalars args) body
......@@ -190,7 +197,7 @@ vectScalarLam args body
(zipf `App` Var fn_var)
clo_var <- hoistExpr (fsLit "clo") clo DontInline
lclo <- liftPD (Var clo_var)
pprTrace " lam is scalar" (ppr "") $
{- pprTrace " lam is scalar" (ppr "") $ -}
return (Var clo_var, lclo)
where
arg_tys = map idType args
......@@ -214,7 +221,7 @@ vectScalarLam args body
| isPrimTyCon tycon = False
| isAbstractTyCon tycon = True
| isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args
| isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $
| isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $
any (maybe_parr_ty' alreadySeen) args ||
hasParrDataCon alreadySeen tycon
| otherwise = True
......
......@@ -17,7 +17,8 @@ module Vectorise.Monad (
maybeCantVectoriseVarM,
dumpVar,
addGlobalScalar,
deleteGlobalScalar,
-- * Primitives
lookupPrimPArray,
lookupPrimMethod
......@@ -146,6 +147,11 @@ addGlobalScalar :: Var -> VM ()
addGlobalScalar var
= updGEnv $ \env -> pprTrace "addGLobalScalar" (ppr var) env{global_scalars = extendVarSet (global_scalars env) var}
deleteGlobalScalar :: Var -> VM ()
deleteGlobalScalar var
= updGEnv $ \env -> pprTrace "deleteGLobalScalar" (ppr var) env{global_scalars = delVarSet (global_scalars env) var}
-- Primitives -----------------------------------------------------------------
lookupPrimPArray :: TyCon -> VM (Maybe TyCon)
lookupPrimPArray = liftBuiltinDs . primPArray
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment