Monad.hs 6.62 KB
Newer Older
1
module Vectorise.Monad (
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  module Vectorise.Monad.Base,
  module Vectorise.Monad.Naming,
  module Vectorise.Monad.Local,
  module Vectorise.Monad.Global,
  module Vectorise.Monad.InstEnv,
  initV,

  -- * Builtins
  liftBuiltinDs,
  builtin,
  builtins,
  
  -- * Variables
  lookupVar,
16
  lookupVar_maybe,
17
18
  addGlobalParallelVar, 
  addGlobalParallelTyCon, 
19
20
) where

21
22
23
24
25
26
27
28
import Vectorise.Monad.Base
import Vectorise.Monad.Naming
import Vectorise.Monad.Local
import Vectorise.Monad.Global
import Vectorise.Monad.InstEnv
import Vectorise.Builtins
import Vectorise.Env

29
30
import CoreSyn
import DsMonad
31
32
import HscTypes hiding ( MonadThings(..) )
import DynFlags
33
import MonadUtils (liftIO)
34
35
import InstEnv
import Class
36
37
import TyCon
import NameSet
38
import VarSet
39
import VarEnv
40
import Var
41
import Id
42
import Name
43
import ErrUtils
44
import Outputable
45
import Module
46

47

48
-- |Run a vectorisation computation.
49
50
51
52
53
54
55
--
initV :: HscEnv
      -> ModGuts
      -> VectInfo
      -> VM a
      -> IO (Maybe (VectInfo, a))
initV hsc_env guts info thing_inside
56
57
58
  = do { dumpIfVtTrace "Incoming VectInfo" (ppr info)

       ; let type_env = typeEnvFromEntities ids (mg_tcs guts) (mg_fam_insts guts)
59
       ; (_, Just res) <- initDs hsc_env (mg_module guts)
60
61
                                         (mg_rdr_env guts) type_env
                                         (mg_fam_inst_env guts) go
62
63
64
65
66
67
68
69

       ; case res of
           Nothing
             -> dumpIfVtTrace "Vectorisation FAILED!" empty
           Just (info', _)
             -> dumpIfVtTrace "Outgoing VectInfo" (ppr info')

       ; return res
70
       }
71
  where
72
73
74
    dflags = hsc_dflags hsc_env

    dumpIfVtTrace = dumpIfSet_dyn dflags Opt_D_dump_vt_trace
75
76
77
78
79
    
    bindsToIds (NonRec v _)   = [v]
    bindsToIds (Rec    binds) = map fst binds
    
    ids = concatMap bindsToIds (mg_binds guts)
80

81
    go 
82
83
      = do {   -- set up tables of builtin entities
           ; builtins        <- initBuiltins
84
           ; builtin_vars    <- initBuiltinVars builtins
85
86
87
88

               -- set up class and type family envrionments
           ; eps <- liftIO $ hscEPS hsc_env
           ; let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
89
90
91
                 instEnvs    = InstEnvs (eps_inst_env     eps)
                                        (mg_inst_env     guts)
                                        (mkModuleSet (dep_orphs (mg_deps guts)))
92
93
                 builtin_pas = initClassDicts instEnvs (paClass builtins)  -- grab all 'PA' and..
                 builtin_prs = initClassDicts instEnvs (prClass builtins)  -- ..'PR' class instances
94
95
96

               -- construct the initial global environment
           ; let genv = extendImportedVarsEnv builtin_vars
97
                        . setPAFunsEnv        builtin_pas
98
                        . setPRFunsEnv        builtin_prs
chak@cse.unsw.edu.au.'s avatar
chak@cse.unsw.edu.au. committed
99
                        $ initGlobalEnv (gopt Opt_VectorisationAvoidance dflags) 
100
                                        info (mg_vect_decls guts) instEnvs famInstEnvs
101
102
 
               -- perform vectorisation
103
           ; r <- runVM thing_inside builtins genv emptyLocalEnv
104
105
           ; case r of
               Yes genv _ x -> return $ Just (new_info genv, x)
106
107
               No reason    -> do { unqual <- mkPrintUnqualifiedDs
                                  ; liftIO $ 
108
                                      printOutputForUser dflags unqual $
109
110
111
                                        mkDumpDoc "Warning: vectorisation failure:" reason
                                  ; return Nothing
                                  }
112
113
114
           }

    new_info genv = modVectInfo genv ids (mg_tcs guts) (mg_vect_decls guts) info
115

116
117
118
    -- For a given DPH class, produce a mapping from type constructor (in head position) to the
    -- instance dfun for that type constructor and class.  (DPH class instances cannot overlap in
    -- head constructors.)
119
    --
120
    initClassDicts :: InstEnvs -> Class -> [(Name, Var)]
121
122
123
124
125
126
    initClassDicts insts cls = map find $ classInstances insts cls
      where
        find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
               | otherwise                       = pprPanic invalidInstance (ppr i)

    invalidInstance = "Invalid DPH instance (overlapping in head constructor)"
127
128
129


-- Builtins -------------------------------------------------------------------
130
131
132

-- |Lift a desugaring computation using the `Builtins` into the vectorisation monad.
--
133
134
135
liftBuiltinDs :: (Builtins -> DsM a) -> VM a
liftBuiltinDs p = VM $ \bi genv lenv -> do { x <- p bi; return (Yes genv lenv x)}

136
137
-- |Project something from the set of builtins.
--
138
139
140
builtin :: (Builtins -> a) -> VM a
builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi))

141
142
-- |Lift a function using the `Builtins` into the vectorisation monad.
--
143
144
145
146
147
builtins :: (a -> Builtins -> b) -> VM (a -> b)
builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi))


-- Var ------------------------------------------------------------------------
148

149
-- |Lookup the vectorised, and if local, also the lifted version of a variable.
150
151
152
153
--
-- * If it's in the global environment we get the vectorised version.
-- * If it's in the local environment we get both the vectorised and lifted version.
--
154
155
lookupVar :: Var -> VM (Scope Var (Var, Var))
lookupVar v
156
157
158
  = do { mb_res <- lookupVar_maybe v
       ; case mb_res of
           Just x  -> return x
Ian Lynagh's avatar
Ian Lynagh committed
159
160
161
           Nothing ->
               do dflags <- getDynFlags
                  dumpVar dflags v
162
163
164
165
166
167
168
169
170
       }

lookupVar_maybe :: Var -> VM (Maybe (Scope Var (Var, Var)))
lookupVar_maybe v
 = do { r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
      ; case r of
          Just e  -> return $ Just (Local e)
          Nothing -> fmap Global <$> (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
      }
171

Ian Lynagh's avatar
Ian Lynagh committed
172
173
dumpVar :: DynFlags -> Var -> a
dumpVar dflags var
174
  | Just _    <- isClassOpId_maybe var
Ian Lynagh's avatar
Ian Lynagh committed
175
  = cantVectorise dflags "ClassOpId not vectorised:" (ppr var)
176
  | otherwise
Ian Lynagh's avatar
Ian Lynagh committed
177
  = cantVectorise dflags "Variable not vectorised:" (ppr var)
178

179

180
-- Global parallel entities ----------------------------------------------------
181

182
-- |Mark the given variable as parallel — i.e., executing the associated code might involve
183
184
-- parallel array computations.
--
185
186
187
188
addGlobalParallelVar :: Var -> VM ()
addGlobalParallelVar var
  = do { traceVt "addGlobalParallelVar" (ppr var)
       ; updGEnv $ \env -> env{global_parallel_vars = extendVarSet (global_parallel_vars env) var}
189
       }
190

191
-- |Mark the given type constructor as parallel — i.e., its values might embed parallel arrays.
192
--
193
194
195
addGlobalParallelTyCon :: TyCon -> VM ()
addGlobalParallelTyCon tycon
  = do { traceVt "addGlobalParallelTyCon" (ppr tycon)
196
       ; updGEnv $ \env -> 
197
           env{global_parallel_tycons = extendNameSet (global_parallel_tycons env) (tyConName tycon)}
198
       }