Vectorise.hs 20.2 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures #-}
2

3
4
5
module Vectorise( vectorise )
where

6
import VectMonad
7
import VectUtils
8
import VectVar
9
import VectType
10
import VectCore
11
import Vectorise.Env
12

13
import HscTypes hiding      ( MonadThings(..) )
14

15
import Module               ( PackageId )
16
import CoreSyn
17
import CoreUtils
18
import CoreUnfold           ( mkInlineRule )
19
import MkCore               ( mkWildCase )
20
import CoreFVs
Ian Lynagh's avatar
Ian Lynagh committed
21
import CoreMonad            ( CoreM, getHscEnv )
22
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
24
import Type
25
import FamInstEnv           ( extendFamInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26
27
import Var
import VarEnv
28
import VarSet
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Id
30
import OccName
31
import BasicTypes           ( isLoopBreaker )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
32

33
import Literal
34
import TysWiredIn
35
import TysPrim              ( intPrimTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
36

37
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
38
import FastString
39
40
import Util                 ( zipLazy )
import Control.Monad
41
import Data.List            ( sortBy, unzip4 )
42

43
44
45
46
47
48

debug		= False
dtrace s x	= if debug then pprTrace "Vectorise" s x else x

-- | Vectorise a single module.
--   Takes the package containing the DPH backend we're using. Eg either dph-par or dph-seq.
49
vectorise :: PackageId -> ModGuts -> CoreM ModGuts
50
51
52
53
vectorise backend guts 
 = do hsc_env <- getHscEnv
      liftIO $ vectoriseIO backend hsc_env guts

54

55
-- | Vectorise a single monad, given its HscEnv (code gen environment).
56
57
vectoriseIO :: PackageId -> HscEnv -> ModGuts -> IO ModGuts
vectoriseIO backend hsc_env guts
58
 = do -- Get information about currently loaded external packages.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59
      eps <- hscEPS hsc_env
60
61

      -- Combine vectorisation info from the current module, and external ones.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
62
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
63
64

      -- Run the main VM computation.
65
      Just (info', guts') <- initV backend hsc_env guts info (vectModule guts)
66
      return (guts' { mg_vect_info = info' })
67

68
69

-- | Vectorise a single module, in the VM monad.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
70
vectModule :: ModGuts -> VM ModGuts
71
vectModule guts
72
73
74
 = do -- Vectorise the type environment.
      -- This may add new TyCons and DataCons.
      -- TODO: What new binds do we get back here?
75
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
Ian Lynagh's avatar
Ian Lynagh committed
76

77
      -- TODO: What is this?
78
79
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
Ian Lynagh's avatar
Ian Lynagh committed
80

81
82
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
83
84

      -- Vectorise all the top level bindings.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
85
      binds'  <- mapM vectTopBind (mg_binds guts)
86

87
      return $ guts { mg_types        = types'
88
                    , mg_binds        = Rec tc_binds : binds'
89
90
91
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
130
vectTopBind :: CoreBind -> VM CoreBind
131
vectTopBind b@(NonRec var expr)
132
133
134
135
136
137
138
139
140
141
142
 = do
      (inline, expr') 	<- vectTopRhs var expr
      var' 		<- vectTopBinder var inline expr'

      -- Vectorising the body may create other top-level bindings.
      hs	<- takeHoisted

      -- To get the same functionality as the original body we project
      -- out its vectorised version from the closure.
      cexpr	<- tryConvert var var' expr

143
      return . Rec $ (var, cexpr) : (var', expr') : hs
144
145
146
147
  `orElseV`
    return b

vectTopBind b@(Rec bs)
148
149
150
151
152
153
154
155
156
157
 = do
      (vars', _, exprs') 
	<- fixV $ \ ~(_, inlines, rhss) ->
            do vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               (inlines', exprs') 
                     <- mapAndUnzipM (uncurry vectTopRhs) bs

               return (vars', inlines', exprs')

158
      hs     <- takeHoisted
159
160
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
161
162
163
164
165
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

166
167
168
169
170
171
172
173
174
175
176
177
178

-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
vectTopBinder 
	:: Var 		-- ^ Name of the binding.
	-> Inline 	-- ^ Whether it should be inlined, used to annotate it.
	-> CoreExpr 	-- ^ RHS of the binding, used to set the `Unfolding` of the returned `Var`.
	-> VM Var	-- ^ Name of the vectorised binding.

179
vectTopBinder var inline expr
180
181
 = do
      -- Vectorise the type attached to the var.
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
182
      vty  <- vectType (idType var)
183
184
185
186
187
188

      -- Make the vectorised version of binding's name, and set the unfolding used for inlining.
      var' <- liftM (`setIdUnfolding` unfolding) 
           $  cloneId mkVectOcc var vty

      -- Add the mapping between the plain and vectorised name to the state.
189
      defGlobalVar var var'
190

191
      return var'
192
193
  where
    unfolding = case inline of
194
                  Inline arity -> mkInlineRule expr (Just arity)
195
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
196

197
198
199
200
201
202
203

-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs 
	:: Var 		-- ^ Name of the binding.
	-> CoreExpr	-- ^ Body of the binding.
	-> VM (Inline, CoreExpr)

204
vectTopRhs var expr
205
206
207
208
 = dtrace (vcat [text "vectTopRhs", ppr expr])
 $ closedV
 $ do (inline, vexpr) <- inBind var
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
209
210
                                      (freeVars expr)
      return (inline, vectorised vexpr)
211

212
213
214
215
216
217
218
219
220

-- | Project out the vectorised version of a binding from some closure,
--	or return the original body if that doesn't work.	
tryConvert 
	:: Var	 	-- ^ Name of the original binding (eg @foo@)
	-> Var 		-- ^ Name of vectorised version of binding (eg @$vfoo@)
	-> CoreExpr	-- ^ The original body of the binding.
	-> VM CoreExpr

221
222
223
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

224

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
225
-- ----------------------------------------------------------------------------
226
227
-- Expressions

228
229
230
231
232
233
234
235

-- | Vectorise a polymorphic expression
vectPolyExpr 
	:: Bool 		-- ^ When vectorising the RHS of a binding, whether that
				--   binding is a loop breaker.
	-> CoreExprWithFVs
	-> VM (Inline, VExpr)

236
vectPolyExpr loop_breaker (_, AnnNote note expr)
237
 = do (inline, expr') <- vectPolyExpr loop_breaker expr
238
      return (inline, vNote note expr')
239

240
vectPolyExpr loop_breaker expr
241
242
 = dtrace (vcat [text "vectPolyExpr", ppr (deAnnotate expr)])
 $ do
243
244
245
246
247
248
      arity <- polyArity tvs
      polyAbstract tvs $ \args ->
        do
          (inline, mono') <- vectFnExpr False loop_breaker mono
          return (addInlineArity inline arity,
                  mapVect (mkLams $ tvs ++ args) mono')
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
249
  where
Ian Lynagh's avatar
Ian Lynagh committed
250
251
    (tvs, mono) = collectAnnTypeBinders expr

252
253

-- | Vectorise a core expression.
254
255
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
256
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
257

258
259
vectExpr (_, AnnVar v) 
  = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
260

261
262
vectExpr (_, AnnLit lit) 
  = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
263

264
265
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
266

267
vectExpr e@(_, AnnApp _ arg)
268
  | isAnnTypeArg arg
269
  = vectTyAppExpr fn tys
270
271
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
272

273
274
275
276
277
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
278
      lexpr <- liftPD vexpr
279
280
281
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
Ian Lynagh's avatar
Ian Lynagh committed
282

283

284
285
286
287
288
289
290
291
292
-- TODO: Avoid using closure application for dictionaries.
-- vectExpr (_, AnnApp fn arg)
--  | if is application of dictionary 
--    just use regular app instead of closure app.

-- for lifted version. 
--      do liftPD (sub a dNumber)
--      lift the result of the selection, not sub and dNumber seprately. 

293
vectExpr (_, AnnApp fn arg)
294
295
 = dtrace (text "AnnApp" <+> ppr (deAnnotate fn) <+> ppr (deAnnotate arg))
 $ do
296
297
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
298
299

      dtrace (text "vectorising fn " <> ppr (deAnnotate fn))  $ return ()
300
      fn'     <- vectExpr fn
301
302
      dtrace (text "fn' = "       <> ppr fn') $ return ()

303
      arg'    <- vectExpr arg
304

305
306
307
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
308

309
vectExpr (_, AnnCase scrut bndr ty alts)
310
311
312
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
313
314
315
  where
    scrut_ty = exprType (deAnnotate scrut)

316
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
317
  = do
318
      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
319
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
320
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
321

322
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
323
  = do
324
325
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
326
                                  (zipWithM vect_rhs bndrs rhss)
327
                                  (vectExpr body)
328
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
329
  where
330
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
331

332
333
    vect_rhs bndr rhs = localV
                      . inBind bndr
334
335
                      . liftM snd
                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
336

337
vectExpr e@(_, AnnLam bndr _)
338
  | isId bndr = liftM snd $ vectFnExpr True False e
339
340
341
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                `orElseV` vectLam True fvs bs body
342
343
  where
    (bs,body) = collectAnnValBinders e
344
-}
345

346
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
347

348
349
350
351
352
353
354
355

-- | Vectorise an expression with an outer lambda abstraction.
vectFnExpr 
	:: Bool 		-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool 		-- ^ Whether the binding is a loop breaker.
	-> CoreExprWithFVs 	-- ^ Expression to vectorise. Must have an outer `AnnLam`.
	-> VM (Inline, VExpr)

356
357
358
359
vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
  | isId bndr = onlyIfV (isEmptyVarSet fvs)
                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
360
361
  where
    (bs,body) = collectAnnValBinders e
362

363
vectFnExpr _ _ e = mark DontInline $ vectExpr e
364

365
366
mark :: Inline -> VM a -> VM (Inline, a)
mark b p = do { x <- p; return (b,x) }
367

368
369
370
371
372
373

-- | Vectorise a function where are the args have scalar type, that is Int, Float or Double.
vectScalarLam 
	:: [Var]	-- ^ Bound variables of function.
	-> CoreExpr	-- ^ Function body.
	-> VM VExpr
374
vectScalarLam args body
375
376
 = dtrace (vcat [text "vectScalarLam ", ppr args, ppr body])
 $ do scalars <- globalScalars
377
378
      onlyIfV (all is_scalar_ty arg_tys
               && is_scalar_ty res_ty
379
380
               && is_scalar (extendVarSetList scalars args) body
               && uses scalars body)
381
        $ do
382
383
384
            fn_var  <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
            zipf    <- zipScalars arg_tys res_ty
            clo     <- scalarClosure arg_tys res_ty (Var fn_var)
385
                                                (zipf `App` Var fn_var)
386
            clo_var <- hoistExpr (fsLit "clo") clo DontInline
387
            lclo    <- liftPD (Var clo_var)
388
389
390
391
392
            return (Var clo_var, lclo)
  where
    arg_tys = map idType args
    res_ty  = exprType body

393
394
395
396
397
    is_scalar_ty ty 
        | Just (tycon, [])   <- splitTyConApp_maybe ty
        =    tycon == intTyCon
          || tycon == floatTyCon
          || tycon == doubleTyCon
398

399
        | otherwise = False
400
401

    is_scalar vs (Var v)     = v `elemVarSet` vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
402
    is_scalar _ e@(Lit _)    = is_scalar_ty $ exprType e
403
404
405
    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
    is_scalar _ _            = False

406
407
408
409
410
411
412
413
    -- A scalar function has to actually compute something. Without the check,
    -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
    -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
    -- (\n# x -> x) which is what we want.
    uses funs (Var v)     = v `elemVarSet` funs 
    uses funs (App e1 e2) = uses funs e1 || uses funs e2
    uses _ _              = False

414
415
416
417
418
419
420
421
422

vectLam 
	:: Bool			-- ^ When the RHS of a binding, whether that binding should be inlined.
	-> Bool			-- ^ Whether the binding is a loop breaker.
	-> VarSet		-- ^ The free variables in the body.
	-> [Var]		-- 
	-> CoreExprWithFVs
	-> VM VExpr

423
vectLam inline loop_breaker fvs bs body
424
425
426
427
428
429
 = dtrace (vcat [ text "vectLam "
		, text "free vars    = " <> ppr fvs
		, text "binding vars = " <> ppr bs
		, text "body         = " <> ppr (deAnnotate body)])

 $ do tyvars    <- localTyVars
430
431
432
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
433

434
435
436
437
438
439
440
      arg_tys   <- mapM (vectType . idType) bs

      dtrace (text "arg_tys = " <> ppr arg_tys) $ return ()

      res_ty    <- vectType (exprType $ deAnnotate body)

      dtrace (text "res_ty = " <> ppr res_ty) $ return ()
441

442
      buildClosures tyvars vvs arg_tys res_ty
443
        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
444
        $ do
445
446
447
448
449
            lc              <- builtin liftingContext
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)

            dtrace (text "vbody = " <> ppr vbody) $ return ()

450
451
            vbody' <- break_loop lc res_ty vbody
            return $ vLams lc vbndrs vbody'
452
  where
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

    break_loop lc ty (ve, le)
      | loop_breaker
      = do
          empty <- emptyPD ty
          lty <- mkPDataType ty
          return (ve, mkWildCase (Var lc) intPrimTy lty
                        [(DEFAULT, [], le),
                         (LitAlt (mkMachInt 0), [], empty)])

      | otherwise = return (ve, le)
 
Ian Lynagh's avatar
Ian Lynagh committed
467

468
469
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
470
471
vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
                        (ppr $ deAnnotate e `mkTyApps` tys)
472
473
474
475
476
477
478

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
479
480
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
481
482
--
-- When lifting, we have to do it this way because v must have the type
483
484
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
Ian Lynagh's avatar
Ian Lynagh committed
485
--
486
487

-- FIXME: this is too lazy
Ian Lynagh's avatar
Ian Lynagh committed
488
489
490
491
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
            -> [(AltCon, [Var], CoreExprWithFVs)]
            -> VM VExpr
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
492
  = do
493
494
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
495
496
497
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

Ian Lynagh's avatar
Ian Lynagh committed
498
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
499
  = do
500
501
      vscrut         <- vectExpr scrut
      (vty, lty)     <- vectAndLiftType ty
502
503
504
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
505
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
506
  = do
507
508
      (vty, lty) <- vectAndLiftType ty
      vexpr      <- vectExpr scrut
509
510
511
512
513
514
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
          $ vectExpr body
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
      (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
515
      vect_dc <- maybeV (lookupDataCon dc)
516
517
518
519
520
521
      let [pdata_dc] = tyConDataCons pdata_tc

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
522
  where
Ian Lynagh's avatar
Ian Lynagh committed
523
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
524
525
                    | otherwise         = vectBndrIn bndr

526
527
528
    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]

Ian Lynagh's avatar
Ian Lynagh committed
529
vectAlgCase tycon _ty_args scrut bndr ty alts
530
  = do
531
532
      vect_tc     <- maybeV (lookupTyCon tycon)
      (vty, lty)  <- vectAndLiftType ty
533

534
535
536
537
538
539
540
      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
                      $ mapM (proc_alt arity sel vty lty) alts'
541
542
543
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
544
545
      (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
      let [pdata_dc] = tyConDataCons pdata_tc
546

547
      let (vect_bodies, lift_bodies) = unzip vbodies
548

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
549
550
551
      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
552
553
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

554
555
      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
556
      let lift_case = Case lift_scrut ldummy lty
557
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
558
559
560
561
562
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
Ian Lynagh's avatar
Ian Lynagh committed
563
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
564
565
566
567
568
569
570
571
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
Ian Lynagh's avatar
Ian Lynagh committed
572
    cmp _             _             = panic "vectAlgCase/cmp"
573

574
    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
575
576
      = do
          vect_dc <- maybeV (lookupDataCon dc)
577
578
579
580
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

581
          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
582
583
584
585
586
587
588
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
589
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
590
591
592
593
                           . filter isLocalId
                           $ varSetElems fvs
                 (ve, le) <- vectExpr body
                 return (ve, Case (elems `App` sel) lc lty
594
595
596
597
598
599
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
600
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
601
602
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

603
    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
604
605
606

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

607
    pack_var len tags t v
608
609
610
611
612
613
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
614
                expr <- packByTagPD (idType vv) (Var lv) len tags t
615
616
617
618
619
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
620