PADict.hs 4.89 KB
Newer Older
1

2
module Vectorise.Generic.PADict
3 4 5
  ( buildPADict
  ) where

6 7
import Vectorise.Monad
import Vectorise.Builtins
8 9
import Vectorise.Generic.Description
import Vectorise.Generic.PAMethods ( buildPAScAndMethods )
10
import Vectorise.Utils
11 12

import BasicTypes
13 14
import CoreSyn
import CoreUtils
15
import CoreUnfold
16
import Module
17
import TyCon
18
import CoAxiom
19
import Type
20 21 22
import Id
import Var
import Name
23
import FastString
24

25

26 27
-- |Build the PA dictionary function for some type and hoist it to top level.
--
28
-- The PA dictionary holds fns that convert values to and from their vectorised representations.
29
--
30
-- @Recall the definition:
31
--    class PR (PRepr a) => PA a where
32 33 34 35
--      toPRepr      :: a -> PRepr a
--      fromPRepr    :: PRepr a -> a
--      toArrPRepr   :: PData a -> PData (PRepr a)
--      fromArrPRepr :: PData (PRepr a) -> PData a
36 37
--      toArrPReprs   :: PDatas a         -> PDatas (PRepr a)
--      fromArrPReprs :: PDatas (PRepr a) -> PDatas a
38 39
--
-- Example:
40
--    df :: forall a. PR (PRepr a) -> PA a -> PA (T a)
41
--    df = /\a. \(c:PR (PRepr a)) (d:PA a). MkPA c ($PR_df a d) ($toPRepr a d) ...
42
--    $dPR_df :: forall a. PA a -> PR (PRepr (T a))
43
--    $dPR_df = ....
44 45 46
--    $toRepr :: forall a. PA a -> T a -> PRepr (T a)
--    $toPRepr = ...
-- The "..." stuff is filled in by buildPAScAndMethods
47 48 49 50
-- @
--
buildPADict
        :: TyCon        -- ^ tycon of the type being vectorised.
51
        -> CoAxiom Unbranched
52
                        -- ^ Coercion between the type and
53
                        --     its vectorised representation.
54 55 56 57
        -> TyCon        -- ^ PData  instance tycon
        -> TyCon        -- ^ PDatas instance tycon
        -> SumRepr      -- ^ representation used for the type being vectorised.
        -> VM Var       -- ^ name of the top-level dictionary function.
58

59
buildPADict vect_tc prepr_ax pdata_tc pdatas_tc repr
60 61 62
 = polyAbstract tvs $ \args ->    -- The args are the dictionaries we lambda abstract over; and they
                                  -- are put in the envt, so when we need a (PA a) we can find it in
                                  -- the envt; they don't include the silent superclass args yet
63
   do { mod <- liftDs getModule
64
      ; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name
65

66 67 68 69 70
          -- The superclass dictionary is a (silent) argument if the tycon is polymorphic...
      ; let mk_super_ty = do { r <- mkPReprType inst_ty
                             ; pr_cls <- builtin prClass
                             ; return $ mkClassPred pr_cls [r]
                             }
71
      ; super_tys  <- sequence [mk_super_ty | not (null tvs)]
72
      ; super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
73 74
      ; let val_args = super_args ++ args
            all_args = tvs ++ val_args
75

76 77 78
          -- ...it is constant otherwise
      ; super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_ax [] | null tvs]

79
          -- Get ids for each of the methods in the dictionary, including superclass
80
      ; paMethodBuilders <- buildPAScAndMethods
81
      ; method_ids       <- mapM (method val_args dfun_name) paMethodBuilders
82

83 84
          -- Expression to build the dictionary.
      ; pa_dc  <- builtin paDataCon
85 86
      ; let dict = mkLams all_args (mkConApp pa_dc con_args)
            con_args = Type inst_ty
87
                     : map Var super_args  -- the superclass dictionary is either
88 89
                    ++ super_consts        -- lambda-bound or constant
                    ++ map (method_call val_args) method_ids
90

91 92
          -- Build the type of the dictionary function.
      ; pa_cls <- builtin paClass
93
      ; let dfun_ty = mkInvForAllTys tvs
94
                    $ mkFunTys (map varType val_args)
batterseapower's avatar
batterseapower committed
95
                               (mkClassPred pa_cls [inst_ty])
96

97 98
          -- Set the unfolding for the inliner.
      ; raw_dfun <- newExportedVar dfun_name dfun_ty
99
      ; let dfun_unf = mkDFunUnfolding all_args pa_dc con_args
100 101
            dfun = raw_dfun `setIdUnfolding`  dfun_unf
                            `setInlinePragma` dfunInlinePragma
102

103 104 105 106
          -- Add the new binding to the top-level environment.
      ; hoistBinding dfun dict
      ; return dfun
      }
107
  where
108 109 110
    tvs          = tyConTyVars vect_tc
    arg_tys      = mkTyVarTys tvs
    inst_ty      = mkTyConApp vect_tc arg_tys
111
    vect_tc_name = getName vect_tc
112

113
    method args dfun_name (name, build)
114
     = localV
115
     $ do  expr     <- build vect_tc prepr_ax pdata_tc pdatas_tc repr
116 117 118
           let body = mkLams (tvs ++ args) expr
           raw_var  <- newExportedVar (method_name dfun_name name) (exprType body)
           let var  = raw_var
119 120
                      `setIdUnfolding` mkInlineUnfoldingWithArity
                                         (length args) body
121
                      `setInlinePragma` alwaysInlinePragma
122 123
           hoistBinding var body
           return var
124

125 126
    method_call args id        = mkApps (Var id) (map Type arg_tys ++ map Var args)
    method_name dfun_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)