Commit 551472b1 authored by chak@cse.unsw.edu.au.'s avatar chak@cse.unsw.edu.au.
Browse files

Vectoriser: don't pack free *scalar* variables

parent 2af18952
......@@ -210,7 +210,9 @@ vectTopBind b@(Rec binds)
; cantVectorise dflags noVectoriseErr (ppr b)
}
else do
{ -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
{ traceVt "[Vanilla]" $ vcat [ppr var <+> char '=' <+> ppr expr | (var, expr) <- binds]
-- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
; newBindsWPragma <- concat <$>
sequence [ vectTopBindAndConvert bind inlineMe expr'
| (bind, (_, Just (_, expr'))) <- zip binds vectDecls]
......
......@@ -99,11 +99,19 @@ vectTopExprs binds
= do
{ exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
; if all isVIEncaps exprVIs
then
return Nothing
-- if all bindings are scalar => don't vectorise this group of bindings
then return Nothing
else do
{ (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds
; return $ Just (or areVIParr, vExprs)
{ -- non-scalar bindings need to be vectorised
; let areVIParr = any isVIParr exprVIs
; revised_exprVIs <- if not areVIParr
-- if no binding is parallel => 'exprVIs' is ready for vectorisation
then return exprVIs
-- if any binding is parallel => recompute the vectorisation info
else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
; vExprs <- zipWithM vect vars revised_exprVIs
; return $ Just (areVIParr, vExprs)
}
}
where
......@@ -111,14 +119,13 @@ vectTopExprs binds
vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
encapsulateAndVect (var, expr)
vect var exprVI
= do
{ exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr
; vExpr <- closedV $
{ vExpr <- closedV $
inBind var $
vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
; inline <- computeInline exprVI
; return (isVIParr exprVI, (inline, vectorised vExpr))
; return (inline, vectorised vExpr)
}
-- |Vectorise a polymorphic expression annotated with vectorisation information.
......@@ -302,8 +309,8 @@ vectExpr (_, AnnVar v)
vectExpr (_, AnnLit lit)
= vectConst $ Lit lit
vectExpr aexpr@(_, AnnLam bndr _)
= vectFnExpr True False aexpr
vectExpr aexpr@(_, AnnLam _ _)
= traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >> vectFnExpr True False aexpr
-- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
-- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
......@@ -368,19 +375,24 @@ vectExpr (_, AnnCase scrut bndr ty alts)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
{ vrhs <- localV $
{ traceVt "let binding (non-recursive)" empty
; vrhs <- localV $
inBind bndr $
vectAnnPolyExpr False rhs
; traceVt "let body (non-recursive)" empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vLet (vNonRec vbndr vrhs) vbody
}
vectExpr (_, AnnLet (AnnRec bs) body)
= do
{ (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
$ liftM2 (,)
(zipWithM vect_rhs bndrs rhss)
(vectExpr body)
{ (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
{ traceVt "let bindings (recursive)" empty
; vrhss <- zipWithM vect_rhs bndrs rhss
; traceVt "let body (recursive)" empty
; vbody <- vectExpr body
; return (vrhss, vbody)
}
; return $ vLet (vRec vbndrs vrhss) vbody
}
where
......@@ -442,7 +454,7 @@ vectFnExpr _ _ aexpr
= vectScalarFun . deAnnotate $ aexpr
| otherwise
-- not an abstraction: vectorise as a non-scalar vanilla expression
-- NB: we can get here legitimately due to the recursion in the first case above
-- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr'
= vectExpr aexpr
-- |Vectorise type and dictionary applications.
......@@ -570,7 +582,7 @@ vectDictExpr (Coercion coe)
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr
= do
{ traceVt "vectorise scalar functions:" (ppr expr)
{ traceVt "vectScalarFun:" (ppr expr)
; let (arg_tys, res_ty) = splitFunTys (exprType expr)
; mkScalarFun arg_tys res_ty expr
}
......@@ -700,7 +712,9 @@ vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
-> CoreExprWithVectInfo -- ^ Body of abstraction.
-> VM VExpr
vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
= do { let (bndrs, body) = collectAnnValBinders expr
= do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
; let (bndrs, body) = collectAnnValBinders expr
-- grab the in-scope type variables
; tyvars <- localTyVars
......@@ -769,40 +783,47 @@ vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
-- have to handle the case where v is a wild var correctly.
--
-- FIXME: this is too lazy
-- FIXME: this is too lazy...is it?
vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
-> [(AltCon, [Var], CoreExprWithVectInfo)]
-> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
{ traceVt "scrutinee (DEFAULT only)" empty
; vscrut <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (DEFAULT only)" empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
}
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
= do
vscrut <- vectExpr scrut
(vty, lty) <- vectAndLiftType ty
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vCaseDEFAULT vscrut vbndr vty lty vbody
{ traceVt "scrutinee (one shot w/o binders)" empty
; vscrut <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (one shot w/o binders)" empty
; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
}
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
= do
(vty, lty) <- vectAndLiftType ty
vexpr <- vectExpr scrut
(vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
let (vect_bndrs, lift_bndrs) = unzip vbndrs
(vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
vect_dc <- maybeV dataConErr (lookupDataCon dc)
let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
{ traceVt "scrutinee (one shot w/ binders)" empty
; vexpr <- vectExpr scrut
; (vty, lty) <- vectAndLiftType ty
; traceVt "alternative body (one shot w/ binders)" empty
; (vbndr, (vbndrs, (vect_body, lift_body)))
<- vect_scrut_bndr
. vectBndrsIn bndrs
$ vectExpr body
; let (vect_bndrs, lift_bndrs) = unzip vbndrs
; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
; vect_dc <- maybeV dataConErr (lookupDataCon dc)
; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
}
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
......@@ -814,36 +835,40 @@ vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
vectAlgCase tycon _ty_args scrut bndr ty alts
= do
vect_tc <- vectTyCon tycon
(vty, lty) <- vectAndLiftType ty
{ traceVt "scrutinee (general case)" empty
; vexpr <- vectExpr scrut
; vect_tc <- vectTyCon tycon
; (vty, lty) <- vectAndLiftType ty
let arity = length (tyConDataCons vect_tc)
sel_ty <- builtin (selTy arity)
sel_bndr <- newLocalVar (fsLit "sel") sel_ty
let sel = Var sel_bndr
; let arity = length (tyConDataCons vect_tc)
; sel_ty <- builtin (selTy arity)
; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
; let sel = Var sel_bndr
(vbndr, valts) <- vect_scrut_bndr
; traceVt "alternatives' body (general case)" empty
; (vbndr, valts) <- vect_scrut_bndr
$ mapM (proc_alt arity sel vty lty) alts'
let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
vexpr <- vectExpr scrut
(vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
let (vect_bodies, lift_bodies) = unzip vbodies
; let (vect_bodies, lift_bodies) = unzip vbodies
vdummy <- newDummyVar (exprType vect_scrut)
ldummy <- newDummyVar (exprType lift_scrut)
let vect_case = Case vect_scrut vdummy vty
; vdummy <- newDummyVar (exprType vect_scrut)
; ldummy <- newDummyVar (exprType lift_scrut)
; let vect_case = Case vect_scrut vdummy vty
(zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
lc <- builtin liftingContext
lbody <- combinePD vty (Var lc) sel lift_bodies
let lift_case = Case lift_scrut ldummy lty
; lc <- builtin liftingContext
; lbody <- combinePD vty (Var lc) sel lift_bodies
; let lift_case = Case lift_scrut ldummy lty
[(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
lbody)]
return . vLet (vNonRec vbndr vexpr)
; return . vLet (vNonRec vbndr vexpr)
$ (vect_case, lift_case)
}
where
vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
| otherwise = vectBndrIn bndr
......@@ -871,12 +896,14 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
<- vectBndrsIn bndrs
. localV
$ do
binds <- mapM (pack_var (Var lc) sel_tags tag)
{ binds <- mapM (pack_var (Var lc) sel_tags tag)
. filter isLocalId
$ varSetElems fvs
(ve, le) <- vectExpr body
return (ve, Case (elems `App` sel) lc lty
; traceVt "case alternative:" (ppr . deAnnotate $ body)
; (ve, le) <- vectExpr body
; return (ve, Case (elems `App` sel) lc lty
[(DEFAULT, [], (mkLets (concat binds) le))])
}
-- empty <- emptyPD vty
-- return (ve, Case (elems `App` sel) lc lty
-- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
......@@ -887,25 +914,26 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
where
dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
-- Pack a variable for a case alternative context *if* the variable is vectorised. If it
-- isn't, ignore it as scalar variables don't need to be packed.
pack_var len tags t v
= do
r <- lookupVar v
case r of
Local (vv, lv) ->
{ r <- lookupVar_maybe v
; case r of
Just (Local (vv, lv)) ->
do
lv' <- cloneVar lv
expr <- packByTagPD (idType vv) (Var lv) len tags t
updLEnv (\env -> env { local_vars = extendVarEnv
(local_vars env) v (vv, lv') })
return [(NonRec lv' expr)]
{ lv' <- cloneVar lv
; expr <- packByTagPD (idType vv) (Var lv) len tags t
; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
; return [(NonRec lv' expr)]
}
_ -> return []
}
-- Support to compute information for vectorisation avoidance ------------------
......@@ -972,7 +1000,10 @@ vectAvoidInfo pvs ce@(fvs, AnnVar v)
; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
then return VIParr
else vectAvoidInfoTypeOf ce
; viTrace ce vi []
; viTrace ce vi []
; when (vi == VIParr) $
traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
if v `elemVarSet` gpvs then text "global" else text "parallel type"
; return ((fvs, vi), AnnVar v)
}
......@@ -990,16 +1021,16 @@ vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
; eVI1 <- vectAvoidInfo pvs e1
; eVI2 <- vectAvoidInfo pvs e2
; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
; viTrace ce vi [eVI1, eVI2]
-- ; viTrace ce vi [eVI1, eVI2]
; return ((fvs, vi), AnnApp eVI1 eVI2)
}
vectAvoidInfo pvs ce@(fvs, AnnLam var body)
vectAvoidInfo pvs (fvs, AnnLam var body)
= do
{ bodyVI <- vectAvoidInfo pvs body
; varVI <- vectAvoidInfoType $ varType var
; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
; viTrace ce vi [bodyVI]
-- ; viTrace ce vi [bodyVI]
; return ((fvs, vi), AnnLam var bodyVI)
}
......@@ -1010,14 +1041,14 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
; isScalarTy <- isScalar $ varType var
; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
then do -- binding is parallel
{ bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body
{ bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
; return (bodyVI, VIParr)
}
else do -- binding doesn't affect parallelism
{ bodyVI <- vectAvoidInfo fvs body
{ bodyVI <- vectAvoidInfo pvs body
; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
}
; viTrace ce vi [eVI, bodyVI]
-- ; viTrace ce vi [eVI, bodyVI]
; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
}
......@@ -1032,13 +1063,13 @@ vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
; let extendedPvs = pvs `extendVarSetList` new_pvs
; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
; bodyVI <- vectAvoidInfo extendedPvs body
; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
-- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
}
else do -- demanded bindings cannot trigger parallelism
{ bodyVI <- vectAvoidInfo pvs body
; let vi = ceVI `unlessVIParrExpr` bodyVI
; viTrace ce vi (map snd bndsVI ++ [bodyVI])
-- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
}
}
......@@ -1058,7 +1089,7 @@ vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
; viTrace ce vi (eVI : alteVIs)
-- ; viTrace ce vi (eVI : alteVIs)
; return ((fvs, vi), AnnCase eVI var ty altsVI)
}
where
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment