Env.hs 6.44 KB
Newer Older
Ian Lynagh's avatar
Ian Lynagh committed
1 2
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -XNoMonoLocalBinds #-}
3 4
-- Roman likes local bindings
-- If this module lives on I'd like to get rid of this flag in due course
5

6
module Vectorise.Type.Env ( 
7 8
	vectTypeEnv,
)
9
where
10
import Vectorise.Env
11
import Vectorise.Vect
12 13
import Vectorise.Monad
import Vectorise.Builtins
14
import Vectorise.Type.TyConDecl
15
import Vectorise.Type.Classify
16
import Vectorise.Type.PADict
17 18 19
import Vectorise.Type.PData
import Vectorise.Type.PRepr
import Vectorise.Type.Repr
20
import Vectorise.Utils
21

22
import HscTypes
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23
import CoreSyn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
24
import CoreUtils
25
import CoreUnfold
26
import DataCon
27 28
import TyCon
import Type
29
import FamInstEnv
30
import OccName
31
import Id
32
import MkId
33
import Var
34
import NameEnv
35

36
import Unique
37
import UniqFM
38
import Util
39
import Outputable
40
import FastString
41 42
import MonadUtils
import Control.Monad
43 44 45 46
import Data.List

debug		= False
dtrace s x	= if debug then pprTrace "VectType" s x else x
47

48 49
-- | Vectorise a type environment.
--   The type environment contains all the type things defined in a module.
50 51 52 53 54 55
vectTypeEnv 
	:: TypeEnv
	-> VM ( TypeEnv			-- Vectorised type environment.
	      , [FamInst]		-- New type family instances.
	      , [(Var, CoreExpr)])	-- New top level bindings.
	
56
vectTypeEnv env
57 58
 = dtrace (ppr env)
 $ do
59
      cs <- readGEnv $ mk_map . global_tycons
60 61 62 63

      -- Split the list of TyCons into the ones we have to vectorise vs the
      -- ones we can pass through unchanged. We also pass through algebraic 
      -- types that use non Haskell98 features, as we don't handle those.
64 65 66
      let tycons               = typeEnvTyCons env
          groups               = tyConGroups tycons

67
      let (conv_tcs, keep_tcs) = classifyTyCons cs groups
68
          orig_tcs             = keep_tcs ++ conv_tcs
69
          keep_dcs             = concatMap tyConDataCons keep_tcs
70

71
      -- Just use the unvectorised versions of these constructors in vectorised code.
72 73
      zipWithM_ defTyCon   keep_tcs keep_tcs
      zipWithM_ defDataCon keep_dcs keep_dcs
74

75 76
      -- Vectorise all the declarations.
      new_tcs      <- vectTyConDecls conv_tcs
77 78 79 80 81

      -- We don't need to make new representation types for dictionary
      -- constructors. The constructors are always fully applied, and we don't 
      -- need to lift them to arrays as a dictionary of a particular type
      -- always has the same value.
82 83
      let vect_tcs  = filter (not . isClassTyCon) 
                    $ keep_tcs ++ new_tcs
84

85 86 87 88 89 90 91
      reprs <- mapM tyConRepr vect_tcs
      repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
      pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
      updGEnv $ extendFamEnv
              $ map mkLocalFamInst
              $ repr_tcs ++ pdata_tcs

92 93 94
      -- Create PRepr and PData instances for the vectorised types.
      -- We get back the binds for the instance functions, 
      -- and some new type constructors for the representation types.
95 96 97
      (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
        do
          defTyConPAs (zipLazy vect_tcs dfuns')
98 99 100 101 102 103 104 105 106 107 108
          reprs     <- mapM tyConRepr vect_tcs

          dfuns     <- sequence 
                    $  zipWith5 buildTyConBindings
                               orig_tcs
                               vect_tcs
                               repr_tcs
                               pdata_tcs
                               reprs

          binds     <- takeHoisted
109 110
          return (dfuns, binds, repr_tcs ++ pdata_tcs)

111 112
      -- The new type constructors are the vectorised versions of the originals, 
      -- plus the new type constructors that we use for the representations.
113
      let all_new_tcs = new_tcs ++ inst_tcs
114

115 116 117 118
      let new_env     =  extendTypeEnvList env
                      $  map ATyCon all_new_tcs
                      ++ [ADataCon dc | tc <- all_new_tcs
                                      , dc <- tyConDataCons tc]
119

120
      return (new_env, map mkLocalFamInst inst_tcs, binds)
121

122
   where
123 124 125
    mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]


126

127
buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
128
buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
129
 = do vectDataConWorkers orig_tc vect_tc pdata_tc
130
      buildPADict vect_tc prepr_tc pdata_tc repr
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
131

132

133 134
vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
vectDataConWorkers orig_tc vect_tc arr_tc
135
 = do bs <- sequence
136 137 138
          . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
          $ zipWith4 mk_data_con (tyConDataCons vect_tc)
                                 rep_tys
139 140
                                 (inits rep_tys)
                                 (tail $ tails rep_tys)
141
      mapM_ (uncurry hoistBinding) bs
142
 where
143 144 145 146 147
    tyvars   = tyConTyVars vect_tc
    var_tys  = mkTyVarTys tyvars
    ty_args  = map Type var_tys
    res_ty   = mkTyConApp vect_tc var_tys

148 149 150 151
    cons     = tyConDataCons vect_tc
    arity    = length cons
    [arr_dc] = tyConDataCons arr_tc

152 153 154 155 156
    rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc


    mk_data_con con tys pre post
      = liftM2 (,) (vect_data_con con)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
157
                   (lift_data_con tys pre post (mkDataConTag con))
158

159 160 161 162 163 164 165
    sel_replicate len tag
      | arity > 1 = do
                      rep <- builtin (selReplicate arity)
                      return [rep `mkApps` [len, tag]]

      | otherwise = return []

166
    vect_data_con con = return $ mkConApp con ty_args
167
    lift_data_con tys pre_tys post_tys tag
168 169
      = do
          len  <- builtin liftingContext
Ian Lynagh's avatar
Ian Lynagh committed
170
          args <- mapM (newLocalVar (fsLit "xs"))
171
                  =<< mapM mkPDataType tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
172

173
          sel  <- sel_replicate (Var len) tag
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
174

175 176
          pre   <- mapM emptyPD (concat pre_tys)
          post  <- mapM emptyPD (concat post_tys)
177 178 179 180

          return . mkLams (len : args)
                 . wrapFamInstBody arr_tc var_tys
                 . mkConApp arr_dc
181
                 $ ty_args ++ sel ++ pre ++ map Var args ++ post
182 183 184

    def_worker data_con arg_tys mk_body
      = do
185
          arity <- polyArity tyvars
186 187
          body <- closedV
                . inBind orig_worker
188 189
                . polyAbstract tyvars $ \args ->
                  liftM (mkLams (tyvars ++ args) . vectorised)
190 191
                $ buildClosures tyvars [] arg_tys res_ty mk_body

192 193
          raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
          let vect_worker = raw_worker `setIdUnfolding`
194
                              mkInlineUnfolding (Just arity) body
195 196 197 198 199
          defGlobalVar orig_worker vect_worker
          return (vect_worker, body)
      where
        orig_worker = dataConWorkId data_con