VectUtils.hs 18.9 KB
Newer Older
1
module VectUtils (
2
  collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3
  collectAnnValBinders,
4
  mkDataConTag,
5
  splitClosureTy,
6
7
8

  TyConRepr(..), mkTyConRepr,
  mkToPRepr, mkToArrPRepr, mkFromPRepr, mkFromArrPRepr,
9
  mkPADictType, mkPArrayType, mkPReprType,
10

11
  parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
12
  prDictOfType, prCoerce,
13
  paDictArgType, paDictOfType, paDFunType,
14
  paMethod, lengthPA, replicatePA, emptyPA, liftPA,
15
  polyAbstract, polyApply, polyVApply,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
16
  hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
17
18
  buildClosure, buildClosures,
  mkClosureApp
19
20
21
22
) where

#include "HsVersions.h"

23
import VectCore
24
25
import VectMonad

26
import DsUtils
27
import CoreSyn
28
import CoreUtils
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Coercion
30
31
import Type
import TypeRep
32
import TyCon
33
import DataCon
34
import Var
35
36
import Id                 ( mkWildId )
import MkId               ( unwrapFamInstScrut )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37
import Name               ( Name )
38
import PrelNames
39
40
import TysWiredIn
import BasicTypes         ( Boxity(..) )
41

42
import Outputable
43
import FastString
44

45
import Data.List             ( zipWith4 )
46
import Control.Monad         ( liftM, liftM2, zipWithM_ )
47

48
49
50
51
52
53
54
55
56
57
58
59
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
  where
    go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
    go e                             tys = (e, tys)

collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
    go bs e                           = (reverse bs, e)

60
61
62
63
64
65
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isId b = go (b:bs) e
    go bs e                        = (reverse bs, e)

66
67
68
69
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType t) = True
isAnnTypeArg _              = False

70
71
72
mkDataConTag :: DataCon -> CoreExpr
mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
73
74
75
76
77
splitUnTy :: String -> Name -> Type -> Type
splitUnTy s name ty
  | Just (tc, [ty']) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = ty'
78

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
79
  | otherwise = pprPanic s (ppr ty)
80

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
81
82
83
84
85
splitBinTy :: String -> Name -> Type -> (Type, Type)
splitBinTy s name ty
  | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = (ty1, ty2)
86

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
87
  | otherwise = pprPanic s (ppr ty)
88

89
90
91
92
93
splitFixedTyConApp :: TyCon -> Type -> [Type]
splitFixedTyConApp tc ty
  | Just (tc', tys) <- splitTyConApp_maybe ty
  , tc == tc'
  = tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
94

95
  | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
96
97
98
99
100
101

splitClosureTy :: Type -> (Type, Type)
splitClosureTy = splitBinTy "splitClosureTy" closureTyConName

splitPArrayTy :: Type -> Type
splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
103
104
mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
mkBuiltinTyConApp get_tc tys
105
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
106
107
      tc <- builtin get_tc
      return $ mkTyConApp tc tys
108

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
109
110
mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
mkBuiltinTyConApps get_tc tys ty
111
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
112
113
      tc <- builtin get_tc
      return $ foldr (mk tc) ty tys
114
  where
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
115
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
117
118
119
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 get_tc dft [] = return dft
mkBuiltinTyConApps1 get_tc dft tys
120
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
121
122
123
124
125
126
127
      tc <- builtin get_tc
      case tys of
        [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
        _  -> return $ foldr1 (mk tc) tys
  where
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]

128
129
130
131
132
133
134
135
136
137
138
139
140
141
data TyConRepr = TyConRepr {
                   repr_tyvars      :: [TyVar]
                 , repr_tys         :: [[Type]]

                 , repr_prod_tycons :: [Maybe TyCon]
                 , repr_prod_tys    :: [Type]
                 , repr_sum_tycon   :: Maybe TyCon
                 , repr_type        :: Type
                 }

mkTyConRepr :: TyCon -> VM TyConRepr
mkTyConRepr vect_tc
  = do
      prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
142
143
      let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
      sum_tycon   <- mk_tycon sumTyCon prod_tys
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167

      return $ TyConRepr {
                 repr_tyvars      = tyvars
               , repr_tys         = rep_tys

               , repr_prod_tycons = prod_tycons
               , repr_prod_tys    = prod_tys
               , repr_sum_tycon   = sum_tycon
               , repr_type        = mk_tc_app_maybe sum_tycon prod_tys
               }
  where
    tyvars = tyConTyVars vect_tc
    data_cons = tyConDataCons vect_tc
    rep_tys   = map dataConRepArgTys data_cons

    mk_tycon get_tc tys
      | n > 1     = builtin (Just . get_tc n)
      | otherwise = return Nothing
      where n = length tys

    mk_tc_app_maybe Nothing   []   = unitTy
    mk_tc_app_maybe Nothing   [ty] = ty
    mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
mkToPRepr :: TyConRepr -> [[CoreExpr]] -> [CoreExpr]
mkToPRepr (TyConRepr {
             repr_tys         = repr_tys
           , repr_prod_tycons = prod_tycons
           , repr_prod_tys    = prod_tys
           , repr_sum_tycon   = repr_sum_tycon
           })
  = mk_sum . zipWith3 mk_prod prod_tycons repr_tys
  where
    Just sum_tycon = repr_sum_tycon

    mk_sum []     = [Var unitDataConId]
    mk_sum [expr] = [expr]
    mk_sum exprs  = zipWith (mk_alt prod_tys) (tyConDataCons sum_tycon) exprs

    mk_alt tys dc expr = mk_con_app dc tys [expr]

    mk_prod _         _   []     = Var unitDataConId
    mk_prod _         _   [expr] = expr
    mk_prod (Just tc) tys exprs  = mk_con_app dc tys exprs
      where
        [dc] = tyConDataCons tc
190

191
    mk_con_app dc tys exprs = mkConApp dc (map Type tys ++ exprs)
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
mkToArrPRepr len sel ess
  = do
      let mk_sum [(expr, ty)] = return (expr, ty)
          mk_sum es
            = do
                sum_tc <- builtin . sumTyCon $ length es
                (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
                let [sum_rdc] = tyConDataCons sum_rtc

                return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
                        mkTyConApp sum_tc tys)
            where
              (exprs, tys) = unzip es

208
209
          mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
          mk_prod exprs
210
            = do
211
                prod_tc <- builtin . prodTyCon $ length exprs
212
213
214
215
216
217
                (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
                let [prod_rdc] = tyConDataCons prod_rtc

                return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
                        mkTyConApp prod_tc tys)
            where
218
              tys = map (splitPArrayTy . exprType) exprs
219

220
      liftM fst (mk_sum =<< mapM mk_prod ess)
221

222
223
224
mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
mkFromPRepr scrut res_ty alts
  = do
225
226
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
227

228
229
      let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
          un_sum expr ty bs
230
            = do
231
232
233
234
235
                ps     <- mapM (newLocalVar FSLIT("p")) tys
                bodies <- sequence
                        $ zipWith4 un_prod (map Var ps) tys vars rs
                return . Case expr (mkWildId ty) res_ty
                       $ zipWith3 mk_alt sum_dcs ps bodies
236
            where
237
238
239
240
              (vars, rs) = unzip bs
              tys        = splitFixedTyConApp sum_tc ty
              sum_tc     = sum_tcs $ length bs
              sum_dcs    = tyConDataCons sum_tc
241

242
              mk_alt dc p body = (DataAlt dc, [p], body)
243

244
          un_prod expr ty []    r = return r
245
          un_prod expr ty [var] r = return $ Let (NonRec var expr) r
246
          un_prod expr ty vars  r
247
248
            = return $ Case expr (mkWildId ty) res_ty
                       [(DataAlt prod_dc, vars, r)]
249
            where
250
251
252
              prod_tc   = prod_tcs $ length vars
              [prod_dc] = tyConDataCons prod_tc

253
254
      un_sum scrut (exprType scrut) alts

255
256
257
258
259
mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
               -> VM CoreExpr
mkFromArrPRepr scrut res_ty len sel vars res
  = return (Var unitDataConId)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
260
261
262
263
264
265
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]

mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon

266
267
268
mkPReprType :: Type -> VM Type
mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
269
270
mkPADictType :: Type -> VM Type
mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
271
272

mkPArrayType :: Type -> VM Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
273
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
274

275
276
277
278
279
280
281
282
283
284
285
parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
parrayCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      parray <- builtin parrayTyCon

      let co = mkAppCoercion (mkTyConApp parray [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

286
287
288
289
290
291
292
293
294
295
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])

parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
  = do
      (tc, arg_tys) <- parrayReprTyCon ty
      let [dc] = tyConDataCons tc
      return (dc, arg_tys)

296
297
298
299
300
301
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut (ve, le)
  = do
      (tc, arg_tys) <- parrayReprTyCon (exprType ve)
      return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
prDictOfType :: Type -> VM CoreExpr
prDictOfType orig_ty
  | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
  = do
      dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
      prDFunApply (Var dfun) ty_args

prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
prDFunApply dfun tys
  = do
      args <- mapM mkDFunArg arg_tys
      return $ mkApps mono_dfun args
  where
    mono_dfun    = mkTyApps dfun tys
    (arg_tys, _) = splitFunTys (exprType mono_dfun)

mkDFunArg :: Type -> VM CoreExpr
mkDFunArg ty
  | Just (tycon, [arg]) <- splitTyConApp_maybe ty

  = let name = tyConName tycon

        get_dict | name == paTyConName = paDictOfType
                 | name == prTyConName = prDictOfType
                 | otherwise           = pprPanic "mkDFunArg" (ppr ty)

    in get_dict arg

mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
332
333
334
335
336
337
338
339
340
341
342
prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
prCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      pr_tc <- builtin prTyCon

      let co = mkAppCoercion (mkTyConApp pr_tc [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
  where
    go ty k | Just k' <- kindView k = go ty k'
    go ty (FunTy k1 k2)
      = do
          tv   <- newTyVar FSLIT("a") k1
          mty1 <- go (TyVarTy tv) k1
          case mty1 of
            Just ty1 -> do
                          mty2 <- go (AppTy ty (TyVarTy tv)) k2
                          return $ fmap (ForAllTy tv . FunTy ty1) mty2
            Nothing  -> go ty k2

    go ty k
      | isLiftedTypeKind k
359
      = liftM Just (mkPADictType ty)
360
361
362

    go ty k = return Nothing

363
364
365
366
367
368
369
370
371
372
373
374
375
376
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
  where
    (ty_fn, ty_args) = splitAppTys ty

paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
  | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
  = do
      dfun <- maybeV (lookupTyVarPA tv)
      paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
  = do
377
      dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
378
      paDFunApply (Var dfun) ty_args
379
380
paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)

381
382
383
384
385
386
387
388
389
390
391
paDFunType :: TyCon -> VM Type
paDFunType tc
  = do
      margs <- mapM paDictArgType tvs
      res   <- mkPADictType (mkTyConApp tc arg_tys)
      return . mkForAllTys tvs
             $ mkFunTys [arg | Just arg <- margs] res
  where
    tvs = tyConTyVars tc
    arg_tys = mkTyVarTys tvs

392
393
394
395
396
397
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mkApps (mkTyApps dfun tys) dicts

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
398
399
400
401
402
403
404
paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
paMethod method ty
  = do
      fn   <- builtin method
      dict <- paDictOfType ty
      return $ mkApps (Var fn) [Type ty, dict]

405
406
407
mkPR :: Type -> VM CoreExpr
mkPR = paMethod mkPRVar

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
408
lengthPA :: CoreExpr -> VM CoreExpr
409
410
411
lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
  where
    ty = splitPArrayTy (exprType x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
412
413
414
415
416

replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
                          (paMethod replicatePAVar (exprType x))

417
418
419
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod emptyPAVar

420
421
422
423
424
425
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
  = do
      lc <- builtin liftingContext
      replicatePA (Var lc) x

426
427
428
429
430
431
432
433
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
  = do
      lty <- mkPArrayType vty
      vv  <- newLocalVar fs vty
      lv  <- newLocalVar fs lty
      return (vv,lv)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
434
435
436
437
polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
polyAbstract tvs p
  = localV
  $ do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
438
439
440
441
442
443
444
445
446
447
448
449
      mdicts <- mapM mk_dict_var tvs
      zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
      p (mk_lams mdicts)
  where
    mk_dict_var tv = do
                       r <- paDictArgType tv
                       case r of
                         Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
                         Nothing -> return Nothing

    mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
450
451
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
452
453
454
455
  = do
      dicts <- mapM paDictOfType tys
      return $ expr `mkTyApps` tys `mkApps` dicts

456
457
458
459
460
461
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
462
463
464
465
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
  env { global_bindings = (v,e) : global_bindings env }

466
467
468
469
hoistExpr :: FastString -> CoreExpr -> VM Var
hoistExpr fs expr
  = do
      var <- newLocalVar fs (exprType expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
470
      hoistBinding var expr
471
472
      return var

473
474
hoistVExpr :: VExpr -> VM VVar
hoistVExpr (ve, le)
475
  = do
476
      fs <- getBindName
477
478
479
      vv <- hoistExpr ('v' `consFS` fs) ve
      lv <- hoistExpr ('l' `consFS` fs) le
      return (vv, lv)
480

481
482
hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
hoistPolyVExpr tvs p
483
  = do
484
485
      expr <- closedV . polyAbstract tvs $ \abstract ->
              liftM (mapVect abstract) p
486
      fn   <- hoistVExpr expr
487
      polyVApply (vVar fn) (mkTyVarTys tvs)
488

489
490
491
492
493
494
495
takeHoisted :: VM [(Var, CoreExpr)]
takeHoisted
  = do
      env <- readGEnv id
      setGEnv $ env { global_bindings = [] }
      return $ global_bindings env

496
497
mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
498
  = do
499
500
501
502
503
      dict <- paDictOfType env_ty
      mkv  <- builtin mkClosureVar
      mkl  <- builtin mkClosurePVar
      return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
              Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
504

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
505
506
507
508
509
510
511
512
513
514
mkClosureApp :: VExpr -> VExpr -> VM VExpr
mkClosureApp (vclo, lclo) (varg, larg)
  = do
      vapply <- builtin applyClosureVar
      lapply <- builtin applyClosurePVar
      return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
              Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
  where
    (arg_ty, res_ty) = splitClosureTy (exprType vclo)

515
buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
516
517
buildClosures tvs vars [] res_ty mk_body
  = mk_body
518
519
520
buildClosures tvs vars [arg_ty] res_ty mk_body
  = buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
521
522
523
  = do
      res_ty' <- mkClosureTypes arg_tys res_ty
      arg <- newLocalVVar FSLIT("x") arg_ty
524
      buildClosure tvs vars arg_ty res_ty'
525
        . hoistPolyVExpr tvs
526
        $ do
527
528
            lc <- builtin liftingContext
            clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
529
530
            return $ vLams lc (vars ++ [arg]) clo

531
532
533
534
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
--   where
--     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
--     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
535
--
536
537
buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
buildClosure tvs vars arg_ty res_ty mk_body
538
  = do
539
      (env_ty, env, bind) <- buildEnv vars
540
541
      env_bndr <- newLocalVVar FSLIT("env") env_ty
      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
542

543
      fn <- hoistPolyVExpr tvs
544
          $ do
545
              lc    <- builtin liftingContext
546
547
              body  <- mk_body
              body' <- bind (vVar env_bndr)
548
                            (vVarApps lc body (vars ++ [arg_bndr]))
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
549
              return (vLamsWithoutLC [env_bndr, arg_bndr] body')
550

551
      mkClosure arg_ty res_ty env_ty fn env
552

553
554
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
buildEnv vvs
555
  = do
556
      lc <- builtin liftingContext
557
      let (ty, venv, vbind) = mkVectEnv tys vs
558
      (lenv, lbind) <- mkLiftEnv lc tys ls
559
      return (ty, (venv, lenv),
560
561
562
563
564
              \(venv,lenv) (vbody,lbody) ->
              do
                let vbody' = vbind venv vbody
                lbody' <- lbind lenv lbody
                return (vbody', lbody'))
565
  where
566
567
    (vs,ls) = unzip vvs
    tys     = map idType vs
568
569
570
571
572
573
574
575
576
577

mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
                        \env body -> Case env (mkWildId ty) (exprType body)
                                       [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
  where
    ty = mkCoreTupTy tys

578
mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
579
mkLiftEnv lc [ty] [v]
580
581
582
583
  = return (Var v, \env body ->
                   do
                     len <- lengthPA (Var v)
                     return . Let (NonRec v env)
584
                            $ Case len lc (exprType body) [(DEFAULT, [], body)])
585
586

-- NOTE: this transparently deals with empty environments
587
mkLiftEnv lc tys vs
588
  = do
589
      (env_tc, env_tyargs) <- parrayReprTyCon vty
590
591
592
      let [env_con] = tyConDataCons env_tc
          
          env = Var (dataConWrapId env_con)
593
                `mkTyApps`  env_tyargs
594
                `mkVarApps` (lc : vs)
595

596
597
          bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                          in
598
599
                          return $ Case scrut (mkWildId (exprType scrut))
                                        (exprType body)
600
                                        [(DataAlt env_con, lc : bndrs, body)]
601
602
603
604
      return (env, bind)
  where
    vty = mkCoreTupTy tys

605
606
607
    bndrs | null vs   = [mkWildId unitTy]
          | otherwise = vs