Vectorise.hs 20.2 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures #-}
2

3 4 5
module Vectorise( vectorise )
where

6
import VectMonad
7
import VectUtils
8
import VectVar
9
import VectType
10
import VectCore
11

12
import HscTypes hiding      ( MonadThings(..) )
13

14
import Module               ( PackageId )
15
import CoreSyn
16
import CoreUtils
17
import CoreUnfold           ( mkInlineRule )
18
import MkCore               ( mkWildCase )
19
import CoreFVs
Ian Lynagh's avatar
Ian Lynagh committed
20
import CoreMonad            ( CoreM, getHscEnv )
21
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
22
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23
import Type
24
import FamInstEnv           ( extendFamInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
25 26
import Var
import VarEnv
27
import VarSet
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
28
import Id
29
import OccName
30
import BasicTypes           ( isLoopBreaker )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
31

32
import Literal
33
import TysWiredIn
34
import TysPrim              ( intPrimTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
35

36
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37
import FastString
38 39
import Util                 ( zipLazy )
import Control.Monad
40
import Data.List            ( sortBy, unzip4 )
41

42 43 44 45 46 47

debug		= False
dtrace s x	= if debug then pprTrace "Vectorise" s x else x

-- | Vectorise a single module.
--   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
48
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
49 50 51 52
vectorise backend guts 
 = do hsc_env <- getHscEnv
      liftIO $ vectoriseIO backend hsc_env guts

53

54
-- | Vectorise a single monad, given its HscEnv (code gen environment).
55 56
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
57
 = do -- Get information about currently loaded external packages.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
58
      eps <- hscEPS hsc_env
59 60

      -- Combine vectorisation info from the current module, and external ones.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
61
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
62 63

      -- Run the main VM computation.
64
      Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
65
      return (guts' { mg_vect_info = info' })
66

67 68

-- | Vectorise a single module, in the VM monad.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
69
vectModule :: ModGuts -> VM ModGuts
70
vectModule guts
71 72 73
 = do -- Vectorise the type environment.
      -- This may add new TyCons and DataCons.
      -- TODO: What new binds do we get back here?
74
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
Ian Lynagh's avatar
Ian Lynagh committed
75

76
      -- TODO: What is this?
77 78
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
Ian Lynagh's avatar
Ian Lynagh committed
79

80 81
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
82 83

      -- Vectorise all the top level bindings.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
84
      binds'  <- mapM vectTopBind (mg_binds guts)
85

86
      return $ guts { mg_types        = types'
87
                    , mg_binds        = Rec tc_binds : binds'
88 89 90
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
91

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
129
vectTopBind :: CoreBind -> VM CoreBind
130
vectTopBind b@(NonRec var expr)
131 132 133 134 135 136 137 138 139 140 141
 = do
      (inline, expr') 	<- vectTopRhs var expr
      var' 		<- vectTopBinder var inline expr'

      -- Vectorising the body may create other top-level bindings.
      hs	<- takeHoisted

      -- To get the same functionality as the original body we project
      -- out its vectorised version from the closure.
      cexpr	<- tryConvert var var' expr

142
      return . Rec $ (var, cexpr) : (var', expr') : hs
143 144 145 146
  `orElseV`
    return b

vectTopBind b@(Rec bs)
147 148 149 150 151 152 153 154 155 156
 = do
      (vars', _, exprs') 
	<- fixV $ \ ~(_, inlines, rhss) ->
            do vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               (inlines', exprs') 
                     <- mapAndUnzipM (uncurry vectTopRhs) bs

               return (vars', inlines', exprs')

157
      hs     <- takeHoisted
158 159
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
160 161 162 163 164
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

165 166 167 168 169 170 171 172 173 174 175 176 177

-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
vectTopBinder 
	:: Var 		-- ^ Name of the binding.
	-> Inline 	-- ^ Whether it should be inlined, used to annotate it.
	-> CoreExpr 	-- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
	-> VM Var	-- ^ Name of the vectorised binding.

178
vectTopBinder var inline expr
179 180
 = do
      -- Vectorise the type attached to the var.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
181
      vty  <- vectType (idType var)
182 183 184 185 186 187

      -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
      var' <- liftM (`setIdUnfolding` unfolding) 
           $  cloneId mkVectOcc var vty

      -- Add the mapping between the plain and vectorised name to the state.
188
      defGlobalVar var var'
189

190
      return var'
191 192
  where
    unfolding = case inline of
193
                  Inline arity -> mkInlineRule expr (Just arity)
194
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
195

196 197 198 199 200 201 202

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs 
	:: Var 		-- ^ Name of the binding.
	-> CoreExpr	-- ^ Body of the binding.
	-> VM (Inline, CoreExpr)

203
vectTopRhs var expr
204 205 206 207
 = dtrace (vcat [text "vectTopRhs", ppr expr])
 $ closedV
 $ do (inline, vexpr) <- inBind var
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
208 209
                                      (freeVars expr)
      return (inline, vectorised vexpr)
210

211 212 213 214 215 216 217 218 219

-- | Project out the vectorised version of a binding from some closure,
--	or return the original body if that doesn't work.	
tryConvert 
	:: Var	 	-- ^ Name of the original binding (eg @foo@)
	-> Var 		-- ^ Name of vectorised version of binding (eg @$vfoo@)
	-> CoreExpr	-- ^ The original body of the binding.
	-> VM CoreExpr

220 221 222
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

223

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
224
-- ----------------------------------------------------------------------------
225 226
-- Expressions

227 228 229 230 231 232 233 234

-- | Vectorise a polymorphic expression
vectPolyExpr 
	:: Bool 		-- ^ When vectorising the RHS of a binding, whether that
				--   binding is a loop breaker.
	-> CoreExprWithFVs
	-> VM (Inline, VExpr)

235
vectPolyExpr loop_breaker (_, AnnNote note expr)
236
 = do (inline, expr') <- vectPolyExpr loop_breaker expr
237
      return (inline, vNote note expr')
238

239
vectPolyExpr loop_breaker expr
240 241
 = dtrace (vcat [text "vectPolyExpr", ppr (deAnnotate expr)])
 $ do
242 243 244 245 246 247
      arity <- polyArity tvs
      polyAbstract tvs $ \args ->
        do
          (inline, mono') <- vectFnExpr False loop_breaker mono
          return (addInlineArity inline arity,
                  mapVect (mkLams $ tvs ++ args) mono')
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
248
  where
Ian Lynagh's avatar
Ian Lynagh committed
249 250
    (tvs, mono) = collectAnnTypeBinders expr

251 252

-- | Vectorise a core expression.
253 254
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
255
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
256

257 258
vectExpr (_, AnnVar v) 
  = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
259

260 261
vectExpr (_, AnnLit lit) 
  = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
262

263 264
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
265

266
vectExpr e@(_, AnnApp _ arg)
267
  | isAnnTypeArg arg
268
  = vectTyAppExpr fn tys
269 270
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
271

272 273 274 275 276
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
277
      lexpr <- liftPD vexpr
278 279 280
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
Ian Lynagh's avatar
Ian Lynagh committed
281

282

283 284 285 286 287 288 289 290 291
-- TODO: Avoid using closure application for dictionaries.
-- vectExpr (_, AnnApp fn arg)
--  | if is application of dictionary 
--    just use regular app instead of closure app.

-- for lifted version. 
--      do liftPD (sub a dNumber)
--      lift the result of the selection, not sub and dNumber seprately. 

292
vectExpr (_, AnnApp fn arg)
293 294
 = dtrace (text "AnnApp" <+> ppr (deAnnotate fn) <+> ppr (deAnnotate arg))
 $ do
295 296
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
297 298

      dtrace (text "vectorising fn " <> ppr (deAnnotate fn))  $ return ()
299
      fn'     <- vectExpr fn
300 301
      dtrace (text "fn' = "       <> ppr fn') $ return ()

302
      arg'    <- vectExpr arg
303

304 305 306
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
307

308
vectExpr (_, AnnCase scrut bndr ty alts)
309 310 311
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
312 313 314
  where
    scrut_ty = exprType (deAnnotate scrut)

315
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
316
  = do
317
      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
318
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
319
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
320

321
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
322
  = do
323 324
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
325
                                  (zipWithM vect_rhs bndrs rhss)
326
                                  (vectExpr body)
327
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
328
  where
329
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
330

331 332
    vect_rhs bndr rhs = localV
                      . inBind bndr
333 334
                      . liftM snd
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
335

336
vectExpr e@(_, AnnLam bndr _)
337
  | isId bndr = liftM snd $ vectFnExpr True False e
338 339 340
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                `orElseV` vectLam True fvs bs body
341 342
  where
    (bs,body) = collectAnnValBinders e
343
-}
344

345
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
346

347 348 349 350 351 352 353 354

-- | Vectorise an expression with an outer lambda abstraction.
vectFnExpr 
	:: Bool 		-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool 		-- ^ Whether the binding is a loop breaker.
	-> CoreExprWithFVs 	-- ^ Expression to vectorise. Must have an outer `AnnLam`.
	-> VM (Inline, VExpr)

355 356 357 358
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
  | isId bndr = onlyIfV (isEmptyVarSet fvs)
                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
359 360
  where
    (bs,body) = collectAnnValBinders e
361

362
vectFnExpr _ _ e = mark DontInline $ vectExpr e
363

364 365
mark :: Inline -> VM a -> VM (Inline, a)
mark b p = do { x <- p; return (b,x) }
366

367 368 369 370 371 372

-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
vectScalarLam 
	:: [Var]	-- ^ Bound variables of function.
	-> CoreExpr	-- ^ Function body.
	-> VM VExpr
373
vectScalarLam args body
374 375
 = dtrace (vcat [text "vectScalarLam ", ppr args, ppr body])
 $ do scalars <- globalScalars
376 377
      onlyIfV (all is_scalar_ty arg_tys
               && is_scalar_ty res_ty
378 379
               && is_scalar (extendVarSetList scalars args) body
               && uses scalars body)
380
        $ do
381 382 383
            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
            zipf    <- zipScalars arg_tys res_ty
            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
384
                                                (zipf `App` Var fn_var)
385
            clo_var <- hoistExpr (fsLit "clo") clo DontInline
386
            lclo    <- liftPD (Var clo_var)
387 388 389 390 391
            return (Var clo_var, lclo)
  where
    arg_tys = map idType args
    res_ty  = exprType body

392 393 394 395 396
    is_scalar_ty ty 
        | Just (tycon, [])   <- splitTyConApp_maybe ty
        =    tycon == intTyCon
          || tycon == floatTyCon
          || tycon == doubleTyCon
397

398
        | otherwise = False
399 400

    is_scalar vs (Var v)     = v `elemVarSet` vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
401
    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
402 403 404
    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
    is_scalar _ _            = False

405 406 407 408 409 410 411 412
    -- A scalar function has to actually compute something. Without the check,
    -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
    -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
    -- (\n# x -> x) which is what we want.
    uses funs (Var v)     = v `elemVarSet` funs 
    uses funs (App e1 e2) = uses funs e1 || uses funs e2
    uses _ _              = False

413 414 415 416 417 418 419 420 421

vectLam 
	:: Bool			-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool			-- ^ Whether the binding is a loop breaker.
	-> VarSet		-- ^ The free variables in the body.
	-> [Var]		-- 
	-> CoreExprWithFVs
	-> VM VExpr

422
vectLam inline loop_breaker fvs bs body
423 424 425 426 427 428
 = dtrace (vcat [ text "vectLam "
		, text "free vars    = " <> ppr fvs
		, text "binding vars = " <> ppr bs
		, text "body         = " <> ppr (deAnnotate body)])

 $ do tyvars    <- localTyVars
429 430 431
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
432

433 434 435 436 437 438 439
      arg_tys   <- mapM (vectType . idType) bs

      dtrace (text "arg_tys = " <> ppr arg_tys) $ return ()

      res_ty    <- vectType (exprType $ deAnnotate body)

      dtrace (text "res_ty = " <> ppr res_ty) $ return ()
440

441
      buildClosures tyvars vvs arg_tys res_ty
442
        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
443
        $ do
444 445 446 447 448
            lc              <- builtin liftingContext
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)

            dtrace (text "vbody = " <> ppr vbody) $ return ()

449 450
            vbody' <- break_loop lc res_ty vbody
            return $ vLams lc vbndrs vbody'
451
  where
452 453 454 455 456 457 458 459 460 461 462 463 464 465
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

    break_loop lc ty (ve, le)
      | loop_breaker
      = do
          empty <- emptyPD ty
          lty <- mkPDataType ty
          return (ve, mkWildCase (Var lc) intPrimTy lty
                        [(DEFAULT, [], le),
                         (LitAlt (mkMachInt 0), [], empty)])

      | otherwise = return (ve, le)
 
Ian Lynagh's avatar
Ian Lynagh committed
466

467 468
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
469 470
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
                        (ppr $ deAnnotate e `mkTyApps` tys)
471 472 473 474 475 476 477

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
478 479
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
480 481
--
-- When lifting, we have to do it this way because v must have the type
482 483
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
Ian Lynagh's avatar
Ian Lynagh committed
484
--
485 486

-- FIXME: this is too lazy
Ian Lynagh's avatar
Ian Lynagh committed
487 488 489 490
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithFVs)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
491
  = do
492 493
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
494 495 496
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

Ian Lynagh's avatar
Ian Lynagh committed
497
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
498
  = do
499 500
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
501 502 503
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
504
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
505
  = do
506 507
      (vty, lty) <- vectAndLiftType ty
      vexpr      <- vectExpr scrut
508 509 510 511 512 513
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
          $ vectExpr body
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
      (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
514
      vect_dc <- maybeV (lookupDataCon dc)
515 516 517 518 519 520
      let [pdata_dc] = tyConDataCons pdata_tc

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
521
  where
Ian Lynagh's avatar
Ian Lynagh committed
522
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
523 524
                    | otherwise         = vectBndrIn bndr

525 526 527
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]

Ian Lynagh's avatar
Ian Lynagh committed
528
vectAlgCase tycon _ty_args scrut bndr ty alts
529
  = do
530 531
      vect_tc     <- maybeV (lookupTyCon tycon)
      (vty, lty)  <- vectAndLiftType ty
532

533 534 535 536 537 538 539
      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
540 541 542
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
543 544
      (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
      let [pdata_dc] = tyConDataCons pdata_tc
545

546
      let (vect_bodies, lift_bodies) = unzip vbodies
547

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
548 549 550
      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
551 552
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

553 554
      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
555
      let lift_case = Case lift_scrut ldummy lty
556
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
557 558 559 560 561
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
Ian Lynagh's avatar
Ian Lynagh committed
562
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
563 564 565 566 567 568 569 570
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
Ian Lynagh's avatar
Ian Lynagh committed
571
    cmp _             _             = panic "vectAlgCase/cmp"
572

573
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
574 575
      = do
          vect_dc <- maybeV (lookupDataCon dc)
576 577 578 579
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

580
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
581 582 583 584 585 586 587
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
588
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
589 590 591 592
                           . filter isLocalId
                           $ varSetElems fvs
                 (ve, le) <- vectExpr body
                 return (ve, Case (elems `App` sel) lc lty
593 594 595 596 597 598
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
599
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
600 601
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

602
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
603 604 605

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

606
    pack_var len tags t v
607 608 609 610 611 612
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
613
                expr <- packByTagPD (idType vv) (Var lv) len tags t
614 615 616 617 618
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
619