VectUtils.hs 19.4 KB
Newer Older
1
module VectUtils (
2
  collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3
  collectAnnValBinders,
4
  mkDataConTag,
5
  splitClosureTy,
6
7
8

  TyConRepr(..), mkTyConRepr,
  mkToPRepr, mkToArrPRepr, mkFromPRepr, mkFromArrPRepr,
9
  mkPADictType, mkPArrayType, mkPReprType,
10

11
  parrayCoerce, parrayReprTyCon, parrayReprDataCon, mkVScrut,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
12
  prDictOfType, prCoerce,
13
  paDictArgType, paDictOfType, paDFunType,
14
  paMethod, lengthPA, replicatePA, emptyPA, liftPA,
15
  polyAbstract, polyApply, polyVApply,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
16
  hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
17
18
  buildClosure, buildClosures,
  mkClosureApp
19
20
21
22
) where

#include "HsVersions.h"

23
import VectCore
24
25
import VectMonad

26
import DsUtils
27
import CoreSyn
28
import CoreUtils
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Coercion
30
31
import Type
import TypeRep
32
import TyCon
33
import DataCon
34
import Var
35
36
import Id                 ( mkWildId )
import MkId               ( unwrapFamInstScrut )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37
import Name               ( Name )
38
import PrelNames
39
40
import TysWiredIn
import BasicTypes         ( Boxity(..) )
41

42
import Outputable
43
import FastString
44

45
import Data.List             ( zipWith4 )
46
import Control.Monad         ( liftM, liftM2, zipWithM_ )
47

48
49
50
51
52
53
54
55
56
57
58
59
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
collectAnnTypeArgs expr = go expr []
  where
    go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
    go e                             tys = (e, tys)

collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnTypeBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
    go bs e                           = (reverse bs, e)

60
61
62
63
64
65
collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
collectAnnValBinders expr = go [] expr
  where
    go bs (_, AnnLam b e) | isId b = go (b:bs) e
    go bs e                        = (reverse bs, e)

66
67
68
69
isAnnTypeArg :: AnnExpr b ann -> Bool
isAnnTypeArg (_, AnnType t) = True
isAnnTypeArg _              = False

70
71
72
mkDataConTag :: DataCon -> CoreExpr
mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
73
74
75
76
77
splitUnTy :: String -> Name -> Type -> Type
splitUnTy s name ty
  | Just (tc, [ty']) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = ty'
78

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
79
  | otherwise = pprPanic s (ppr ty)
80

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
81
82
83
84
85
splitBinTy :: String -> Name -> Type -> (Type, Type)
splitBinTy s name ty
  | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
  , tyConName tc == name
  = (ty1, ty2)
86

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
87
  | otherwise = pprPanic s (ppr ty)
88

89
90
91
92
93
splitFixedTyConApp :: TyCon -> Type -> [Type]
splitFixedTyConApp tc ty
  | Just (tc', tys) <- splitTyConApp_maybe ty
  , tc == tc'
  = tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
94

95
  | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
96
97
98
99
100
101

splitClosureTy :: Type -> (Type, Type)
splitClosureTy = splitBinTy "splitClosureTy" closureTyConName

splitPArrayTy :: Type -> Type
splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
103
104
mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
mkBuiltinTyConApp get_tc tys
105
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
106
107
      tc <- builtin get_tc
      return $ mkTyConApp tc tys
108

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
109
110
mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
mkBuiltinTyConApps get_tc tys ty
111
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
112
113
      tc <- builtin get_tc
      return $ foldr (mk tc) ty tys
114
  where
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
115
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
117
118
119
mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
mkBuiltinTyConApps1 get_tc dft [] = return dft
mkBuiltinTyConApps1 get_tc dft tys
120
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
121
122
123
124
125
126
127
      tc <- builtin get_tc
      case tys of
        [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
        _  -> return $ foldr1 (mk tc) tys
  where
    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]

128
129
130
131
132
133
134
135
136
137
138
139
140
141
data TyConRepr = TyConRepr {
                   repr_tyvars      :: [TyVar]
                 , repr_tys         :: [[Type]]

                 , repr_prod_tycons :: [Maybe TyCon]
                 , repr_prod_tys    :: [Type]
                 , repr_sum_tycon   :: Maybe TyCon
                 , repr_type        :: Type
                 }

mkTyConRepr :: TyCon -> VM TyConRepr
mkTyConRepr vect_tc
  = do
      prod_tycons <- mapM (mk_tycon prodTyCon) rep_tys
142
143
      let prod_tys = zipWith mk_tc_app_maybe prod_tycons rep_tys
      sum_tycon   <- mk_tycon sumTyCon prod_tys
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

      return $ TyConRepr {
                 repr_tyvars      = tyvars
               , repr_tys         = rep_tys

               , repr_prod_tycons = prod_tycons
               , repr_prod_tys    = prod_tys
               , repr_sum_tycon   = sum_tycon
               , repr_type        = mk_tc_app_maybe sum_tycon prod_tys
               }
  where
    tyvars = tyConTyVars vect_tc
    data_cons = tyConDataCons vect_tc
    rep_tys   = map dataConRepArgTys data_cons

    mk_tycon get_tc tys
      | n > 1     = builtin (Just . get_tc n)
      | otherwise = return Nothing
      where n = length tys

    mk_tc_app_maybe Nothing   []   = unitTy
    mk_tc_app_maybe Nothing   [ty] = ty
    mk_tc_app_maybe (Just tc) tys  = mkTyConApp tc tys

{-
169
170
mkPRepr :: [[Type]] -> VM Type
mkPRepr tys
171
  = do
172
173
174
      embed_tc <- builtin embedTyCon
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
175

176
177
178
      let mk_sum []   = unitTy
          mk_sum [ty] = ty
          mk_sum tys  = mkTyConApp (sum_tcs $ length tys) tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
179

180
181
182
          mk_prod []   = unitTy
          mk_prod [ty] = ty
          mk_prod tys  = mkTyConApp (prod_tcs $ length tys) tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
183

184
          mk_embed ty = mkTyConApp embed_tc [ty]
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
185

186
      return . mk_sum
187
             . map (mk_prod . map mk_embed)
188
             $ tys
189
-}
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
mkToPRepr :: TyConRepr -> [[CoreExpr]] -> [CoreExpr]
mkToPRepr (TyConRepr {
             repr_tys         = repr_tys
           , repr_prod_tycons = prod_tycons
           , repr_prod_tys    = prod_tys
           , repr_sum_tycon   = repr_sum_tycon
           })
  = mk_sum . zipWith3 mk_prod prod_tycons repr_tys
  where
    Just sum_tycon = repr_sum_tycon

    mk_sum []     = [Var unitDataConId]
    mk_sum [expr] = [expr]
    mk_sum exprs  = zipWith (mk_alt prod_tys) (tyConDataCons sum_tycon) exprs

    mk_alt tys dc expr = mk_con_app dc tys [expr]

    mk_prod _         _   []     = Var unitDataConId
    mk_prod _         _   [expr] = expr
    mk_prod (Just tc) tys exprs  = mk_con_app dc tys exprs
      where
        [dc] = tyConDataCons tc
213

214
    mk_con_app dc tys exprs = mkConApp dc (map Type tys ++ exprs)
215

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
mkToArrPRepr :: CoreExpr -> CoreExpr -> [[CoreExpr]] -> VM CoreExpr
mkToArrPRepr len sel ess
  = do
      let mk_sum [(expr, ty)] = return (expr, ty)
          mk_sum es
            = do
                sum_tc <- builtin . sumTyCon $ length es
                (sum_rtc, _) <- parrayReprTyCon (mkTyConApp sum_tc tys)
                let [sum_rdc] = tyConDataCons sum_rtc

                return (mkConApp sum_rdc (map Type tys ++ (len : sel : exprs)),
                        mkTyConApp sum_tc tys)
            where
              (exprs, tys) = unzip es

231
232
          mk_prod [expr] = return (expr, splitPArrayTy (exprType expr))
          mk_prod exprs
233
            = do
234
                prod_tc <- builtin . prodTyCon $ length exprs
235
236
237
238
239
240
                (prod_rtc, _) <- parrayReprTyCon (mkTyConApp prod_tc tys)
                let [prod_rdc] = tyConDataCons prod_rtc

                return (mkConApp prod_rdc (map Type tys ++ (len : exprs)),
                        mkTyConApp prod_tc tys)
            where
241
              tys = map (splitPArrayTy . exprType) exprs
242

243
      liftM fst (mk_sum =<< mapM mk_prod ess)
244

245
246
247
mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
mkFromPRepr scrut res_ty alts
  = do
248
249
      sum_tcs  <- builtins sumTyCon
      prod_tcs <- builtins prodTyCon
250

251
252
      let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
          un_sum expr ty bs
253
            = do
254
255
256
257
258
                ps     <- mapM (newLocalVar FSLIT("p")) tys
                bodies <- sequence
                        $ zipWith4 un_prod (map Var ps) tys vars rs
                return . Case expr (mkWildId ty) res_ty
                       $ zipWith3 mk_alt sum_dcs ps bodies
259
            where
260
261
262
263
              (vars, rs) = unzip bs
              tys        = splitFixedTyConApp sum_tc ty
              sum_tc     = sum_tcs $ length bs
              sum_dcs    = tyConDataCons sum_tc
264

265
              mk_alt dc p body = (DataAlt dc, [p], body)
266

267
          un_prod expr ty []    r = return r
268
          un_prod expr ty [var] r = return $ Let (NonRec var expr) r
269
          un_prod expr ty vars  r
270
271
            = return $ Case expr (mkWildId ty) res_ty
                       [(DataAlt prod_dc, vars, r)]
272
            where
273
274
275
              prod_tc   = prod_tcs $ length vars
              [prod_dc] = tyConDataCons prod_tc

276
277
      un_sum scrut (exprType scrut) alts

278
279
280
281
282
mkFromArrPRepr :: CoreExpr -> Type -> Var -> Var -> [[Var]] -> CoreExpr
               -> VM CoreExpr
mkFromArrPRepr scrut res_ty len sel vars res
  = return (Var unitDataConId)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
283
284
285
286
287
288
mkClosureType :: Type -> Type -> VM Type
mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]

mkClosureTypes :: [Type] -> Type -> VM Type
mkClosureTypes = mkBuiltinTyConApps closureTyCon

289
290
291
mkPReprType :: Type -> VM Type
mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
292
293
mkPADictType :: Type -> VM Type
mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
294
295

mkPArrayType :: Type -> VM Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
296
mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
297

298
299
300
301
302
303
304
305
306
307
308
parrayCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
parrayCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      parray <- builtin parrayTyCon

      let co = mkAppCoercion (mkTyConApp parray [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

309
310
311
312
313
314
315
316
317
318
parrayReprTyCon :: Type -> VM (TyCon, [Type])
parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])

parrayReprDataCon :: Type -> VM (DataCon, [Type])
parrayReprDataCon ty
  = do
      (tc, arg_tys) <- parrayReprTyCon ty
      let [dc] = tyConDataCons tc
      return (dc, arg_tys)

319
320
321
322
323
324
mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
mkVScrut (ve, le)
  = do
      (tc, arg_tys) <- parrayReprTyCon (exprType ve)
      return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
prDictOfType :: Type -> VM CoreExpr
prDictOfType orig_ty
  | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
  = do
      dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
      prDFunApply (Var dfun) ty_args

prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
prDFunApply dfun tys
  = do
      args <- mapM mkDFunArg arg_tys
      return $ mkApps mono_dfun args
  where
    mono_dfun    = mkTyApps dfun tys
    (arg_tys, _) = splitFunTys (exprType mono_dfun)

mkDFunArg :: Type -> VM CoreExpr
mkDFunArg ty
  | Just (tycon, [arg]) <- splitTyConApp_maybe ty

  = let name = tyConName tycon

        get_dict | name == paTyConName = paDictOfType
                 | name == prTyConName = prDictOfType
                 | otherwise           = pprPanic "mkDFunArg" (ppr ty)

    in get_dict arg

mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
355
356
357
358
359
360
361
362
363
364
365
prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
prCoerce repr_tc args expr
  | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
  = do
      pr_tc <- builtin prTyCon

      let co = mkAppCoercion (mkTyConApp pr_tc [])
                             (mkSymCoercion (mkTyConApp arg_co args))

      return $ mkCoerce co expr

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
paDictArgType :: TyVar -> VM (Maybe Type)
paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
  where
    go ty k | Just k' <- kindView k = go ty k'
    go ty (FunTy k1 k2)
      = do
          tv   <- newTyVar FSLIT("a") k1
          mty1 <- go (TyVarTy tv) k1
          case mty1 of
            Just ty1 -> do
                          mty2 <- go (AppTy ty (TyVarTy tv)) k2
                          return $ fmap (ForAllTy tv . FunTy ty1) mty2
            Nothing  -> go ty k2

    go ty k
      | isLiftedTypeKind k
382
      = liftM Just (mkPADictType ty)
383
384
385

    go ty k = return Nothing

386
387
388
389
390
391
392
393
394
395
396
397
398
399
paDictOfType :: Type -> VM CoreExpr
paDictOfType ty = paDictOfTyApp ty_fn ty_args
  where
    (ty_fn, ty_args) = splitAppTys ty

paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
paDictOfTyApp ty_fn ty_args
  | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
paDictOfTyApp (TyVarTy tv) ty_args
  = do
      dfun <- maybeV (lookupTyVarPA tv)
      paDFunApply dfun ty_args
paDictOfTyApp (TyConApp tc _) ty_args
  = do
400
      dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
401
      paDFunApply (Var dfun) ty_args
402
403
paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)

404
405
406
407
408
409
410
411
412
413
414
paDFunType :: TyCon -> VM Type
paDFunType tc
  = do
      margs <- mapM paDictArgType tvs
      res   <- mkPADictType (mkTyConApp tc arg_tys)
      return . mkForAllTys tvs
             $ mkFunTys [arg | Just arg <- margs] res
  where
    tvs = tyConTyVars tc
    arg_tys = mkTyVarTys tvs

415
416
417
418
419
420
paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
paDFunApply dfun tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mkApps (mkTyApps dfun tys) dicts

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
421
422
423
424
425
426
427
paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
paMethod method ty
  = do
      fn   <- builtin method
      dict <- paDictOfType ty
      return $ mkApps (Var fn) [Type ty, dict]

428
429
430
mkPR :: Type -> VM CoreExpr
mkPR = paMethod mkPRVar

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
431
lengthPA :: CoreExpr -> VM CoreExpr
432
433
434
lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
  where
    ty = splitPArrayTy (exprType x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
435
436
437
438
439

replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
replicatePA len x = liftM (`mkApps` [len,x])
                          (paMethod replicatePAVar (exprType x))

440
441
442
emptyPA :: Type -> VM CoreExpr
emptyPA = paMethod emptyPAVar

443
444
445
446
447
448
liftPA :: CoreExpr -> VM CoreExpr
liftPA x
  = do
      lc <- builtin liftingContext
      replicatePA (Var lc) x

449
450
451
452
453
454
455
456
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
  = do
      lty <- mkPArrayType vty
      vv  <- newLocalVar fs vty
      lv  <- newLocalVar fs lty
      return (vv,lv)

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
457
458
459
460
polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
polyAbstract tvs p
  = localV
  $ do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
461
462
463
464
465
466
467
468
469
470
471
472
      mdicts <- mapM mk_dict_var tvs
      zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
      p (mk_lams mdicts)
  where
    mk_dict_var tv = do
                       r <- paDictArgType tv
                       case r of
                         Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
                         Nothing -> return Nothing

    mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
473
474
polyApply :: CoreExpr -> [Type] -> VM CoreExpr
polyApply expr tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
475
476
477
478
  = do
      dicts <- mapM paDictOfType tys
      return $ expr `mkTyApps` tys `mkApps` dicts

479
480
481
482
483
484
polyVApply :: VExpr -> [Type] -> VM VExpr
polyVApply expr tys
  = do
      dicts <- mapM paDictOfType tys
      return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
485
486
487
488
hoistBinding :: Var -> CoreExpr -> VM ()
hoistBinding v e = updGEnv $ \env ->
  env { global_bindings = (v,e) : global_bindings env }

489
490
491
492
hoistExpr :: FastString -> CoreExpr -> VM Var
hoistExpr fs expr
  = do
      var <- newLocalVar fs (exprType expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
493
      hoistBinding var expr
494
495
      return var

496
497
hoistVExpr :: VExpr -> VM VVar
hoistVExpr (ve, le)
498
  = do
499
      fs <- getBindName
500
501
502
      vv <- hoistExpr ('v' `consFS` fs) ve
      lv <- hoistExpr ('l' `consFS` fs) le
      return (vv, lv)
503

504
505
hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
hoistPolyVExpr tvs p
506
  = do
507
508
      expr <- closedV . polyAbstract tvs $ \abstract ->
              liftM (mapVect abstract) p
509
      fn   <- hoistVExpr expr
510
      polyVApply (vVar fn) (mkTyVarTys tvs)
511

512
513
514
515
516
517
518
takeHoisted :: VM [(Var, CoreExpr)]
takeHoisted
  = do
      env <- readGEnv id
      setGEnv $ env { global_bindings = [] }
      return $ global_bindings env

519
520
mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
521
  = do
522
523
524
525
526
      dict <- paDictOfType env_ty
      mkv  <- builtin mkClosureVar
      mkl  <- builtin mkClosurePVar
      return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
              Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
527

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
528
529
530
531
532
533
534
535
536
537
mkClosureApp :: VExpr -> VExpr -> VM VExpr
mkClosureApp (vclo, lclo) (varg, larg)
  = do
      vapply <- builtin applyClosureVar
      lapply <- builtin applyClosurePVar
      return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
              Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
  where
    (arg_ty, res_ty) = splitClosureTy (exprType vclo)

538
buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
539
540
buildClosures tvs vars [] res_ty mk_body
  = mk_body
541
542
543
buildClosures tvs vars [arg_ty] res_ty mk_body
  = buildClosure tvs vars arg_ty res_ty mk_body
buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
544
545
546
  = do
      res_ty' <- mkClosureTypes arg_tys res_ty
      arg <- newLocalVVar FSLIT("x") arg_ty
547
      buildClosure tvs vars arg_ty res_ty'
548
        . hoistPolyVExpr tvs
549
        $ do
550
551
            lc <- builtin liftingContext
            clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
552
553
            return $ vLams lc (vars ++ [arg]) clo

554
555
556
557
-- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
--   where
--     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
--     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
558
--
559
560
buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
buildClosure tvs vars arg_ty res_ty mk_body
561
  = do
562
      (env_ty, env, bind) <- buildEnv vars
563
564
      env_bndr <- newLocalVVar FSLIT("env") env_ty
      arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
565

566
      fn <- hoistPolyVExpr tvs
567
          $ do
568
              lc    <- builtin liftingContext
569
570
              body  <- mk_body
              body' <- bind (vVar env_bndr)
571
                            (vVarApps lc body (vars ++ [arg_bndr]))
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
572
              return (vLamsWithoutLC [env_bndr, arg_bndr] body')
573

574
      mkClosure arg_ty res_ty env_ty fn env
575

576
577
buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
buildEnv vvs
578
  = do
579
      lc <- builtin liftingContext
580
      let (ty, venv, vbind) = mkVectEnv tys vs
581
      (lenv, lbind) <- mkLiftEnv lc tys ls
582
      return (ty, (venv, lenv),
583
584
585
586
587
              \(venv,lenv) (vbody,lbody) ->
              do
                let vbody' = vbind venv vbody
                lbody' <- lbind lenv lbody
                return (vbody', lbody'))
588
  where
589
590
    (vs,ls) = unzip vvs
    tys     = map idType vs
591
592
593
594
595
596
597
598
599
600

mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
                        \env body -> Case env (mkWildId ty) (exprType body)
                                       [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
  where
    ty = mkCoreTupTy tys

601
mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
602
mkLiftEnv lc [ty] [v]
603
604
605
606
  = return (Var v, \env body ->
                   do
                     len <- lengthPA (Var v)
                     return . Let (NonRec v env)
607
                            $ Case len lc (exprType body) [(DEFAULT, [], body)])
608
609

-- NOTE: this transparently deals with empty environments
610
mkLiftEnv lc tys vs
611
  = do
612
      (env_tc, env_tyargs) <- parrayReprTyCon vty
613
614
615
      let [env_con] = tyConDataCons env_tc
          
          env = Var (dataConWrapId env_con)
616
                `mkTyApps`  env_tyargs
617
                `mkVarApps` (lc : vs)
618

619
620
          bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
                          in
621
622
                          return $ Case scrut (mkWildId (exprType scrut))
                                        (exprType body)
623
                                        [(DataAlt env_con, lc : bndrs, body)]
624
625
626
627
      return (env, bind)
  where
    vty = mkCoreTupTy tys

628
629
630
    bndrs | null vs   = [mkWildId unitTy]
          | otherwise = vs