VectType.hs 27 KB
Newer Older
1
module VectType ( vectTyCon, vectType, vectTypeEnv,
2
                   PAInstance, buildPADict )
3
4
5
6
7
8
where

#include "HsVersions.h"

import VectMonad
import VectUtils
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
9
import VectCore
10

11
import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
12
import CoreSyn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
13
import CoreUtils
14
import BuildTyCl
15
import DataCon
16
17
18
import TyCon
import Type
import TypeRep
19
import Coercion
20
import FamInstEnv        ( FamInst, mkLocalFamInst )
21
import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
22
23
import OccName
import MkId
24
import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
25
import Var               ( Var )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26
import Id                ( mkWildId )
27
import Name              ( Name, getOccName )
28
import NameEnv
29
import TysWiredIn        ( unitTy, unitTyCon, intTy, intDataCon, unitDataConId )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
30
import TysPrim           ( intPrimTy )
31

32
import Unique
33
34
35
36
import UniqFM
import UniqSet
import Digraph           ( SCC(..), stronglyConnComp )

37
38
import Outputable

39
import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
40
import Data.List      ( inits, tails, zipWith4, zipWith5 )
41

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
42
43
44
-- ----------------------------------------------------------------------------
-- Types

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
vectTyCon :: TyCon -> VM TyCon
vectTyCon tc
  | isFunTyCon tc        = builtin closureTyCon
  | isBoxedTupleTyCon tc = return tc
  | isUnLiftedTyCon tc   = return tc
  | otherwise = do
                  r <- lookupTyCon tc
                  case r of
                    Just tc' -> return tc'

                    -- FIXME: just for now
                    Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc

vectType :: Type -> VM Type
vectType ty | Just ty' <- coreView ty = vectType ty'
vectType (TyVarTy tv) = return $ TyVarTy tv
vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
                                             (mapM vectType [ty1,ty2])
vectType ty@(ForAllTy _ _)
  = do
      mdicts   <- mapM paDictArgType tyvars
      mono_ty' <- vectType mono_ty
      return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
  where
    (tyvars, mono_ty) = splitForAllTys ty

vectType ty = pprPanic "vectType:" (ppr ty)

75
76
77
78
79
-- ----------------------------------------------------------------------------
-- Type definitions

type TyConGroup = ([TyCon], UniqSet TyCon)

80
data PAInstance = PAInstance {
81
                    painstDFun      :: Var
82
                  , painstOrigTyCon :: TyCon
83
84
85
86
                  , painstVectTyCon :: TyCon
                  , painstArrTyCon  :: TyCon
                  }

87
vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
88
89
90
91
92
93
94
vectTypeEnv env
  = do
      cs <- readGEnv $ mk_map . global_tycons
      let (conv_tcs, keep_tcs) = classifyTyCons cs groups
          keep_dcs             = concatMap tyConDataCons keep_tcs
      zipWithM_ defTyCon   keep_tcs keep_tcs
      zipWithM_ defDataCon keep_dcs keep_dcs
95
96
97
98
99
      new_tcs <- vectTyConDecls conv_tcs

      let orig_tcs = keep_tcs ++ conv_tcs
          vect_tcs  = keep_tcs ++ new_tcs

100
      repr_tcs <- zipWithM buildPReprTyCon   orig_tcs vect_tcs
101
      parr_tcs <- zipWithM buildPArrayTyCon orig_tcs vect_tcs
102
103
      dfuns    <- mapM mkPADFun vect_tcs
      defTyConPAs (zip vect_tcs dfuns)
104
105
106
107
108
109
      binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
                                                        vect_tcs
                                                        repr_tcs
                                                        parr_tcs
                                                        dfuns)

110
      let all_new_tcs = new_tcs ++ repr_tcs ++ parr_tcs
111
112

      let new_env = extendTypeEnvList env
113
114
                       (map ATyCon all_new_tcs
                        ++ [ADataCon dc | tc <- all_new_tcs
115
116
                                        , dc <- tyConDataCons tc])

117
      return (new_env, map mkLocalFamInst (repr_tcs ++ parr_tcs), concat binds)
118
119
120
121
122
123
124
125
126
127
128
  where
    tycons = typeEnvTyCons env
    groups = tyConGroups tycons

    mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]

    keep_tc tc = let dcs = tyConDataCons tc
                 in
                 defTyCon tc tc >> zipWithM_ defDataCon dcs dcs


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
vectTyConDecls :: [TyCon] -> VM [TyCon]
vectTyConDecls tcs = fixV $ \tcs' ->
  do
    mapM_ (uncurry defTyCon) (lazy_zip tcs tcs')
    mapM vectTyConDecl tcs
  where
    lazy_zip [] _ = []
    lazy_zip (x:xs) ~(y:ys) = (x,y) : lazy_zip xs ys

vectTyConDecl :: TyCon -> VM TyCon
vectTyConDecl tc
  = do
      name' <- cloneName mkVectTyConOcc name
      rhs'  <- vectAlgTyConRhs (algTyConRhs tc)

144
145
146
147
148
149
150
151
      liftDs $ buildAlgTyCon name'
                             tyvars
                             []           -- no stupid theta
                             rhs'
                             rec_flag     -- FIXME: is this ok?
                             False        -- FIXME: no generics
                             False        -- not GADT syntax
                             Nothing      -- not a family instance
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  where
    name   = tyConName tc
    tyvars = tyConTyVars tc
    rec_flag = boolToRecFlag (isRecursiveTyCon tc)

vectAlgTyConRhs :: AlgTyConRhs -> VM AlgTyConRhs
vectAlgTyConRhs (DataTyCon { data_cons = data_cons
                           , is_enum   = is_enum
                           })
  = do
      data_cons' <- mapM vectDataCon data_cons
      zipWithM_ defDataCon data_cons data_cons'
      return $ DataTyCon { data_cons = data_cons'
                         , is_enum   = is_enum
                         }

vectDataCon :: DataCon -> VM DataCon
vectDataCon dc
  | not . null $ dataConExTyVars dc = pprPanic "vectDataCon: existentials" (ppr dc)
  | not . null $ dataConEqSpec   dc = pprPanic "vectDataCon: eq spec" (ppr dc)
  | otherwise
  = do
      name'    <- cloneName mkVectDataConOcc name
      tycon'   <- vectTyCon tycon
      arg_tys  <- mapM vectType rep_arg_tys
177
178
179
180
181
182
183
184
185
186
187

      liftDs $ buildDataCon name'
                            False           -- not infix
                            (map (const NotMarkedStrict) arg_tys)
                            []              -- no labelled fields
                            univ_tvs
                            []              -- no existential tvs for now
                            []              -- no eq spec for now
                            []              -- no context
                            arg_tys
                            tycon'
188
189
190
  where
    name        = dataConName dc
    univ_tvs    = dataConUnivTyVars dc
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
191
    rep_arg_tys = dataConRepArgTys dc
192
193
    tycon       = dataConTyCon dc

194
195
196
197
mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
mk_fam_inst fam_tc arg_tc
  = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
198
199
200
201
buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
buildPReprTyCon orig_tc vect_tc
  = do
      name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
202
      rhs_ty   <- buildPReprType vect_tc
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
203
204
205
206
207
208
209
210
      prepr_tc <- builtin preprTyCon
      liftDs $ buildSynTyCon name
                             tyvars
                             (SynonymTyCon rhs_ty)
                             (Just $ mk_fam_inst prepr_tc vect_tc)
  where
    tyvars = tyConTyVars vect_tc

211

212
213
214
215
216
217
218
data Repr = ProdRepr {
              prod_components   :: [Type]
            , prod_tycon        :: TyCon
            , prod_data_con     :: DataCon
            , prod_arr_tycon    :: TyCon
            , prod_arr_data_con :: DataCon
            }
219

220
221
222
223
224
225
226
227
228
229
230
231
          | SumRepr {
              sum_components    :: [Repr]
            , sum_tycon         :: TyCon
            , sum_arr_tycon     :: TyCon
            , sum_arr_data_con  :: DataCon
            }

mkProduct :: [Type] -> VM Repr
mkProduct tys
  = do
      tycon <- builtin (prodTyCon arity)
      let [data_con] = tyConDataCons tycon
232

233
      (arr_tycon, _) <- parrayReprTyCon $ mkTyConApp tycon tys
234
235
      let [arr_data_con] = tyConDataCons arr_tycon

236
      return $ ProdRepr {
237
238
239
240
241
                 prod_components   = tys
               , prod_tycon        = tycon
               , prod_data_con     = data_con
               , prod_arr_tycon    = arr_tycon
               , prod_arr_data_con = arr_data_con
242
               }
243
244
  where
    arity = length tys
245

246
247
248
mkSum :: [Repr] -> VM Repr
mkSum [repr] = return repr
mkSum reprs
249
  = do
250
251
252
253
254
255
      tycon <- builtin (sumTyCon arity)
      (arr_tycon, _) <- parrayReprTyCon
                      . mkTyConApp tycon
                      $ map reprType reprs

      let [arr_data_con] = tyConDataCons arr_tycon
256
257

      return $ SumRepr {
258
259
260
261
                 sum_components   = reprs
               , sum_tycon        = tycon
               , sum_arr_tycon    = arr_tycon
               , sum_arr_data_con = arr_data_con
262
263
               }
  where
264
    arity = length reprs
265

266
267
268
reprProducts :: Repr -> [Repr]
reprProducts (SumRepr { sum_components = rs }) = rs
reprProducts repr                              = [repr]
269

270
271
272
273
274
reprType :: Repr -> Type
reprType (ProdRepr { prod_tycon = tycon, prod_components = tys })
  = mkTyConApp tycon tys
reprType (SumRepr { sum_tycon = tycon, sum_components = reprs })
  = mkTyConApp tycon (map reprType reprs)
275

276
277
arrReprType :: Repr -> VM Type
arrReprType = mkPArrayType . reprType
278

279
280
281
reprTys :: Repr -> [[Type]]
reprTys (SumRepr { sum_components = prods }) = map prodTys prods
reprTys prod                                 = [prodTys prod]
282

283
prodTys (ProdRepr { prod_components = tys }) = tys
284

285
286
reprVars :: Repr -> VM [[Var]]
reprVars = mapM (mapM (newLocalVar FSLIT("r"))) . reprTys
287

288
289
arrShapeTys :: Repr -> VM [Type]
arrShapeTys (SumRepr  {})
290
  = do
291
292
      int_arr <- builtin parrayIntPrimTyCon
      return [intPrimTy, mkTyConApp int_arr [], mkTyConApp int_arr []]
293
arrShapeTys repr = return [intPrimTy]
294

295
296
arrShapeVars :: Repr -> VM [Var]
arrShapeVars repr = mapM (newLocalVar FSLIT("sh")) =<< arrShapeTys repr
297

298
299
replicateShape :: Repr -> CoreExpr -> CoreExpr -> VM [CoreExpr]
replicateShape (ProdRepr {}) len _ = return [len]
300
301
302
303
304
replicateShape (SumRepr {})  len tag
  = do
      rep <- builtin replicatePAIntPrimVar
      up  <- builtin upToPAIntPrimVar
      return [len, Var rep `mkApps` [len, tag], Var up `App` len]
305

306
307
308
309
310
311
312
313
314
315
316
317
318
arrReprElemTys :: Repr -> [[Type]]
arrReprElemTys (SumRepr { sum_components = prods })
  = map arrProdElemTys prods
arrReprElemTys prod@(ProdRepr {})
  = [arrProdElemTys prod]

arrProdElemTys (ProdRepr { prod_components = [] })
  = [unitTy]
arrProdElemTys (ProdRepr { prod_components = tys })
  = tys

arrReprTys :: Repr -> VM [[Type]]
arrReprTys = mapM (mapM mkPArrayType) . arrReprElemTys
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

arrReprVars :: Repr -> VM [[Var]]
arrReprVars repr
  = mapM (mapM (newLocalVar FSLIT("rs"))) =<< arrReprTys repr

mkRepr :: TyCon -> VM Repr
mkRepr vect_tc
  = mkSum
  =<< mapM mkProduct (map dataConRepArgTys $ tyConDataCons vect_tc)

buildPReprType :: TyCon -> VM Type
buildPReprType = liftM reprType . mkRepr

buildToPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToPRepr repr vect_tc prepr_tc _
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
334
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
335
      arg    <- newLocalVar FSLIT("x") arg_ty
336
      result <- to_repr repr (Var arg)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
337

338
339
340
      return . Lam arg
             . wrapFamInstBody prepr_tc var_tys
             $ result
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
341
  where
342
343
344
    var_tys = mkTyVarTys $ tyConTyVars vect_tc
    arg_ty  = mkTyConApp vect_tc var_tys
    res_ty  = reprType repr
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
345

346
347
    cons    = tyConDataCons vect_tc
    [con]   = cons
348

349
350
351
    to_repr (SumRepr { sum_components = prods
                     , sum_tycon      = tycon })
            expr
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
352
      = do
353
354
355
356
357
358
359
360
361
362
          (vars, bodies) <- mapAndUnzipM prod_alt prods
          return . Case expr (mkWildId (exprType expr)) res_ty
                 $ zipWith4 mk_alt cons vars (tyConDataCons tycon) bodies
      where
        mk_alt con vars sum_con body
          = (DataAlt con, vars, mkConApp sum_con (ty_args ++ [body]))

        ty_args = map (Type . reprType) prods

    to_repr prod expr
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
363
      = do
364
365
366
          (vars, body) <- prod_alt prod
          return $ Case expr (mkWildId (exprType expr)) res_ty
                   [(DataAlt con, vars, body)]
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
367

368
369
    prod_alt (ProdRepr { prod_components = tys
                       , prod_data_con   = data_con })
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
370
      = do
371
372
          vars <- mapM (newLocalVar FSLIT("r")) tys
          return (vars, mkConApp data_con (map Type tys ++ map Var vars))
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
373

374
375
376
377
378
buildFromPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromPRepr repr vect_tc prepr_tc _
  = do
      arg_ty <- mkPReprType res_ty
      arg    <- newLocalVar FSLIT("x") arg_ty
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
379

380
381
382
383
384
385
      liftM (Lam arg)
           . from_repr repr
           $ unwrapFamInstScrut prepr_tc var_tys (Var arg)
  where
    var_tys = mkTyVarTys $ tyConTyVars vect_tc
    res_ty  = mkTyConApp vect_tc var_tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
386

387
388
    cons    = map (`mkConApp` map Type var_tys) (tyConDataCons vect_tc)
    [con]   = cons
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
389

390
391
392
393
394
395
396
397
398
399
400
    from_repr repr@(SumRepr { sum_components = prods
                            , sum_tycon      = tycon })
              expr
      = do
          vars   <- mapM (newLocalVar FSLIT("x")) (map reprType prods)
          bodies <- sequence . zipWith3 from_prod prods cons
                             $ map Var vars
          return . Case expr (mkWildId (reprType repr)) res_ty
                 $ zipWith3 sum_alt (tyConDataCons tycon) vars bodies
      where
        sum_alt data_con var body = (DataAlt data_con, [var], body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
401

402
403
404
405
406
407
408
409
410
411
412
413
414
    from_repr repr expr = from_prod repr con expr

    from_prod prod@(ProdRepr { prod_components = tys
                             , prod_data_con   = data_con })
              con
              expr
      = do
          vars <- mapM (newLocalVar FSLIT("y")) tys
          return $ Case expr (mkWildId (reprType prod)) res_ty
                   [(DataAlt data_con, vars, con `mkVarApps` vars)]

buildToArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildToArrPRepr repr vect_tc prepr_tc arr_tc
415
  = do
416
417
      arg_ty     <- mkPArrayType el_ty
      arg        <- newLocalVar FSLIT("xs") arg_ty
418
419
420
421
422

      res_ty     <- mkPArrayType (reprType repr)

      shape_vars <- arrShapeVars repr
      repr_vars  <- arrReprVars  repr
423
424
425

      parray_co  <- mkBuiltinCo parrayTyCon

426
427
428
429
430
431
      let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
          co           = mkAppCoercion parray_co
                       . mkSymCoercion
                       $ mkTyConApp repr_co var_tys

          scrut   = unwrapFamInstScrut arr_tc var_tys (Var arg)
432

433
      result <- to_repr shape_vars repr_vars repr
434
435
436

      return . Lam arg
             . mkCoerce co
437
438
             $ Case scrut (mkWildId (mkTyConApp arr_tc var_tys)) res_ty
               [(DataAlt arr_dc, shape_vars ++ concat repr_vars, result)]
439
440
441
  where
    var_tys = mkTyVarTys $ tyConTyVars vect_tc
    el_ty   = mkTyConApp vect_tc var_tys
442
443
444

    [arr_dc] = tyConDataCons arr_tc

445
446
447
448
449
450
451
    to_repr shape_vars@(len_var : _)
            repr_vars
            (SumRepr { sum_components   = prods
                     , sum_arr_tycon    = tycon
                     , sum_arr_data_con = data_con })
      = do
          exprs <- zipWithM (to_prod len_var) repr_vars prods
452

453
454
455
456
457
          return . wrapFamInstBody tycon tys
                 . mkConApp data_con
                 $ map Type tys ++ map Var shape_vars ++ exprs
      where
        tys = map reprType prods
458

459
    to_repr [len_var] [repr_vars] prod = to_prod len_var repr_vars prod
460

461
462
463
464
465
466
467
468
    to_prod len_var
            repr_vars
            (ProdRepr { prod_components   = tys
                      , prod_arr_tycon    = tycon
                      , prod_arr_data_con = data_con })
      = return . wrapFamInstBody tycon tys
               . mkConApp data_con
               $ map Type tys ++ map Var (len_var : repr_vars)
469

470
471
buildFromArrPRepr :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
buildFromArrPRepr repr vect_tc prepr_tc arr_tc
472
  = do
473
      arg_ty     <- mkPArrayType =<< mkPReprType el_ty
474
475
      arg        <- newLocalVar FSLIT("xs") arg_ty

476
477
478
479
      res_ty     <- mkPArrayType el_ty

      shape_vars <- arrShapeVars repr
      repr_vars  <- arrReprVars  repr
480
481
482

      parray_co  <- mkBuiltinCo parrayTyCon

483
484
485
      let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
          co           = mkAppCoercion parray_co
                       $ mkTyConApp repr_co var_tys
486

487
          scrut  = mkCoerce co (Var arg)
488

489
490
491
          result = wrapFamInstBody arr_tc var_tys
                 . mkConApp arr_dc
                 $ map Type var_tys ++ map Var (shape_vars ++ concat repr_vars)
492

493
494
      liftM (Lam arg)
            (from_repr repr scrut shape_vars repr_vars res_ty result)
495
496
497
498
499
  where
    var_tys = mkTyVarTys $ tyConTyVars vect_tc
    el_ty   = mkTyConApp vect_tc var_tys

    [arr_dc] = tyConDataCons arr_tc
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    from_repr (SumRepr { sum_components   = prods
                       , sum_arr_tycon    = tycon
                       , sum_arr_data_con = data_con })
              expr
              shape_vars
              repr_vars
              res_ty
              body
      = do
          vars <- mapM (newLocalVar FSLIT("xs")) =<< mapM arrReprType prods
          result <- go prods repr_vars vars body

          let scrut = unwrapFamInstScrut tycon ty_args expr
          return . Case scrut (mkWildId scrut_ty) res_ty
                 $ [(DataAlt data_con, shape_vars ++ vars, result)]
      where
        ty_args  = map reprType prods
        scrut_ty = mkTyConApp tycon ty_args

        go [] [] [] body = return body
        go (prod : prods) (repr_vars : rss) (var : vars) body
          = do
              shape_vars <- mapM (newLocalVar FSLIT("s")) =<< arrShapeTys prod

              from_prod prod (Var var) shape_vars repr_vars res_ty
                =<< go prods rss vars body

    from_repr repr expr shape_vars [repr_vars] res_ty body
      = from_prod repr expr shape_vars repr_vars res_ty body

    from_prod prod@(ProdRepr { prod_components = tys
                             , prod_arr_tycon  = tycon
                             , prod_arr_data_con = data_con })
              expr
              shape_vars
              repr_vars
              res_ty
              body
      = do
          let scrut    = unwrapFamInstScrut tycon tys expr
              scrut_ty = mkTyConApp tycon tys
          ty <- arrReprType prod

          return $ Case scrut (mkWildId scrut_ty) res_ty
                   [(DataAlt data_con, shape_vars ++ repr_vars, body)]

buildPRDictRepr :: Repr -> VM CoreExpr
548
buildPRDictRepr (ProdRepr {
549
550
                   prod_components = tys
                 , prod_tycon      = tycon
551
                 })
552
  = do
553
554
555
      prs  <- mapM mkPR tys
      dfun <- prDFunOfTyCon tycon
      return $ dfun `mkTyApps` tys `mkApps` prs
556

557
buildPRDictRepr (SumRepr {
558
559
                   sum_components = prods
                 , sum_tycon      = tycon })
560
  = do
561
562
563
      prs  <- mapM buildPRDictRepr prods
      dfun <- prDFunOfTyCon tycon
      return $ dfun `mkTyApps` map reprType prods `mkApps` prs
564

565
buildPRDict :: Repr -> TyCon -> TyCon -> TyCon -> VM CoreExpr
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
buildPRDict repr vect_tc prepr_tc _
  = do
      dict  <- buildPRDictRepr repr

      pr_co <- mkBuiltinCo prTyCon
      let co = mkAppCoercion pr_co
             . mkSymCoercion
             $ mkTyConApp arg_co var_tys

      return $ mkCoerce co dict
  where
    var_tys = mkTyVarTys $ tyConTyVars vect_tc

    Just arg_co = tyConFamilyCoercion_maybe prepr_tc

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
581
582
buildPArrayTyCon :: TyCon -> TyCon -> VM TyCon
buildPArrayTyCon orig_tc vect_tc = fixV $ \repr_tc ->
583
  do
584
585
    name'  <- cloneName mkPArrayTyConOcc orig_name
    rhs    <- buildPArrayTyConRhs orig_name vect_tc repr_tc
586
587
588
589
590
591
592
593
594
    parray <- builtin parrayTyCon

    liftDs $ buildAlgTyCon name'
                           tyvars
                           []          -- no stupid theta
                           rhs
                           rec_flag    -- FIXME: is this ok?
                           False       -- FIXME: no generics
                           False       -- not GADT syntax
595
                           (Just $ mk_fam_inst parray vect_tc)
596
  where
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
597
    orig_name = tyConName orig_tc
598
599
    tyvars = tyConTyVars vect_tc
    rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
600

601

602
603
buildPArrayTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
buildPArrayTyConRhs orig_name vect_tc repr_tc
604
  = do
605
      data_con <- buildPArrayDataCon orig_name vect_tc repr_tc
606
      return $ DataTyCon { data_cons = [data_con], is_enum = False }
607

608
609
buildPArrayDataCon :: Name -> TyCon -> TyCon -> VM DataCon
buildPArrayDataCon orig_name vect_tc repr_tc
610
  = do
611
      dc_name  <- cloneName mkPArrayDataConOcc orig_name
612
      repr     <- mkRepr vect_tc
613

614
615
616
      shape_tys <- arrShapeTys repr
      repr_tys  <- arrReprTys  repr

617
      let tys = shape_tys ++ concat repr_tys
618
619
620

      liftDs $ buildDataCon dc_name
                            False                  -- not infix
621
                            (map (const NotMarkedStrict) tys)
622
623
624
625
626
                            []                     -- no field labels
                            (tyConTyVars vect_tc)
                            []                     -- no existentials
                            []                     -- no eq spec
                            []                     -- no context
627
                            tys
628
                            repr_tc
629

630
631
632
633
mkPADFun :: TyCon -> VM Var
mkPADFun vect_tc
  = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc

634
635
636
buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
                   -> VM [(Var, CoreExpr)]
buildTyConBindings orig_tc vect_tc prepr_tc arr_tc dfun
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
637
  = do
638
      repr  <- mkRepr vect_tc
639
      vectDataConWorkers repr orig_tc vect_tc arr_tc
640
      dict <- buildPADict repr vect_tc prepr_tc arr_tc dfun
641
642
      binds <- takeHoisted
      return $ (dfun, dict) : binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
643
644
645
646
647
648
649
  where
    orig_dcs = tyConDataCons orig_tc
    vect_dcs = tyConDataCons vect_tc
    [arr_dc] = tyConDataCons arr_tc

    repr_tys = map dataConRepArgTys vect_dcs

650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
vectDataConWorkers :: Repr -> TyCon -> TyCon -> TyCon
                   -> VM ()
vectDataConWorkers repr orig_tc vect_tc arr_tc
  = do
      bs <- sequence
          . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
          $ zipWith4 mk_data_con (tyConDataCons vect_tc)
                                 rep_tys
                                 (inits arr_tys)
                                 (tail $ tails arr_tys)
      mapM_ (uncurry hoistBinding) bs
  where
    tyvars   = tyConTyVars vect_tc
    var_tys  = mkTyVarTys tyvars
    ty_args  = map Type var_tys

    res_ty   = mkTyConApp vect_tc var_tys

    rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
    arr_tys  = arrReprElemTys repr

    [arr_dc] = tyConDataCons arr_tc

    mk_data_con con tys pre post
      = liftM2 (,) (vect_data_con con)
                   (lift_data_con tys (concat pre)
                                      (concat post)
                                      (mkDataConTag con))

    vect_data_con con = return $ mkConApp con ty_args
    lift_data_con tys pre_tys post_tys tag
      = do
          len  <- builtin liftingContext
          args <- mapM (newLocalVar FSLIT("xs"))
                  =<< mapM mkPArrayType tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
685

686
687
          shape <- replicateShape repr (Var len) tag
          repr  <- mk_arr_repr (Var len) (map Var args)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
688

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
          pre   <- mapM emptyPA pre_tys
          post  <- mapM emptyPA post_tys

          return . mkLams (len : args)
                 . wrapFamInstBody arr_tc var_tys
                 . mkConApp arr_dc
                 $ ty_args ++ shape ++ pre ++ repr ++ post

    mk_arr_repr len []
      = do
          units <- replicatePA len (Var unitDataConId)
          return [units]

    mk_arr_repr len arrs = return arrs

    def_worker data_con arg_tys mk_body
      = do
          body <- closedV
                . inBind orig_worker
                . polyAbstract tyvars $ \abstract ->
                  liftM (abstract . vectorised)
                $ buildClosures tyvars [] arg_tys res_ty mk_body

          vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
          defGlobalVar orig_worker vect_worker
          return (vect_worker, body)
      where
        orig_worker = dataConWorkId data_con

718
buildPADict :: Repr -> TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
719
buildPADict repr vect_tc prepr_tc arr_tc dfun
720
  = polyAbstract tvs $ \abstract ->
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
721
    do
722
      meth_binds <- mapM (mk_method repr) paMethods
723
      let meth_exprs = map (Var . fst) meth_binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
724

725
      pa_dc <- builtin paDataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
726
      let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
727
          body = Let (Rec meth_binds) dict
728
      return . mkInlineMe $ abstract body
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
729
730
731
732
  where
    tvs = tyConTyVars arr_tc
    arg_tys = mkTyVarTys tvs

733
    mk_method repr (name, build)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
734
735
      = localV
      $ do
736
          body <- build repr vect_tc prepr_tc arr_tc
737
          var  <- newLocalVar name (exprType body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
738
          return (var, mkInlineMe body)
739

740
741
742
743
744
paMethods = [(FSLIT("toPRepr"),      buildToPRepr),
             (FSLIT("fromPRepr"),    buildFromPRepr),
             (FSLIT("toArrPRepr"),   buildToArrPRepr),
             (FSLIT("fromArrPRepr"), buildFromArrPRepr),
             (FSLIT("dictPRepr"),    buildPRDict)]
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
745

746
747
748
-- | Split the given tycons into two sets depending on whether they have to be
-- converted (first list) or not (second list). The first argument contains
-- information about the conversion status of external tycons:
749
--
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
--   * tycons which have converted versions are mapped to True
--   * tycons which are not changed by vectorisation are mapped to False
--   * tycons which can't be converted are not elements of the map
--
classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
classifyTyCons = classify [] []
  where
    classify conv keep cs [] = (conv, keep)
    classify conv keep cs ((tcs, ds) : rs)
      | can_convert && must_convert
        = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
      | can_convert
        = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
      | otherwise
        = classify conv keep cs rs
      where
        refs = ds `delListFromUniqSet` tcs

        can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
        must_convert = foldUFM (||) False (intersectUFM_C const cs refs)

        convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
772

773
774
775
776
777
778
779
780
781
782
783
784
785
786
-- | Compute mutually recursive groups of tycons in topological order
--
tyConGroups :: [TyCon] -> [TyConGroup]
tyConGroups tcs = map mk_grp (stronglyConnComp edges)
  where
    edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
                                , let ds = tyConsOfTyCon tc]

    mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
    mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
      where
        (tcs, dss) = unzip els

tyConsOfTyCon :: TyCon -> UniqSet TyCon
787
tyConsOfTyCon
788
789
790
791
792
793
  = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons

tyConsOfType :: Type -> UniqSet TyCon
tyConsOfType ty
  | Just ty' <- coreView ty    = tyConsOfType ty'
tyConsOfType (TyVarTy v)       = emptyUniqSet
794
795
796
797
798
799
800
tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
  where
    extend | isUnLiftedTyCon tc
           || isTupleTyCon   tc = id

           | otherwise          = (`addOneToUniqSet` tc)

801
802
803
804
805
806
807
808
809
tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
                                 `addOneToUniqSet` funTyCon
tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other

tyConsOfTypes :: [Type] -> UniqSet TyCon
tyConsOfTypes = unionManyUniqSets . map tyConsOfType