Vectorise.hs 21.7 KB
Newer Older
1

2 3 4
module Vectorise( vectorise )
where

5
import VectMonad
6
import VectUtils
7
import VectType
8
import VectCore
9

10
import HscTypes hiding      ( MonadThings(..) )
11

12
import Module               ( PackageId )
13
import CoreSyn
14
import CoreUtils
15
import CoreUnfold           ( mkInlineRule )
16
import MkCore               ( mkWildCase )
17
import CoreFVs
Ian Lynagh's avatar
Ian Lynagh committed
18
import CoreMonad            ( CoreM, getHscEnv )
19
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
20
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
21
import Type
22
import FamInstEnv           ( extendFamInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23 24
import Var
import VarEnv
25
import VarSet
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26
import Id
27
import OccName
28
import BasicTypes           ( isLoopBreaker )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29

30
import Literal              ( Literal, mkMachInt )
31
import TysWiredIn
32
import TysPrim              ( intPrimTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
33

34
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
35
import FastString
36 37
import Util                 ( zipLazy )
import Control.Monad
38
import Data.List            ( sortBy, unzip4 )
39

40 41 42 43 44
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
vectorise backend guts = do
    hsc_env <- getHscEnv
    liftIO $ vectoriseIO backend hsc_env guts

45
-- | Vectorise a single monad, given its HscEnv (code gen environment).
46 47
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
48
 = do -- Get information about currently loaded external packages.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
49
      eps <- hscEPS hsc_env
50 51

      -- Combine vectorisation info from the current module, and external ones.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
52
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
53 54

      -- Run the main VM computation.
55
      Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
56
      return (guts' { mg_vect_info = info' })
57

58 59

-- | Vectorise a single module, in the VM monad.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
60
vectModule :: ModGuts -> VM ModGuts
61
vectModule guts
62 63 64
 = do -- Vectorise the type environment.
      -- This may add new TyCons and DataCons.
      -- TODO: What new binds do we get back here?
65
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
Ian Lynagh's avatar
Ian Lynagh committed
66

67
      -- TODO: What is this?
68 69
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
Ian Lynagh's avatar
Ian Lynagh committed
70

71 72
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
73 74

      -- Vectorise all the top level bindings.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
75
      binds'  <- mapM vectTopBind (mg_binds guts)
76

77
      return $ guts { mg_types        = types'
78
                    , mg_binds        = Rec tc_binds : binds'
79 80 81
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
82

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
120
vectTopBind :: CoreBind -> VM CoreBind
121
vectTopBind b@(NonRec var expr)
122 123 124 125 126 127 128 129 130 131 132
 = do
      (inline, expr') 	<- vectTopRhs var expr
      var' 		<- vectTopBinder var inline expr'

      -- Vectorising the body may create other top-level bindings.
      hs	<- takeHoisted

      -- To get the same functionality as the original body we project
      -- out its vectorised version from the closure.
      cexpr	<- tryConvert var var' expr

133
      return . Rec $ (var, cexpr) : (var', expr') : hs
134 135 136 137
  `orElseV`
    return b

vectTopBind b@(Rec bs)
138 139 140 141 142 143 144 145 146 147
 = do
      (vars', _, exprs') 
	<- fixV $ \ ~(_, inlines, rhss) ->
            do vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               (inlines', exprs') 
                     <- mapAndUnzipM (uncurry vectTopRhs) bs

               return (vars', inlines', exprs')

148
      hs     <- takeHoisted
149 150
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
151 152 153 154 155
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

156 157 158 159 160 161 162 163 164 165 166 167 168

-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
vectTopBinder 
	:: Var 		-- ^ Name of the binding.
	-> Inline 	-- ^ Whether it should be inlined, used to annotate it.
	-> CoreExpr 	-- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
	-> VM Var	-- ^ Name of the vectorised binding.

169
vectTopBinder var inline expr
170 171
 = do
      -- Vectorise the type attached to the var.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
172
      vty  <- vectType (idType var)
173
      var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
174 175
      defGlobalVar var var'
      return var'
176 177
  where
    unfolding = case inline of
178
                  Inline arity -> mkInlineRule expr (Just arity)
179
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
180

181 182 183 184 185 186 187

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs 
	:: Var 		-- ^ Name of the binding.
	-> CoreExpr	-- ^ Body of the binding.
	-> VM (Inline, CoreExpr)

188
vectTopRhs var expr
189 190 191 192
 = dtrace (vcat [text "vectTopRhs", ppr expr])
 $ closedV
 $ do (inline, vexpr) <- inBind var
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
193 194
                                      (freeVars expr)
      return (inline, vectorised vexpr)
195

196 197 198 199 200 201 202 203 204

-- | Project out the vectorised version of a binding from some closure,
--	or return the original body if that doesn't work.	
tryConvert 
	:: Var	 	-- ^ Name of the original binding (eg @foo@)
	-> Var 		-- ^ Name of vectorised version of binding (eg @$vfoo@)
	-> CoreExpr	-- ^ The original body of the binding.
	-> VM CoreExpr

205 206 207
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

208 209
-- ----------------------------------------------------------------------------
-- Bindings
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
210

211
-- | Vectorise a binder variable, along with its attached type.
212
vectBndr :: Var -> VM VVar
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
213 214
vectBndr v
  = do
215
      (vty, lty) <- vectAndLiftType (idType v)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
216 217 218 219 220
      let vv = v `Id.setIdType` vty
          lv = v `Id.setIdType` lty
      updLEnv (mapTo vv lv)
      return (vv, lv)
  where
221
    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
222

223 224 225

-- | Vectorise a binder variable, along with its attached type, 
--   but give the result a new name.
226 227 228 229 230 231 232 233 234 235
vectBndrNew :: Var -> FastString -> VM VVar
vectBndrNew v fs
  = do
      vty <- vectType (idType v)
      vv  <- newLocalVVar fs vty
      updLEnv (upd vv)
      return vv
  where
    upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }

236 237

-- | Vectorise a binder then run a computation with that binder in scope.
238
vectBndrIn :: Var -> VM a -> VM (VVar, a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
239 240 241
vectBndrIn v p
  = localV
  $ do
242
      vv <- vectBndr v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
243
      x <- p
244
      return (vv, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
245

246 247

-- | Vectorise a binder, give it a new name, then run a computation with that binder in scope.
248 249 250 251 252 253 254 255
vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
vectBndrNewIn v fs p
  = localV
  $ do
      vv <- vectBndrNew v fs
      x  <- p
      return (vv, x)

256
-- | Vectorise some binders, then run a computation with them in scope.
257
vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
258 259 260
vectBndrsIn vs p
  = localV
  $ do
261
      vvs <- mapM vectBndr vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
262
      x <- p
263
      return (vvs, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
264

265

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
266
-- ----------------------------------------------------------------------------
267 268
-- Expressions

269
-- | Vectorise a variable, producing the vectorised and lifted versions.
270 271
vectVar :: Var -> VM VExpr
vectVar v
272 273
 = do 
      -- lookup the variable from the environment.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
274
      r <- lookupVar v
275

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
276
      case r of
277 278 279
        Local (vv,lv) -> return (Var vv, Var lv)
        Global vv     -> do
                           let vexpr = Var vv
280
                           lexpr <- liftPD vexpr
281
                           return (vexpr, lexpr)
282

283
-- | Like `vectVar` but also add type applications to the variables.
284 285
vectPolyVar :: Var -> [Type] -> VM VExpr
vectPolyVar v tys
286
  = do
287 288
      vtys	<- mapM vectType tys
      r		<- lookupVar v
289
      case r of
290 291 292 293 294 295 296 297
        Local (vv, lv) 
         -> liftM2 (,) (polyApply (Var vv) vtys)
                       (polyApply (Var lv) vtys)

        Global poly    
         -> do vexpr <- polyApply (Var poly) vtys
               lexpr <- liftPD vexpr
               return (vexpr, lexpr)
298

299 300

-- | Lifted literals are created by replicating them.
301 302
vectLiteral :: Literal -> VM VExpr
vectLiteral lit
303
  = do
304
      lexpr <- liftPD (Lit lit)
305 306
      return (Lit lit, lexpr)

307 308 309 310 311 312 313 314

-- | Vectorise a polymorphic expression
vectPolyExpr 
	:: Bool 		-- ^ When vectorising the RHS of a binding, whether that
				--   binding is a loop breaker.
	-> CoreExprWithFVs
	-> VM (Inline, VExpr)

315
vectPolyExpr loop_breaker (_, AnnNote note expr)
316
 = do (inline, expr') <- vectPolyExpr loop_breaker expr
317
      return (inline, vNote note expr')
318

319 320 321 322 323 324 325 326
vectPolyExpr loop_breaker expr
  = do
      arity <- polyArity tvs
      polyAbstract tvs $ \args ->
        do
          (inline, mono') <- vectFnExpr False loop_breaker mono
          return (addInlineArity inline arity,
                  mapVect (mkLams $ tvs ++ args) mono')
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
327
  where
Ian Lynagh's avatar
Ian Lynagh committed
328 329
    (tvs, mono) = collectAnnTypeBinders expr

330 331

-- | Vectorise a core expression.
332 333
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
334
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
335

336 337
vectExpr (_, AnnVar v) 
  = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
338

339 340
vectExpr (_, AnnLit lit) 
  = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
341

342 343
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
344

345
vectExpr e@(_, AnnApp _ arg)
346
  | isAnnTypeArg arg
347
  = vectTyAppExpr fn tys
348 349
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
350

351 352 353 354 355
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
356
      lexpr <- liftPD vexpr
357 358 359
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
Ian Lynagh's avatar
Ian Lynagh committed
360

361

362 363 364 365 366 367 368 369 370
-- TODO: Avoid using closure application for dictionaries.
-- vectExpr (_, AnnApp fn arg)
--  | if is application of dictionary 
--    just use regular app instead of closure app.

-- for lifted version. 
--      do liftPD (sub a dNumber)
--      lift the result of the selection, not sub and dNumber seprately. 

371
vectExpr (_, AnnApp fn arg)
372
  = do
373 374 375 376
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
      fn'     <- vectExpr fn
      arg'    <- vectExpr arg
377

378 379 380
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
381

382
vectExpr (_, AnnCase scrut bndr ty alts)
383 384 385
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
386 387 388
  where
    scrut_ty = exprType (deAnnotate scrut)

389
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
390
  = do
391
      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
392
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
393
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
394

395
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
396
  = do
397 398
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
399
                                  (zipWithM vect_rhs bndrs rhss)
400
                                  (vectExpr body)
401
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
402
  where
403
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
404

405 406
    vect_rhs bndr rhs = localV
                      . inBind bndr
407 408
                      . liftM snd
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
409

410
vectExpr e@(_, AnnLam bndr _)
411
  | isId bndr = liftM snd $ vectFnExpr True False e
412 413 414
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                `orElseV` vectLam True fvs bs body
415 416
  where
    (bs,body) = collectAnnValBinders e
417
-}
418

419
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
420

421 422 423 424 425 426 427 428

-- | Vectorise an expression with an outer lambda abstraction.
vectFnExpr 
	:: Bool 		-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool 		-- ^ Whether the binding is a loop breaker.
	-> CoreExprWithFVs 	-- ^ Expression to vectorise. Must have an outer `AnnLam`.
	-> VM (Inline, VExpr)

429 430 431 432
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
  | isId bndr = onlyIfV (isEmptyVarSet fvs)
                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
433 434
  where
    (bs,body) = collectAnnValBinders e
435
vectFnExpr _ _ e = mark DontInline $ vectExpr e
436

437 438
mark :: Inline -> VM a -> VM (Inline, a)
mark b p = do { x <- p; return (b,x) }
439

440 441 442 443 444 445

-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
vectScalarLam 
	:: [Var]	-- ^ Bound variables of function.
	-> CoreExpr	-- ^ Function body.
	-> VM VExpr
446 447 448 449 450
vectScalarLam args body
  = do
      scalars <- globalScalars
      onlyIfV (all is_scalar_ty arg_tys
               && is_scalar_ty res_ty
451 452
               && is_scalar (extendVarSetList scalars args) body
               && uses scalars body)
453
        $ do
454 455 456
            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
            zipf    <- zipScalars arg_tys res_ty
            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
457
                                                (zipf `App` Var fn_var)
458
            clo_var <- hoistExpr (fsLit "clo") clo DontInline
459
            lclo    <- liftPD (Var clo_var)
460 461 462 463 464
            return (Var clo_var, lclo)
  where
    arg_tys = map idType args
    res_ty  = exprType body

465 466 467 468 469
    is_scalar_ty ty 
        | Just (tycon, [])   <- splitTyConApp_maybe ty
        =    tycon == intTyCon
          || tycon == floatTyCon
          || tycon == doubleTyCon
470

471
        | otherwise = False
472 473

    is_scalar vs (Var v)     = v `elemVarSet` vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
474
    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
475 476 477
    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
    is_scalar _ _            = False

478 479 480 481 482 483 484 485
    -- A scalar function has to actually compute something. Without the check,
    -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
    -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
    -- (\n# x -> x) which is what we want.
    uses funs (Var v)     = v `elemVarSet` funs 
    uses funs (App e1 e2) = uses funs e1 || uses funs e2
    uses _ _              = False

486 487 488 489 490 491 492 493 494

vectLam 
	:: Bool			-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool			-- ^ Whether the binding is a loop breaker.
	-> VarSet		-- ^ The free variables in the body.
	-> [Var]		-- 
	-> CoreExprWithFVs
	-> VM VExpr

495
vectLam inline loop_breaker fvs bs body
496
  = do
497
      tyvars <- localTyVars
498 499 500
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
501

502 503 504
      arg_tys <- mapM (vectType . idType) bs
      res_ty  <- vectType (exprType $ deAnnotate body)

505
      buildClosures tyvars vvs arg_tys res_ty
506
        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
507
        $ do
508
            lc <- builtin liftingContext
509
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
510
                                           (vectExpr body)
511 512
            vbody' <- break_loop lc res_ty vbody
            return $ vLams lc vbndrs vbody'
513
  where
514 515 516 517 518 519 520 521 522 523 524 525 526 527
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

    break_loop lc ty (ve, le)
      | loop_breaker
      = do
          empty <- emptyPD ty
          lty <- mkPDataType ty
          return (ve, mkWildCase (Var lc) intPrimTy lty
                        [(DEFAULT, [], le),
                         (LitAlt (mkMachInt 0), [], empty)])

      | otherwise = return (ve, le)
 
Ian Lynagh's avatar
Ian Lynagh committed
528

529 530
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
531 532
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
                        (ppr $ deAnnotate e `mkTyApps` tys)
533 534 535 536 537 538 539

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
540 541
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
542 543
--
-- When lifting, we have to do it this way because v must have the type
544 545
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
Ian Lynagh's avatar
Ian Lynagh committed
546
--
547 548

-- FIXME: this is too lazy
Ian Lynagh's avatar
Ian Lynagh committed
549 550 551 552
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithFVs)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
553
  = do
554 555
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
556 557 558
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

Ian Lynagh's avatar
Ian Lynagh committed
559
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
560
  = do
561 562
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
563 564 565
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
566
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
567
  = do
568 569
      (vty, lty) <- vectAndLiftType ty
      vexpr      <- vectExpr scrut
570 571 572 573 574 575
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
          $ vectExpr body
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
      (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
576
      vect_dc <- maybeV (lookupDataCon dc)
577 578 579 580 581 582
      let [pdata_dc] = tyConDataCons pdata_tc

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
583
  where
Ian Lynagh's avatar
Ian Lynagh committed
584
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
585 586
                    | otherwise         = vectBndrIn bndr

587 588 589
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]

Ian Lynagh's avatar
Ian Lynagh committed
590
vectAlgCase tycon _ty_args scrut bndr ty alts
591
  = do
592 593
      vect_tc     <- maybeV (lookupTyCon tycon)
      (vty, lty)  <- vectAndLiftType ty
594

595 596 597 598 599 600 601
      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
602 603 604
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
605 606
      (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
      let [pdata_dc] = tyConDataCons pdata_tc
607

608
      let (vect_bodies, lift_bodies) = unzip vbodies
609

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
610 611 612
      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
613 614
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

615 616
      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
617
      let lift_case = Case lift_scrut ldummy lty
618
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
619 620 621 622 623
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
Ian Lynagh's avatar
Ian Lynagh committed
624
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
625 626 627 628 629 630 631 632
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
Ian Lynagh's avatar
Ian Lynagh committed
633
    cmp _             _             = panic "vectAlgCase/cmp"
634

635
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
636 637
      = do
          vect_dc <- maybeV (lookupDataCon dc)
638 639 640 641
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

642
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
643 644 645 646 647 648 649
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
650
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
651 652 653 654
                           . filter isLocalId
                           $ varSetElems fvs
                 (ve, le) <- vectExpr body
                 return (ve, Case (elems `App` sel) lc lty
655 656 657 658 659 660
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
661
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
662 663
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

664
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
665 666 667

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

668
    pack_var len tags t v
669 670 671 672 673 674
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
675
                expr <- packByTagPD (idType vv) (Var lv) len tags t
676 677 678 679 680
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
681