VectUtils.hs 19.3 KB
Newer Older
1
module VectUtils (
2
  collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3
  collectAnnValBinders,
4
  mkDataConTag,
5
  splitClosureTy,
6
  mkPRepr, mkToPRepr, mkToArrPRepr, mkFromPRepr, mkFromArrPRepr,
7
  mkPADictType, mkPArrayType, mkPReprType,
8
  parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
9
  prDictOfType, prCoerce,
10
  paDictArgType, paDictOfType, paDFunType,
11
  paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12
  polyAbstract, polyApply, polyVApply,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
13
  hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
14
15
  buildClosure, buildClosures,
  mkClosureApp
16
17
18
19
) where

#include "HsVersions.h"

20
import VectCore
21
22
import VectMonad

23
import DsUtils
24
import CoreSyn
25
import CoreUtils
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26
import Coercion
27
28
import Type
import TypeRep
29
import TyCon
30
import DataCon            ( DataCon, dataConWrapId, dataConTag )
31
import Var
32
33
import Id                 ( mkWildId )
import MkId               ( unwrapFamInstScrut )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
34
import Name               ( Name )
35
import PrelNames
36
37
import TysWiredIn
import BasicTypes         ( Boxity(..) )
38

39
import Outputable
40
import FastString
41

42
import Data.List             ( zipWith4 )
43
import Control.Monad         ( liftM, liftM2, zipWithM_ )
44

45
46
47
48
49
50
51
52
53
54
55
56
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
  where
    go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
    go e                             tys = (e, tys)

collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
    go bs e                           = (reverse bs, e)

57
58
59
60
61
62
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isId b = go (b:bs) e
    go bs e                        = (reverse bs, e)

63
64
65
66
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType t) = True
isAnnTypeArg _              = False

67
68
69
mkDataConTag :: DataCon -> CoreExpr
mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
70
71
72
73
74
splitUnTy :: String -> Name -> Type -> Type
splitUnTy s name ty
  | Just (tc, [ty']) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = ty'
75

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
76
  | otherwise = pprPanic s (ppr ty)
77

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
78
79
80
81
82
splitBinTy :: String -> Name -> Type -> (Type, Type)
splitBinTy s name ty
  | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = (ty1, ty2)
83

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
84
  | otherwise = pprPanic s (ppr ty)
85

86
87
88
89
90
splitFixedTyConApp :: TyCon -> Type -> [Type]
splitFixedTyConApp tc ty
  | Just (tc', tys) <- splitTyConApp_maybe ty
  , tc == tc'
  = tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
91

92
  | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
93
94
95

splitEmbedTy :: Type -> Type
splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
96

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
97
98
99
100
101
splitClosureTy :: Type -> (Type, Type)
splitClosureTy = splitBinTy "splitClosureTy" closureTyConName

splitPArrayTy :: Type -> Type
splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
103
104
mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
mkBuiltinTyConApp get_tc tys
105
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
106
107
      tc <- builtin get_tc
      return $ mkTyConApp tc tys
108

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
109
110
mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
mkBuiltinTyConApps get_tc tys ty
111
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
112
113
      tc <- builtin get_tc
      return $ foldr (mk tc) ty tys
114
  where
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
115
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
117
118
119
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 get_tc dft [] = return dft
mkBuiltinTyConApps1 get_tc dft tys
120
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
121
122
123
124
125
126
127
      tc <- builtin get_tc
      case tys of
        [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
        _  -> return $ foldr1 (mk tc) tys
  where
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]

128
129
mkPRepr :: [[Type]] -> VM Type
mkPRepr tys
130
  = do
131
132
133
      embed_tc <- builtin embedTyCon
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
134

135
136
137
      let mk_sum []   = unitTy
          mk_sum [ty] = ty
          mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
138

139
140
141
          mk_prod []   = unitTy
          mk_prod [ty] = ty
          mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
142

143
          mk_embed ty = mkTyConApp embed_tc [ty]
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
144

145
      return . mk_sum
146
             . map (mk_prod . map mk_embed)
147
             $ tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
148

149
150
mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
mkToPRepr ess
151
  = do
152
153
      embed_tc <- builtin embedTyCon
      embed_dc <- builtin embedDataCon
154
155
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
156

157
      let mk_sum [] = ([Var unitDataConId], unitTy)
158
          mk_sum [(expr, ty)] = ([expr], ty)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
          mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
                       mkTyConApp sum_tc tys)
            where
              (exprs, tys)   = unzip es
              sum_tc         = sum_tcs (length es)
              mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])

          mk_prod [] = (Var unitDataConId, unitTy)
          mk_prod [(expr, ty)] = (expr, ty)
          mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
                        mkTyConApp prod_tc tys)
            where
              (exprs, tys) = unzip es
              prod_tc      = prod_tcs (length es)
              [prod_dc]    = tyConDataCons prod_tc

          mk_embed expr = (mkConApp embed_dc [Type ty, expr],
                           mkTyConApp embed_tc [ty])
            where ty = exprType expr
178

179
      return . mk_sum $ map (mk_prod . map mk_embed) ess
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
mkToArrPRepr len sel ess
  = do
      embed_tc <- builtin embedTyCon
      (embed_rtc, _) <- parrayReprTyCon (mkTyConApp embed_tc [unitTy])
      let [embed_rdc] = tyConDataCons embed_rtc

      let mk_sum [(expr, ty)] = return (expr, ty)
          mk_sum es
            = do
                sum_tc <- builtin . sumTyCon $ length es
                (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
                let [sum_rdc] = tyConDataCons sum_rtc

                return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
                        mkTyConApp sum_tc tys)
            where
              (exprs, tys) = unzip es

          mk_prod [(expr, ty)] = return (expr, ty)
          mk_prod es
            = do
                prod_tc <- builtin . prodTyCon $ length es
                (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
                let [prod_rdc] = tyConDataCons prod_rtc

                return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
                        mkTyConApp prod_tc tys)
            where
              (exprs, tys) = unzip es

          mk_embed expr = (mkConApp embed_rdc [Type ty, expr],
                           mkTyConApp embed_tc [ty])
            where ty = splitPArrayTy (exprType expr)

      liftM fst (mk_sum =<< mapM (mk_prod . map mk_embed) ess)

218
219
220
221
mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
mkFromPRepr scrut res_ty alts
  = do
      embed_dc <- builtin embedDataCon
222
223
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
224

225
226
      let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
          un_sum expr ty bs
227
            = do
228
229
230
231
232
                ps     <- mapM (newLocalVar FSLIT("p")) tys
                bodies <- sequence
                        $ zipWith4 un_prod (map Var ps) tys vars rs
                return . Case expr (mkWildId ty) res_ty
                       $ zipWith3 mk_alt sum_dcs ps bodies
233
            where
234
235
236
237
              (vars, rs) = unzip bs
              tys        = splitFixedTyConApp sum_tc ty
              sum_tc     = sum_tcs $ length bs
              sum_dcs    = tyConDataCons sum_tc
238

239
              mk_alt dc p body = (DataAlt dc, [p], body)
240

241
242
243
          un_prod expr ty []    r = return r
          un_prod expr ty [var] r = return $ un_embed expr ty var r
          un_prod expr ty vars  r
244
            = do
245
246
247
248
249
                xs <- mapM (newLocalVar FSLIT("x")) tys
                let body = foldr (\(e,t,v) r -> un_embed e t v r) r
                         $ zip3 (map Var xs) tys vars
                return $ Case expr (mkWildId ty) res_ty
                         [(DataAlt prod_dc, xs, body)]
250
            where
251
252
253
254
255
256
257
              tys       = splitFixedTyConApp prod_tc ty
              prod_tc   = prod_tcs $ length vars
              [prod_dc] = tyConDataCons prod_tc

          un_embed expr ty var r
            = Case expr (mkWildId ty) res_ty
                [(DataAlt embed_dc, [var], r)]
258
259
260

      un_sum scrut (exprType scrut) alts

261
262
263
264
265
mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
               -> VM CoreExpr
mkFromArrPRepr scrut res_ty len sel vars res
  = return (Var unitDataConId)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
266
267
268
269
270
271
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]

mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon

272
273
274
mkPReprType :: Type -> VM Type
mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
275
276
mkPADictType :: Type -> VM Type
mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
277
278

mkPArrayType :: Type -> VM Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
279
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
280

281
282
283
284
285
286
287
288
289
290
291
parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
parrayCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      parray <- builtin parrayTyCon

      let co = mkAppCoercion (mkTyConApp parray [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

292
293
294
295
296
297
298
299
300
301
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])

parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
  = do
      (tc, arg_tys) <- parrayReprTyCon ty
      let [dc] = tyConDataCons tc
      return (dc, arg_tys)

302
303
304
305
306
307
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut (ve, le)
  = do
      (tc, arg_tys) <- parrayReprTyCon (exprType ve)
      return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
prDictOfType :: Type -> VM CoreExpr
prDictOfType orig_ty
  | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
  = do
      dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
      prDFunApply (Var dfun) ty_args

prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
prDFunApply dfun tys
  = do
      args <- mapM mkDFunArg arg_tys
      return $ mkApps mono_dfun args
  where
    mono_dfun    = mkTyApps dfun tys
    (arg_tys, _) = splitFunTys (exprType mono_dfun)

mkDFunArg :: Type -> VM CoreExpr
mkDFunArg ty
  | Just (tycon, [arg]) <- splitTyConApp_maybe ty

  = let name = tyConName tycon

        get_dict | name == paTyConName = paDictOfType
                 | name == prTyConName = prDictOfType
                 | otherwise           = pprPanic "mkDFunArg" (ppr ty)

    in get_dict arg

mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
338
339
340
341
342
343
344
345
346
347
348
prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
prCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      pr_tc <- builtin prTyCon

      let co = mkAppCoercion (mkTyConApp pr_tc [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
  where
    go ty k | Just k' <- kindView k = go ty k'
    go ty (FunTy k1 k2)
      = do
          tv   <- newTyVar FSLIT("a") k1
          mty1 <- go (TyVarTy tv) k1
          case mty1 of
            Just ty1 -> do
                          mty2 <- go (AppTy ty (TyVarTy tv)) k2
                          return $ fmap (ForAllTy tv . FunTy ty1) mty2
            Nothing  -> go ty k2

    go ty k
      | isLiftedTypeKind k
365
      = liftM Just (mkPADictType ty)
366
367
368

    go ty k = return Nothing

369
370
371
372
373
374
375
376
377
378
379
380
381
382
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
  where
    (ty_fn, ty_args) = splitAppTys ty

paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
  | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
  = do
      dfun <- maybeV (lookupTyVarPA tv)
      paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
  = do
383
      dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
384
      paDFunApply (Var dfun) ty_args
385
386
paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)

387
388
389
390
391
392
393
394
395
396
397
paDFunType :: TyCon -> VM Type
paDFunType tc
  = do
      margs <- mapM paDictArgType tvs
      res   <- mkPADictType (mkTyConApp tc arg_tys)
      return . mkForAllTys tvs
             $ mkFunTys [arg | Just arg <- margs] res
  where
    tvs = tyConTyVars tc
    arg_tys = mkTyVarTys tvs

398
399
400
401
402
403
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mkApps (mkTyApps dfun tys) dicts

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
404
405
406
407
408
409
410
paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
paMethod method ty
  = do
      fn   <- builtin method
      dict <- paDictOfType ty
      return $ mkApps (Var fn) [Type ty, dict]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
411
lengthPA :: CoreExpr -> VM CoreExpr
412
413
414
lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
  where
    ty = splitPArrayTy (exprType x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
415
416
417
418
419

replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
                          (paMethod replicatePAVar (exprType x))

420
421
422
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod emptyPAVar

423
424
425
426
427
428
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
  = do
      lc <- builtin liftingContext
      replicatePA (Var lc) x

429
430
431
432
433
434
435
436
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
  = do
      lty <- mkPArrayType vty
      vv  <- newLocalVar fs vty
      lv  <- newLocalVar fs lty
      return (vv,lv)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
437
438
439
440
polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
polyAbstract tvs p
  = localV
  $ do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
441
442
443
444
445
446
447
448
449
450
451
452
      mdicts <- mapM mk_dict_var tvs
      zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
      p (mk_lams mdicts)
  where
    mk_dict_var tv = do
                       r <- paDictArgType tv
                       case r of
                         Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
                         Nothing -> return Nothing

    mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
453
454
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
455
456
457
458
  = do
      dicts <- mapM paDictOfType tys
      return $ expr `mkTyApps` tys `mkApps` dicts

459
460
461
462
463
464
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
465
466
467
468
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
  env { global_bindings = (v,e) : global_bindings env }

469
470
471
472
hoistExpr :: FastString -> CoreExpr -> VM Var
hoistExpr fs expr
  = do
      var <- newLocalVar fs (exprType expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
473
      hoistBinding var expr
474
475
      return var

476
477
hoistVExpr :: VExpr -> VM VVar
hoistVExpr (ve, le)
478
  = do
479
      fs <- getBindName
480
481
482
      vv <- hoistExpr ('v' `consFS` fs) ve
      lv <- hoistExpr ('l' `consFS` fs) le
      return (vv, lv)
483

484
485
hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
hoistPolyVExpr tvs p
486
  = do
487
488
      expr <- closedV . polyAbstract tvs $ \abstract ->
              liftM (mapVect abstract) p
489
      fn   <- hoistVExpr expr
490
      polyVApply (vVar fn) (mkTyVarTys tvs)
491

492
493
494
495
496
497
498
takeHoisted :: VM [(Var, CoreExpr)]
takeHoisted
  = do
      env <- readGEnv id
      setGEnv $ env { global_bindings = [] }
      return $ global_bindings env

499
500
mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
501
  = do
502
503
504
505
506
      dict <- paDictOfType env_ty
      mkv  <- builtin mkClosureVar
      mkl  <- builtin mkClosurePVar
      return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
              Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
507

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
508
509
510
511
512
513
514
515
516
517
mkClosureApp :: VExpr -> VExpr -> VM VExpr
mkClosureApp (vclo, lclo) (varg, larg)
  = do
      vapply <- builtin applyClosureVar
      lapply <- builtin applyClosurePVar
      return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
              Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
  where
    (arg_ty, res_ty) = splitClosureTy (exprType vclo)

518
buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
519
520
buildClosures tvs vars [] res_ty mk_body
  = mk_body
521
522
523
buildClosures tvs vars [arg_ty] res_ty mk_body
  = buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
524
525
526
  = do
      res_ty' <- mkClosureTypes arg_tys res_ty
      arg <- newLocalVVar FSLIT("x") arg_ty
527
      buildClosure tvs vars arg_ty res_ty'
528
        . hoistPolyVExpr tvs
529
        $ do
530
531
            lc <- builtin liftingContext
            clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
532
533
            return $ vLams lc (vars ++ [arg]) clo

534
535
536
537
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
--   where
--     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
--     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
538
--
539
540
buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
buildClosure tvs vars arg_ty res_ty mk_body
541
  = do
542
      (env_ty, env, bind) <- buildEnv vars
543
544
      env_bndr <- newLocalVVar FSLIT("env") env_ty
      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
545

546
      fn <- hoistPolyVExpr tvs
547
          $ do
548
              lc    <- builtin liftingContext
549
550
              body  <- mk_body
              body' <- bind (vVar env_bndr)
551
                            (vVarApps lc body (vars ++ [arg_bndr]))
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
552
              return (vLamsWithoutLC [env_bndr, arg_bndr] body')
553

554
      mkClosure arg_ty res_ty env_ty fn env
555

556
557
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
buildEnv vvs
558
  = do
559
      lc <- builtin liftingContext
560
      let (ty, venv, vbind) = mkVectEnv tys vs
561
      (lenv, lbind) <- mkLiftEnv lc tys ls
562
      return (ty, (venv, lenv),
563
564
565
566
567
              \(venv,lenv) (vbody,lbody) ->
              do
                let vbody' = vbind venv vbody
                lbody' <- lbind lenv lbody
                return (vbody', lbody'))
568
  where
569
570
    (vs,ls) = unzip vvs
    tys     = map idType vs
571
572
573
574
575
576
577
578
579
580

mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
                        \env body -> Case env (mkWildId ty) (exprType body)
                                       [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
  where
    ty = mkCoreTupTy tys

581
mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
582
mkLiftEnv lc [ty] [v]
583
584
585
586
  = return (Var v, \env body ->
                   do
                     len <- lengthPA (Var v)
                     return . Let (NonRec v env)
587
                            $ Case len lc (exprType body) [(DEFAULT, [], body)])
588
589

-- NOTE: this transparently deals with empty environments
590
mkLiftEnv lc tys vs
591
  = do
592
      (env_tc, env_tyargs) <- parrayReprTyCon vty
593
594
595
      let [env_con] = tyConDataCons env_tc
          
          env = Var (dataConWrapId env_con)
596
                `mkTyApps`  env_tyargs
597
                `mkVarApps` (lc : vs)
598

599
600
          bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                          in
601
602
                          return $ Case scrut (mkWildId (exprType scrut))
                                        (exprType body)
603
                                        [(DataAlt env_con, lc : bndrs, body)]
604
605
606
607
      return (env, bind)
  where
    vty = mkCoreTupTy tys

608
609
610
    bndrs | null vs   = [mkWildId unitTy]
          | otherwise = vs