VectUtils.hs 17.4 KB
Newer Older
1
module VectUtils (
2
  collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3
  collectAnnValBinders,
4
  mkDataConTag,
5
  splitClosureTy,
6
7
  mkPRepr, mkToPRepr, mkFromPRepr,
  mkPADictType, mkPArrayType, mkPReprType,
8
  parrayReprTyCon, parrayReprDataCon, mkVScrut,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
9
  prDictOfType, prCoerce,
10
  paDictArgType, paDictOfType, paDFunType,
11
  paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12
  polyAbstract, polyApply, polyVApply,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
13
  hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
14
15
  buildClosure, buildClosures,
  mkClosureApp
16
17
18
19
) where

#include "HsVersions.h"

20
import VectCore
21
22
import VectMonad

23
import DsUtils
24
import CoreSyn
25
import CoreUtils
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
26
import Coercion
27
28
import Type
import TypeRep
29
import TyCon
30
import DataCon            ( DataCon, dataConWrapId, dataConTag )
31
import Var
32
33
import Id                 ( mkWildId )
import MkId               ( unwrapFamInstScrut )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
34
import Name               ( Name )
35
import PrelNames
36
37
import TysWiredIn
import BasicTypes         ( Boxity(..) )
38

39
import Outputable
40
import FastString
41

42
import Control.Monad         ( liftM, liftM2, zipWithM_ )
43

44
45
46
47
48
49
50
51
52
53
54
55
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
  where
    go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
    go e                             tys = (e, tys)

collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
    go bs e                           = (reverse bs, e)

56
57
58
59
60
61
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isId b = go (b:bs) e
    go bs e                        = (reverse bs, e)

62
63
64
65
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType t) = True
isAnnTypeArg _              = False

66
67
68
mkDataConTag :: DataCon -> CoreExpr
mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
69
70
71
72
73
splitUnTy :: String -> Name -> Type -> Type
splitUnTy s name ty
  | Just (tc, [ty']) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = ty'
74

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
75
  | otherwise = pprPanic s (ppr ty)
76

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
77
78
79
80
81
splitBinTy :: String -> Name -> Type -> (Type, Type)
splitBinTy s name ty
  | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = (ty1, ty2)
82

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
83
  | otherwise = pprPanic s (ppr ty)
84

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
85
86
87
88
89
90
91
92
splitCrossTy :: Type -> (Type, Type)
splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName

splitPlusTy :: Type -> (Type, Type)
splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName

splitEmbedTy :: Type -> Type
splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
93

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
94
95
96
97
98
splitClosureTy :: Type -> (Type, Type)
splitClosureTy = splitBinTy "splitClosureTy" closureTyConName

splitPArrayTy :: Type -> Type
splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
99

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
100
101
mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
mkBuiltinTyConApp get_tc tys
102
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
103
104
      tc <- builtin get_tc
      return $ mkTyConApp tc tys
105

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
106
107
mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
mkBuiltinTyConApps get_tc tys ty
108
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
109
110
      tc <- builtin get_tc
      return $ foldr (mk tc) ty tys
111
  where
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
112
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
113

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
114
115
116
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 get_tc dft [] = return dft
mkBuiltinTyConApps1 get_tc dft tys
117
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
118
119
120
121
122
123
124
      tc <- builtin get_tc
      case tys of
        [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
        _  -> return $ foldr1 (mk tc) tys
  where
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]

125
126
127
mkPRepr :: [[Type]] -> VM Type
mkPRepr [] = return unitTy
mkPRepr tys
128
  = do
129
130
131
      embed <- builtin embedTyCon
      cross <- builtin crossTyCon
      plus  <- builtin plusTyCon
132

133
134
135
      let mk_embed ty      = mkTyConApp embed [ty]
          mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
          mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
136

137
138
          mk_tup   []      = unitTy
          mk_tup   tys     = foldr1 mk_cross tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
139

140
141
          mk_sum   []      = unitTy
          mk_sum   tys     = foldr1 mk_plus  tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
142

143
144
145
      return . mk_sum
             . map (mk_tup . map mk_embed)
             $ tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
146

147
148
mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
mkToPRepr ess
149
  = do
150
151
      embed_tc <- builtin embedTyCon
      embed_dc <- builtin embedDataCon
152
153
      cross_tc <- builtin crossTyCon
      cross_dc <- builtin crossDataCon
154
155
156
      plus_tc  <- builtin plusTyCon
      left_dc  <- builtin leftDataCon
      right_dc <- builtin rightDataCon
157

158
159
      let mk_embed expr
            = (mkConApp   embed_dc [Type ty, expr],
160
               mkTyConApp embed_tc [ty])
161
            where ty = exprType expr
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

          mk_cross (expr1, ty1) (expr2, ty2)
            = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
               mkTyConApp cross_tc [ty1, ty2])

          mk_tup [] = (Var unitDataConId, unitTy)
          mk_tup es = foldr1 mk_cross es

          mk_sum []           = ([Var unitDataConId], unitTy)
          mk_sum [(expr, ty)] = ([expr], ty)
          mk_sum ((expr, lty) : es)
            = let (alts, rty) = mk_sum es
              in
              (mkConApp left_dc [Type lty, Type rty, expr]
                 : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
               mkTyConApp plus_tc [lty, rty])
178
179

      return . mk_sum $ map (mk_tup . map mk_embed) ess
180

181
182
183
184
185
186
187
188
189
190
mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
mkFromPRepr scrut res_ty alts
  = do
      embed_dc <- builtin embedDataCon
      cross_dc <- builtin crossDataCon
      left_dc  <- builtin leftDataCon
      right_dc <- builtin rightDataCon
      pa_tc    <- builtin paTyCon

      let un_embed expr ty var res
191
192
            = Case expr (mkWildId ty) res_ty
                   [(DataAlt embed_dc, [var], res)]
193
194
195
196
197
198

          un_cross expr ty var1 var2 res
            = Case expr (mkWildId ty) res_ty
                [(DataAlt cross_dc, [var1, var2], res)]

          un_tup expr ty []    res = return res
199
          un_tup expr ty [var] res = return $ un_embed expr ty var res
200
201
202
203
          un_tup expr ty (var : vars) res
            = do
                lv <- newLocalVar FSLIT("x") lty
                rv <- newLocalVar FSLIT("y") rty
204
205
206
                liftM (un_cross expr ty lv rv
                      . un_embed (Var lv) lty var)
                      (un_tup (Var rv) rty vars res)
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            where
              (lty, rty) = splitCrossTy ty

          un_plus expr ty var1 var2 res1 res2
            = Case expr (mkWildId ty) res_ty
                [(DataAlt left_dc,  [var1], res1),
                 (DataAlt right_dc, [var2], res2)]

          un_sum expr ty [(vars, res)] = un_tup expr ty vars res
          un_sum expr ty ((vars, res) : alts)
            = do
                lv <- newLocalVar FSLIT("l") lty
                rv <- newLocalVar FSLIT("r") rty
                liftM2 (un_plus expr ty lv rv)
                         (un_tup (Var lv) lty vars res)
                         (un_sum (Var rv) rty alts)
            where
              (lty, rty) = splitPlusTy ty

      un_sum scrut (exprType scrut) alts

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
228
229
230
231
232
233
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]

mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon

234
235
236
mkPReprType :: Type -> VM Type
mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
237
238
mkPADictType :: Type -> VM Type
mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
239
240

mkPArrayType :: Type -> VM Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
241
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
242

243
244
245
246
247
248
249
250
251
252
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])

parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
  = do
      (tc, arg_tys) <- parrayReprTyCon ty
      let [dc] = tyConDataCons tc
      return (dc, arg_tys)

253
254
255
256
257
258
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut (ve, le)
  = do
      (tc, arg_tys) <- parrayReprTyCon (exprType ve)
      return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)

259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
prDictOfType :: Type -> VM CoreExpr
prDictOfType orig_ty
  | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
  = do
      dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
      prDFunApply (Var dfun) ty_args

prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
prDFunApply dfun tys
  = do
      args <- mapM mkDFunArg arg_tys
      return $ mkApps mono_dfun args
  where
    mono_dfun    = mkTyApps dfun tys
    (arg_tys, _) = splitFunTys (exprType mono_dfun)

mkDFunArg :: Type -> VM CoreExpr
mkDFunArg ty
  | Just (tycon, [arg]) <- splitTyConApp_maybe ty

  = let name = tyConName tycon

        get_dict | name == paTyConName = paDictOfType
                 | name == prTyConName = prDictOfType
                 | otherwise           = pprPanic "mkDFunArg" (ppr ty)

    in get_dict arg

mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
289
290
291
292
293
294
295
296
297
298
299
prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
prCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      pr_tc <- builtin prTyCon

      let co = mkAppCoercion (mkTyConApp pr_tc [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
  where
    go ty k | Just k' <- kindView k = go ty k'
    go ty (FunTy k1 k2)
      = do
          tv   <- newTyVar FSLIT("a") k1
          mty1 <- go (TyVarTy tv) k1
          case mty1 of
            Just ty1 -> do
                          mty2 <- go (AppTy ty (TyVarTy tv)) k2
                          return $ fmap (ForAllTy tv . FunTy ty1) mty2
            Nothing  -> go ty k2

    go ty k
      | isLiftedTypeKind k
316
      = liftM Just (mkPADictType ty)
317
318
319

    go ty k = return Nothing

320
321
322
323
324
325
326
327
328
329
330
331
332
333
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
  where
    (ty_fn, ty_args) = splitAppTys ty

paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
  | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
  = do
      dfun <- maybeV (lookupTyVarPA tv)
      paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
  = do
334
      dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
335
      paDFunApply (Var dfun) ty_args
336
337
paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)

338
339
340
341
342
343
344
345
346
347
348
paDFunType :: TyCon -> VM Type
paDFunType tc
  = do
      margs <- mapM paDictArgType tvs
      res   <- mkPADictType (mkTyConApp tc arg_tys)
      return . mkForAllTys tvs
             $ mkFunTys [arg | Just arg <- margs] res
  where
    tvs = tyConTyVars tc
    arg_tys = mkTyVarTys tvs

349
350
351
352
353
354
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mkApps (mkTyApps dfun tys) dicts

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
355
356
357
358
359
360
361
paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
paMethod method ty
  = do
      fn   <- builtin method
      dict <- paDictOfType ty
      return $ mkApps (Var fn) [Type ty, dict]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
362
lengthPA :: CoreExpr -> VM CoreExpr
363
364
365
lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
  where
    ty = splitPArrayTy (exprType x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
366
367
368
369
370

replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
                          (paMethod replicatePAVar (exprType x))

371
372
373
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod emptyPAVar

374
375
376
377
378
379
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
  = do
      lc <- builtin liftingContext
      replicatePA (Var lc) x

380
381
382
383
384
385
386
387
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
  = do
      lty <- mkPArrayType vty
      vv  <- newLocalVar fs vty
      lv  <- newLocalVar fs lty
      return (vv,lv)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
388
389
390
391
polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
polyAbstract tvs p
  = localV
  $ do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
392
393
394
395
396
397
398
399
400
401
402
403
      mdicts <- mapM mk_dict_var tvs
      zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
      p (mk_lams mdicts)
  where
    mk_dict_var tv = do
                       r <- paDictArgType tv
                       case r of
                         Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
                         Nothing -> return Nothing

    mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
404
405
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
406
407
408
409
  = do
      dicts <- mapM paDictOfType tys
      return $ expr `mkTyApps` tys `mkApps` dicts

410
411
412
413
414
415
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
416
417
418
419
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
  env { global_bindings = (v,e) : global_bindings env }

420
421
422
423
hoistExpr :: FastString -> CoreExpr -> VM Var
hoistExpr fs expr
  = do
      var <- newLocalVar fs (exprType expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
424
      hoistBinding var expr
425
426
      return var

427
428
hoistVExpr :: VExpr -> VM VVar
hoistVExpr (ve, le)
429
  = do
430
      fs <- getBindName
431
432
433
      vv <- hoistExpr ('v' `consFS` fs) ve
      lv <- hoistExpr ('l' `consFS` fs) le
      return (vv, lv)
434

435
436
hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
hoistPolyVExpr tvs p
437
  = do
438
439
      expr <- closedV . polyAbstract tvs $ \abstract ->
              liftM (mapVect abstract) p
440
      fn   <- hoistVExpr expr
441
      polyVApply (vVar fn) (mkTyVarTys tvs)
442

443
444
445
446
447
448
449
takeHoisted :: VM [(Var, CoreExpr)]
takeHoisted
  = do
      env <- readGEnv id
      setGEnv $ env { global_bindings = [] }
      return $ global_bindings env

450
451
mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
452
  = do
453
454
455
456
457
      dict <- paDictOfType env_ty
      mkv  <- builtin mkClosureVar
      mkl  <- builtin mkClosurePVar
      return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
              Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
458

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
459
460
461
462
463
464
465
466
467
468
mkClosureApp :: VExpr -> VExpr -> VM VExpr
mkClosureApp (vclo, lclo) (varg, larg)
  = do
      vapply <- builtin applyClosureVar
      lapply <- builtin applyClosurePVar
      return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
              Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
  where
    (arg_ty, res_ty) = splitClosureTy (exprType vclo)

469
buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
470
471
buildClosures tvs vars [] res_ty mk_body
  = mk_body
472
473
474
buildClosures tvs vars [arg_ty] res_ty mk_body
  = buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
475
476
477
  = do
      res_ty' <- mkClosureTypes arg_tys res_ty
      arg <- newLocalVVar FSLIT("x") arg_ty
478
      buildClosure tvs vars arg_ty res_ty'
479
        . hoistPolyVExpr tvs
480
        $ do
481
482
            lc <- builtin liftingContext
            clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
483
484
            return $ vLams lc (vars ++ [arg]) clo

485
486
487
488
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
--   where
--     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
--     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
489
--
490
491
buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
buildClosure tvs vars arg_ty res_ty mk_body
492
  = do
493
      (env_ty, env, bind) <- buildEnv vars
494
495
      env_bndr <- newLocalVVar FSLIT("env") env_ty
      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
496

497
      fn <- hoistPolyVExpr tvs
498
          $ do
499
              lc    <- builtin liftingContext
500
501
              body  <- mk_body
              body' <- bind (vVar env_bndr)
502
                            (vVarApps lc body (vars ++ [arg_bndr]))
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
503
              return (vLamsWithoutLC [env_bndr, arg_bndr] body')
504

505
      mkClosure arg_ty res_ty env_ty fn env
506

507
508
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
buildEnv vvs
509
  = do
510
      lc <- builtin liftingContext
511
      let (ty, venv, vbind) = mkVectEnv tys vs
512
      (lenv, lbind) <- mkLiftEnv lc tys ls
513
      return (ty, (venv, lenv),
514
515
516
517
518
              \(venv,lenv) (vbody,lbody) ->
              do
                let vbody' = vbind venv vbody
                lbody' <- lbind lenv lbody
                return (vbody', lbody'))
519
  where
520
521
    (vs,ls) = unzip vvs
    tys     = map idType vs
522
523
524
525
526
527
528
529
530
531

mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
                        \env body -> Case env (mkWildId ty) (exprType body)
                                       [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
  where
    ty = mkCoreTupTy tys

532
mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
533
mkLiftEnv lc [ty] [v]
534
535
536
537
  = return (Var v, \env body ->
                   do
                     len <- lengthPA (Var v)
                     return . Let (NonRec v env)
538
                            $ Case len lc (exprType body) [(DEFAULT, [], body)])
539
540

-- NOTE: this transparently deals with empty environments
541
mkLiftEnv lc tys vs
542
  = do
543
      (env_tc, env_tyargs) <- parrayReprTyCon vty
544
545
546
      let [env_con] = tyConDataCons env_tc
          
          env = Var (dataConWrapId env_con)
547
                `mkTyApps`  env_tyargs
548
                `mkVarApps` (lc : vs)
549

550
551
          bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                          in
552
553
                          return $ Case scrut (mkWildId (exprType scrut))
                                        (exprType body)
554
                                        [(DataAlt env_con, lc : bndrs, body)]
555
556
557
558
      return (env, bind)
  where
    vty = mkCoreTupTy tys

559
560
561
    bndrs | null vs   = [mkWildId unitTy]
          | otherwise = vs