Vectorise.hs 20.2 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures #-}
2

3 4 5
module Vectorise( vectorise )
where

6
import VectUtils
7
import VectVar
8
import VectType
9
import Vectorise.Vect
10
import Vectorise.Env
11 12
import Vectorise.Monad
import Vectorise.Builtins
13

14
import HscTypes hiding      ( MonadThings(..) )
15

16
import Module               ( PackageId )
17
import CoreSyn
18
import CoreUtils
19
import CoreUnfold           ( mkInlineRule )
20
import MkCore               ( mkWildCase )
21
import CoreFVs
Ian Lynagh's avatar
Ian Lynagh committed
22
import CoreMonad            ( CoreM, getHscEnv )
23
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
24
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
25
import Type
26
import FamInstEnv           ( extendFamInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
27 28
import Var
import VarEnv
29
import VarSet
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
30
import Id
31
import OccName
32
import BasicTypes           ( isLoopBreaker )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
33

34
import Literal
35
import TysWiredIn
36
import TysPrim              ( intPrimTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37

38
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
39
import FastString
40 41
import Util                 ( zipLazy )
import Control.Monad
42
import Data.List            ( sortBy, unzip4 )
43

44 45 46 47 48 49

debug		= False
dtrace s x	= if debug then pprTrace "Vectorise" s x else x

-- | Vectorise a single module.
--   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
50
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
51 52 53 54
vectorise backend guts 
 = do hsc_env <- getHscEnv
      liftIO $ vectoriseIO backend hsc_env guts

55

56
-- | Vectorise a single monad, given its HscEnv (code gen environment).
57 58
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
59
 = do -- Get information about currently loaded external packages.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
60
      eps <- hscEPS hsc_env
61 62

      -- Combine vectorisation info from the current module, and external ones.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
63
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
64 65

      -- Run the main VM computation.
66
      Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
67
      return (guts' { mg_vect_info = info' })
68

69 70

-- | Vectorise a single module, in the VM monad.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
71
vectModule :: ModGuts -> VM ModGuts
72
vectModule guts
73 74 75
 = do -- Vectorise the type environment.
      -- This may add new TyCons and DataCons.
      -- TODO: What new binds do we get back here?
76
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
Ian Lynagh's avatar
Ian Lynagh committed
77

78
      -- TODO: What is this?
79 80
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
Ian Lynagh's avatar
Ian Lynagh committed
81

82 83
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
84 85

      -- Vectorise all the top level bindings.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
86
      binds'  <- mapM vectTopBind (mg_binds guts)
87

88
      return $ guts { mg_types        = types'
89
                    , mg_binds        = Rec tc_binds : binds'
90 91 92
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
93

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
131
vectTopBind :: CoreBind -> VM CoreBind
132
vectTopBind b@(NonRec var expr)
133 134 135 136 137 138 139 140 141 142 143
 = do
      (inline, expr') 	<- vectTopRhs var expr
      var' 		<- vectTopBinder var inline expr'

      -- Vectorising the body may create other top-level bindings.
      hs	<- takeHoisted

      -- To get the same functionality as the original body we project
      -- out its vectorised version from the closure.
      cexpr	<- tryConvert var var' expr

144
      return . Rec $ (var, cexpr) : (var', expr') : hs
145 146 147 148
  `orElseV`
    return b

vectTopBind b@(Rec bs)
149 150 151 152 153 154 155 156 157 158
 = do
      (vars', _, exprs') 
	<- fixV $ \ ~(_, inlines, rhss) ->
            do vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               (inlines', exprs') 
                     <- mapAndUnzipM (uncurry vectTopRhs) bs

               return (vars', inlines', exprs')

159
      hs     <- takeHoisted
160 161
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
162 163 164 165 166
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

167 168 169 170 171 172 173 174 175 176 177 178 179

-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
vectTopBinder 
	:: Var 		-- ^ Name of the binding.
	-> Inline 	-- ^ Whether it should be inlined, used to annotate it.
	-> CoreExpr 	-- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
	-> VM Var	-- ^ Name of the vectorised binding.

180
vectTopBinder var inline expr
181 182
 = do
      -- Vectorise the type attached to the var.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
183
      vty  <- vectType (idType var)
184 185 186 187 188 189

      -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
      var' <- liftM (`setIdUnfolding` unfolding) 
           $  cloneId mkVectOcc var vty

      -- Add the mapping between the plain and vectorised name to the state.
190
      defGlobalVar var var'
191

192
      return var'
193 194
  where
    unfolding = case inline of
195
                  Inline arity -> mkInlineRule expr (Just arity)
196
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
197

198 199 200 201 202 203 204

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs 
	:: Var 		-- ^ Name of the binding.
	-> CoreExpr	-- ^ Body of the binding.
	-> VM (Inline, CoreExpr)

205
vectTopRhs var expr
206 207 208 209
 = dtrace (vcat [text "vectTopRhs", ppr expr])
 $ closedV
 $ do (inline, vexpr) <- inBind var
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
210 211
                                      (freeVars expr)
      return (inline, vectorised vexpr)
212

213 214 215 216 217 218 219 220 221

-- | Project out the vectorised version of a binding from some closure,
--	or return the original body if that doesn't work.	
tryConvert 
	:: Var	 	-- ^ Name of the original binding (eg @foo@)
	-> Var 		-- ^ Name of vectorised version of binding (eg @$vfoo@)
	-> CoreExpr	-- ^ The original body of the binding.
	-> VM CoreExpr

222 223 224
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

225

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
226
-- ----------------------------------------------------------------------------
227 228
-- Expressions

229 230 231 232 233 234 235 236

-- | Vectorise a polymorphic expression
vectPolyExpr 
	:: Bool 		-- ^ When vectorising the RHS of a binding, whether that
				--   binding is a loop breaker.
	-> CoreExprWithFVs
	-> VM (Inline, VExpr)

237
vectPolyExpr loop_breaker (_, AnnNote note expr)
238
 = do (inline, expr') <- vectPolyExpr loop_breaker expr
239
      return (inline, vNote note expr')
240

241
vectPolyExpr loop_breaker expr
242 243
 = dtrace (vcat [text "vectPolyExpr", ppr (deAnnotate expr)])
 $ do
244 245 246 247 248 249
      arity <- polyArity tvs
      polyAbstract tvs $ \args ->
        do
          (inline, mono') <- vectFnExpr False loop_breaker mono
          return (addInlineArity inline arity,
                  mapVect (mkLams $ tvs ++ args) mono')
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
250
  where
Ian Lynagh's avatar
Ian Lynagh committed
251 252
    (tvs, mono) = collectAnnTypeBinders expr

253 254

-- | Vectorise a core expression.
255 256
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
257
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
258

259 260
vectExpr (_, AnnVar v) 
  = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
261

262 263
vectExpr (_, AnnLit lit) 
  = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
264

265 266
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
267

268
vectExpr e@(_, AnnApp _ arg)
269
  | isAnnTypeArg arg
270
  = vectTyAppExpr fn tys
271 272
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
273

274 275 276 277 278
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
279
      lexpr <- liftPD vexpr
280 281 282
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
Ian Lynagh's avatar
Ian Lynagh committed
283

284

285 286 287 288 289 290 291 292 293
-- TODO: Avoid using closure application for dictionaries.
-- vectExpr (_, AnnApp fn arg)
--  | if is application of dictionary 
--    just use regular app instead of closure app.

-- for lifted version. 
--      do liftPD (sub a dNumber)
--      lift the result of the selection, not sub and dNumber seprately. 

294
vectExpr (_, AnnApp fn arg)
295 296
 = dtrace (text "AnnApp" <+> ppr (deAnnotate fn) <+> ppr (deAnnotate arg))
 $ do
297 298
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
299 300

      dtrace (text "vectorising fn " <> ppr (deAnnotate fn))  $ return ()
301
      fn'     <- vectExpr fn
302 303
      dtrace (text "fn' = "       <> ppr fn') $ return ()

304
      arg'    <- vectExpr arg
305

306 307 308
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
309

310
vectExpr (_, AnnCase scrut bndr ty alts)
311 312 313
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
314 315 316
  where
    scrut_ty = exprType (deAnnotate scrut)

317
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
318
  = do
319
      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
320
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
321
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
322

323
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
324
  = do
325 326
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
327
                                  (zipWithM vect_rhs bndrs rhss)
328
                                  (vectExpr body)
329
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
330
  where
331
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
332

333 334
    vect_rhs bndr rhs = localV
                      . inBind bndr
335 336
                      . liftM snd
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
337

338
vectExpr e@(_, AnnLam bndr _)
339
  | isId bndr = liftM snd $ vectFnExpr True False e
340 341 342
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                `orElseV` vectLam True fvs bs body
343 344
  where
    (bs,body) = collectAnnValBinders e
345
-}
346

347
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
348

349 350 351 352 353 354 355 356

-- | Vectorise an expression with an outer lambda abstraction.
vectFnExpr 
	:: Bool 		-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool 		-- ^ Whether the binding is a loop breaker.
	-> CoreExprWithFVs 	-- ^ Expression to vectorise. Must have an outer `AnnLam`.
	-> VM (Inline, VExpr)

357 358 359 360
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
  | isId bndr = onlyIfV (isEmptyVarSet fvs)
                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
361 362
  where
    (bs,body) = collectAnnValBinders e
363

364
vectFnExpr _ _ e = mark DontInline $ vectExpr e
365

366 367
mark :: Inline -> VM a -> VM (Inline, a)
mark b p = do { x <- p; return (b,x) }
368

369 370 371 372 373 374

-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
vectScalarLam 
	:: [Var]	-- ^ Bound variables of function.
	-> CoreExpr	-- ^ Function body.
	-> VM VExpr
375
vectScalarLam args body
376 377
 = dtrace (vcat [text "vectScalarLam ", ppr args, ppr body])
 $ do scalars <- globalScalars
378 379
      onlyIfV (all is_scalar_ty arg_tys
               && is_scalar_ty res_ty
380 381
               && is_scalar (extendVarSetList scalars args) body
               && uses scalars body)
382
        $ do
383 384 385
            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
            zipf    <- zipScalars arg_tys res_ty
            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
386
                                                (zipf `App` Var fn_var)
387
            clo_var <- hoistExpr (fsLit "clo") clo DontInline
388
            lclo    <- liftPD (Var clo_var)
389 390 391 392 393
            return (Var clo_var, lclo)
  where
    arg_tys = map idType args
    res_ty  = exprType body

394 395 396 397 398
    is_scalar_ty ty 
        | Just (tycon, [])   <- splitTyConApp_maybe ty
        =    tycon == intTyCon
          || tycon == floatTyCon
          || tycon == doubleTyCon
399

400
        | otherwise = False
401 402

    is_scalar vs (Var v)     = v `elemVarSet` vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
403
    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
404 405 406
    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
    is_scalar _ _            = False

407 408 409 410 411 412 413 414
    -- A scalar function has to actually compute something. Without the check,
    -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
    -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
    -- (\n# x -> x) which is what we want.
    uses funs (Var v)     = v `elemVarSet` funs 
    uses funs (App e1 e2) = uses funs e1 || uses funs e2
    uses _ _              = False

415 416 417 418 419 420 421 422 423

vectLam 
	:: Bool			-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool			-- ^ Whether the binding is a loop breaker.
	-> VarSet		-- ^ The free variables in the body.
	-> [Var]		-- 
	-> CoreExprWithFVs
	-> VM VExpr

424
vectLam inline loop_breaker fvs bs body
425 426 427 428 429 430
 = dtrace (vcat [ text "vectLam "
		, text "free vars    = " <> ppr fvs
		, text "binding vars = " <> ppr bs
		, text "body         = " <> ppr (deAnnotate body)])

 $ do tyvars    <- localTyVars
431 432 433
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
434

435 436 437 438 439 440 441
      arg_tys   <- mapM (vectType . idType) bs

      dtrace (text "arg_tys = " <> ppr arg_tys) $ return ()

      res_ty    <- vectType (exprType $ deAnnotate body)

      dtrace (text "res_ty = " <> ppr res_ty) $ return ()
442

443
      buildClosures tyvars vvs arg_tys res_ty
444
        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
445
        $ do
446 447 448 449 450
            lc              <- builtin liftingContext
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)

            dtrace (text "vbody = " <> ppr vbody) $ return ()

451 452
            vbody' <- break_loop lc res_ty vbody
            return $ vLams lc vbndrs vbody'
453
  where
454 455 456 457 458 459 460 461 462 463 464 465 466 467
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

    break_loop lc ty (ve, le)
      | loop_breaker
      = do
          empty <- emptyPD ty
          lty <- mkPDataType ty
          return (ve, mkWildCase (Var lc) intPrimTy lty
                        [(DEFAULT, [], le),
                         (LitAlt (mkMachInt 0), [], empty)])

      | otherwise = return (ve, le)
 
Ian Lynagh's avatar
Ian Lynagh committed
468

469 470
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
471 472
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
                        (ppr $ deAnnotate e `mkTyApps` tys)
473 474 475 476 477 478 479

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
480 481
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
482 483
--
-- When lifting, we have to do it this way because v must have the type
484 485
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
Ian Lynagh's avatar
Ian Lynagh committed
486
--
487 488

-- FIXME: this is too lazy
Ian Lynagh's avatar
Ian Lynagh committed
489 490 491 492
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithFVs)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
493
  = do
494 495
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
496 497 498
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

Ian Lynagh's avatar
Ian Lynagh committed
499
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
500
  = do
501 502
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
503 504 505
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
506
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
507
  = do
508 509
      (vty, lty) <- vectAndLiftType ty
      vexpr      <- vectExpr scrut
510 511 512 513 514 515
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
          $ vectExpr body
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
      (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
516
      vect_dc <- maybeV (lookupDataCon dc)
517 518 519 520 521 522
      let [pdata_dc] = tyConDataCons pdata_tc

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
523
  where
Ian Lynagh's avatar
Ian Lynagh committed
524
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
525 526
                    | otherwise         = vectBndrIn bndr

527 528 529
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]

Ian Lynagh's avatar
Ian Lynagh committed
530
vectAlgCase tycon _ty_args scrut bndr ty alts
531
  = do
532 533
      vect_tc     <- maybeV (lookupTyCon tycon)
      (vty, lty)  <- vectAndLiftType ty
534

535 536 537 538 539 540 541
      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
542 543 544
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
545 546
      (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
      let [pdata_dc] = tyConDataCons pdata_tc
547

548
      let (vect_bodies, lift_bodies) = unzip vbodies
549

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
550 551 552
      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
553 554
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

555 556
      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
557
      let lift_case = Case lift_scrut ldummy lty
558
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
559 560 561 562 563
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
Ian Lynagh's avatar
Ian Lynagh committed
564
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
565 566 567 568 569 570 571 572
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
Ian Lynagh's avatar
Ian Lynagh committed
573
    cmp _             _             = panic "vectAlgCase/cmp"
574

575
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
576 577
      = do
          vect_dc <- maybeV (lookupDataCon dc)
578 579 580 581
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

582
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
583 584 585 586 587 588 589
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
590
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
591 592 593 594
                           . filter isLocalId
                           $ varSetElems fvs
                 (ve, le) <- vectExpr body
                 return (ve, Case (elems `App` sel) lc lty
595 596 597 598 599 600
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
601
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
602 603
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

604
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
605 606 607

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

608
    pack_var len tags t v
609 610 611 612 613 614
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
615
                expr <- packByTagPD (idType vv) (Var lv) len tags t
616 617 618 619 620
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
621