Vectorise.hs 20.2 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures #-}
2

3 4 5
module Vectorise( vectorise )
where

6
import VectMonad
7
import VectUtils
8
import VectVar
9
import VectType
10
import Vectorise.Vect
11
import Vectorise.Env
12

13
import HscTypes hiding      ( MonadThings(..) )
14

15
import Module               ( PackageId )
16
import CoreSyn
17
import CoreUtils
18
import CoreUnfold           ( mkInlineRule )
19
import MkCore               ( mkWildCase )
20
import CoreFVs
Ian Lynagh's avatar
Ian Lynagh committed
21
import CoreMonad            ( CoreM, getHscEnv )
22
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
24
import Type
25
import FamInstEnv           ( extendFamInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26 27
import Var
import VarEnv
28
import VarSet
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Id
30
import OccName
31
import BasicTypes           ( isLoopBreaker )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
32

33
import Literal
34
import TysWiredIn
35
import TysPrim              ( intPrimTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
36

37
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
38
import FastString
39 40
import Util                 ( zipLazy )
import Control.Monad
41
import Data.List            ( sortBy, unzip4 )
42

43 44 45 46 47 48

debug		= False
dtrace s x	= if debug then pprTrace "Vectorise" s x else x

-- | Vectorise a single module.
--   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
49
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
50 51 52 53
vectorise backend guts 
 = do hsc_env <- getHscEnv
      liftIO $ vectoriseIO backend hsc_env guts

54

55
-- | Vectorise a single monad, given its HscEnv (code gen environment).
56 57
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
58
 = do -- Get information about currently loaded external packages.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59
      eps <- hscEPS hsc_env
60 61

      -- Combine vectorisation info from the current module, and external ones.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
62
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
63 64

      -- Run the main VM computation.
65
      Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
66
      return (guts' { mg_vect_info = info' })
67

68 69

-- | Vectorise a single module, in the VM monad.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
70
vectModule :: ModGuts -> VM ModGuts
71
vectModule guts
72 73 74
 = do -- Vectorise the type environment.
      -- This may add new TyCons and DataCons.
      -- TODO: What new binds do we get back here?
75
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
Ian Lynagh's avatar
Ian Lynagh committed
76

77
      -- TODO: What is this?
78 79
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
Ian Lynagh's avatar
Ian Lynagh committed
80

81 82
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
83 84

      -- Vectorise all the top level bindings.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
85
      binds'  <- mapM vectTopBind (mg_binds guts)
86

87
      return $ guts { mg_types        = types'
88
                    , mg_binds        = Rec tc_binds : binds'
89 90 91
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
92

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
130
vectTopBind :: CoreBind -> VM CoreBind
131
vectTopBind b@(NonRec var expr)
132 133 134 135 136 137 138 139 140 141 142
 = do
      (inline, expr') 	<- vectTopRhs var expr
      var' 		<- vectTopBinder var inline expr'

      -- Vectorising the body may create other top-level bindings.
      hs	<- takeHoisted

      -- To get the same functionality as the original body we project
      -- out its vectorised version from the closure.
      cexpr	<- tryConvert var var' expr

143
      return . Rec $ (var, cexpr) : (var', expr') : hs
144 145 146 147
  `orElseV`
    return b

vectTopBind b@(Rec bs)
148 149 150 151 152 153 154 155 156 157
 = do
      (vars', _, exprs') 
	<- fixV $ \ ~(_, inlines, rhss) ->
            do vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               (inlines', exprs') 
                     <- mapAndUnzipM (uncurry vectTopRhs) bs

               return (vars', inlines', exprs')

158
      hs     <- takeHoisted
159 160
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
161 162 163 164 165
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

166 167 168 169 170 171 172 173 174 175 176 177 178

-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
vectTopBinder 
	:: Var 		-- ^ Name of the binding.
	-> Inline 	-- ^ Whether it should be inlined, used to annotate it.
	-> CoreExpr 	-- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
	-> VM Var	-- ^ Name of the vectorised binding.

179
vectTopBinder var inline expr
180 181
 = do
      -- Vectorise the type attached to the var.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
182
      vty  <- vectType (idType var)
183 184 185 186 187 188

      -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
      var' <- liftM (`setIdUnfolding` unfolding) 
           $  cloneId mkVectOcc var vty

      -- Add the mapping between the plain and vectorised name to the state.
189
      defGlobalVar var var'
190

191
      return var'
192 193
  where
    unfolding = case inline of
194
                  Inline arity -> mkInlineRule expr (Just arity)
195
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
196

197 198 199 200 201 202 203

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs 
	:: Var 		-- ^ Name of the binding.
	-> CoreExpr	-- ^ Body of the binding.
	-> VM (Inline, CoreExpr)

204
vectTopRhs var expr
205 206 207 208
 = dtrace (vcat [text "vectTopRhs", ppr expr])
 $ closedV
 $ do (inline, vexpr) <- inBind var
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
209 210
                                      (freeVars expr)
      return (inline, vectorised vexpr)
211

212 213 214 215 216 217 218 219 220

-- | Project out the vectorised version of a binding from some closure,
--	or return the original body if that doesn't work.	
tryConvert 
	:: Var	 	-- ^ Name of the original binding (eg @foo@)
	-> Var 		-- ^ Name of vectorised version of binding (eg @$vfoo@)
	-> CoreExpr	-- ^ The original body of the binding.
	-> VM CoreExpr

221 222 223
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

224

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
225
-- ----------------------------------------------------------------------------
226 227
-- Expressions

228 229 230 231 232 233 234 235

-- | Vectorise a polymorphic expression
vectPolyExpr 
	:: Bool 		-- ^ When vectorising the RHS of a binding, whether that
				--   binding is a loop breaker.
	-> CoreExprWithFVs
	-> VM (Inline, VExpr)

236
vectPolyExpr loop_breaker (_, AnnNote note expr)
237
 = do (inline, expr') <- vectPolyExpr loop_breaker expr
238
      return (inline, vNote note expr')
239

240
vectPolyExpr loop_breaker expr
241 242
 = dtrace (vcat [text "vectPolyExpr", ppr (deAnnotate expr)])
 $ do
243 244 245 246 247 248
      arity <- polyArity tvs
      polyAbstract tvs $ \args ->
        do
          (inline, mono') <- vectFnExpr False loop_breaker mono
          return (addInlineArity inline arity,
                  mapVect (mkLams $ tvs ++ args) mono')
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
249
  where
Ian Lynagh's avatar
Ian Lynagh committed
250 251
    (tvs, mono) = collectAnnTypeBinders expr

252 253

-- | Vectorise a core expression.
254 255
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
256
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
257

258 259
vectExpr (_, AnnVar v) 
  = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
260

261 262
vectExpr (_, AnnLit lit) 
  = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
263

264 265
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
266

267
vectExpr e@(_, AnnApp _ arg)
268
  | isAnnTypeArg arg
269
  = vectTyAppExpr fn tys
270 271
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
272

273 274 275 276 277
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
278
      lexpr <- liftPD vexpr
279 280 281
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
Ian Lynagh's avatar
Ian Lynagh committed
282

283

284 285 286 287 288 289 290 291 292
-- TODO: Avoid using closure application for dictionaries.
-- vectExpr (_, AnnApp fn arg)
--  | if is application of dictionary 
--    just use regular app instead of closure app.

-- for lifted version. 
--      do liftPD (sub a dNumber)
--      lift the result of the selection, not sub and dNumber seprately. 

293
vectExpr (_, AnnApp fn arg)
294 295
 = dtrace (text "AnnApp" <+> ppr (deAnnotate fn) <+> ppr (deAnnotate arg))
 $ do
296 297
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
298 299

      dtrace (text "vectorising fn " <> ppr (deAnnotate fn))  $ return ()
300
      fn'     <- vectExpr fn
301 302
      dtrace (text "fn' = "       <> ppr fn') $ return ()

303
      arg'    <- vectExpr arg
304

305 306 307
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
308

309
vectExpr (_, AnnCase scrut bndr ty alts)
310 311 312
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
313 314 315
  where
    scrut_ty = exprType (deAnnotate scrut)

316
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
317
  = do
318
      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
319
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
320
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
321

322
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
323
  = do
324 325
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
326
                                  (zipWithM vect_rhs bndrs rhss)
327
                                  (vectExpr body)
328
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
329
  where
330
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
331

332 333
    vect_rhs bndr rhs = localV
                      . inBind bndr
334 335
                      . liftM snd
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
336

337
vectExpr e@(_, AnnLam bndr _)
338
  | isId bndr = liftM snd $ vectFnExpr True False e
339 340 341
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                `orElseV` vectLam True fvs bs body
342 343
  where
    (bs,body) = collectAnnValBinders e
344
-}
345

346
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
347

348 349 350 351 352 353 354 355

-- | Vectorise an expression with an outer lambda abstraction.
vectFnExpr 
	:: Bool 		-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool 		-- ^ Whether the binding is a loop breaker.
	-> CoreExprWithFVs 	-- ^ Expression to vectorise. Must have an outer `AnnLam`.
	-> VM (Inline, VExpr)

356 357 358 359
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
  | isId bndr = onlyIfV (isEmptyVarSet fvs)
                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
360 361
  where
    (bs,body) = collectAnnValBinders e
362

363
vectFnExpr _ _ e = mark DontInline $ vectExpr e
364

365 366
mark :: Inline -> VM a -> VM (Inline, a)
mark b p = do { x <- p; return (b,x) }
367

368 369 370 371 372 373

-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
vectScalarLam 
	:: [Var]	-- ^ Bound variables of function.
	-> CoreExpr	-- ^ Function body.
	-> VM VExpr
374
vectScalarLam args body
375 376
 = dtrace (vcat [text "vectScalarLam ", ppr args, ppr body])
 $ do scalars <- globalScalars
377 378
      onlyIfV (all is_scalar_ty arg_tys
               && is_scalar_ty res_ty
379 380
               && is_scalar (extendVarSetList scalars args) body
               && uses scalars body)
381
        $ do
382 383 384
            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
            zipf    <- zipScalars arg_tys res_ty
            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
385
                                                (zipf `App` Var fn_var)
386
            clo_var <- hoistExpr (fsLit "clo") clo DontInline
387
            lclo    <- liftPD (Var clo_var)
388 389 390 391 392
            return (Var clo_var, lclo)
  where
    arg_tys = map idType args
    res_ty  = exprType body

393 394 395 396 397
    is_scalar_ty ty 
        | Just (tycon, [])   <- splitTyConApp_maybe ty
        =    tycon == intTyCon
          || tycon == floatTyCon
          || tycon == doubleTyCon
398

399
        | otherwise = False
400 401

    is_scalar vs (Var v)     = v `elemVarSet` vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
402
    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
403 404 405
    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
    is_scalar _ _            = False

406 407 408 409 410 411 412 413
    -- A scalar function has to actually compute something. Without the check,
    -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
    -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
    -- (\n# x -> x) which is what we want.
    uses funs (Var v)     = v `elemVarSet` funs 
    uses funs (App e1 e2) = uses funs e1 || uses funs e2
    uses _ _              = False

414 415 416 417 418 419 420 421 422

vectLam 
	:: Bool			-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool			-- ^ Whether the binding is a loop breaker.
	-> VarSet		-- ^ The free variables in the body.
	-> [Var]		-- 
	-> CoreExprWithFVs
	-> VM VExpr

423
vectLam inline loop_breaker fvs bs body
424 425 426 427 428 429
 = dtrace (vcat [ text "vectLam "
		, text "free vars    = " <> ppr fvs
		, text "binding vars = " <> ppr bs
		, text "body         = " <> ppr (deAnnotate body)])

 $ do tyvars    <- localTyVars
430 431 432
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
433

434 435 436 437 438 439 440
      arg_tys   <- mapM (vectType . idType) bs

      dtrace (text "arg_tys = " <> ppr arg_tys) $ return ()

      res_ty    <- vectType (exprType $ deAnnotate body)

      dtrace (text "res_ty = " <> ppr res_ty) $ return ()
441

442
      buildClosures tyvars vvs arg_tys res_ty
443
        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
444
        $ do
445 446 447 448 449
            lc              <- builtin liftingContext
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)

            dtrace (text "vbody = " <> ppr vbody) $ return ()

450 451
            vbody' <- break_loop lc res_ty vbody
            return $ vLams lc vbndrs vbody'
452
  where
453 454 455 456 457 458 459 460 461 462 463 464 465 466
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

    break_loop lc ty (ve, le)
      | loop_breaker
      = do
          empty <- emptyPD ty
          lty <- mkPDataType ty
          return (ve, mkWildCase (Var lc) intPrimTy lty
                        [(DEFAULT, [], le),
                         (LitAlt (mkMachInt 0), [], empty)])

      | otherwise = return (ve, le)
 
Ian Lynagh's avatar
Ian Lynagh committed
467

468 469
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
470 471
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
                        (ppr $ deAnnotate e `mkTyApps` tys)
472 473 474 475 476 477 478

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
479 480
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
481 482
--
-- When lifting, we have to do it this way because v must have the type
483 484
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
Ian Lynagh's avatar
Ian Lynagh committed
485
--
486 487

-- FIXME: this is too lazy
Ian Lynagh's avatar
Ian Lynagh committed
488 489 490 491
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithFVs)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
492
  = do
493 494
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
495 496 497
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

Ian Lynagh's avatar
Ian Lynagh committed
498
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
499
  = do
500 501
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
502 503 504
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
505
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
506
  = do
507 508
      (vty, lty) <- vectAndLiftType ty
      vexpr      <- vectExpr scrut
509 510 511 512 513 514
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
          $ vectExpr body
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
      (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
515
      vect_dc <- maybeV (lookupDataCon dc)
516 517 518 519 520 521
      let [pdata_dc] = tyConDataCons pdata_tc

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
522
  where
Ian Lynagh's avatar
Ian Lynagh committed
523
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
524 525
                    | otherwise         = vectBndrIn bndr

526 527 528
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]

Ian Lynagh's avatar
Ian Lynagh committed
529
vectAlgCase tycon _ty_args scrut bndr ty alts
530
  = do
531 532
      vect_tc     <- maybeV (lookupTyCon tycon)
      (vty, lty)  <- vectAndLiftType ty
533

534 535 536 537 538 539 540
      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
541 542 543
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
544 545
      (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
      let [pdata_dc] = tyConDataCons pdata_tc
546

547
      let (vect_bodies, lift_bodies) = unzip vbodies
548

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
549 550 551
      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
552 553
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

554 555
      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
556
      let lift_case = Case lift_scrut ldummy lty
557
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
558 559 560 561 562
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
Ian Lynagh's avatar
Ian Lynagh committed
563
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
564 565 566 567 568 569 570 571
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
Ian Lynagh's avatar
Ian Lynagh committed
572
    cmp _             _             = panic "vectAlgCase/cmp"
573

574
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
575 576
      = do
          vect_dc <- maybeV (lookupDataCon dc)
577 578 579 580
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

581
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
582 583 584 585 586 587 588
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
589
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
590 591 592 593
                           . filter isLocalId
                           $ varSetElems fvs
                 (ve, le) <- vectExpr body
                 return (ve, Case (elems `App` sel) lc lty
594 595 596 597 598 599
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
600
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
601 602
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

603
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
604 605 606

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

607
    pack_var len tags t v
608 609 610 611 612 613
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
614
                expr <- packByTagPD (idType vv) (Var lv) len tags t
615 616 617 618 619
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
620