Vectorise.hs 13.8 KB
Newer Older
1 2 3 4 5 6
-- Main entry point to the vectoriser.  It is invoked iff the option '-fvectorise' is passed.
--
-- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
-- It vectorises all type declarations and value bindings.  It also processes all VECTORISE pragmas
-- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
-- and the enrichment of imported functions with vectorised versions.
7

8
module Vectorise ( vectorise )
9 10
where

11 12 13
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
14
import Vectorise.Utils.Hoisting
15
import Vectorise.Exp
16
import Vectorise.Env
17
import Vectorise.Monad
18

19
import HscTypes hiding      ( MonadThings(..) )
20
import CoreUnfold           ( mkInlineUnfoldingWithArity )
21 22
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
23
import CoreMonad            ( CoreM, getHscEnv )
24
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
25
import Id
26
import DynFlags
27
import Outputable
28
import Util                 ( zipLazy )
29 30
import MonadUtils

31
import Control.Monad
32 33


34
-- |Vectorise a single module.
35 36 37 38 39 40
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
41

42
-- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
43 44 45 46 47
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
48

49 50
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
51

52 53 54 55
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
56

57
-- Vectorise a single module, in the VM monad.
58
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59
vectModule :: ModGuts -> VM ModGuts
60
vectModule guts@(ModGuts { mg_tcs        = tycons
61 62 63
                         , mg_binds      = binds
                         , mg_fam_insts  = fam_insts
                         , mg_vect_decls = vect_decls
64
                         })
65
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $
66
          pprCoreBindings binds
67

68
          -- Pick out all 'VECTORISE [SCALAR] type' and 'VECTORISE class' pragmas
69 70
      ; let ty_vect_decls  = [vd | vd@(VectType _ _ _) <- vect_decls]
            cls_vect_decls = [vd | vd@(VectClass _)    <- vect_decls]
71

72
          -- Vectorise the type environment.  This will add vectorised
Gabor Greif's avatar
Gabor Greif committed
73
          -- type constructors, their representations, and the
74 75 76 77
          -- conrresponding data constructors.  Moreover, we produce
          -- bindings for dfuns and family instances of the classes
          -- and type families used in the DPH library to represent
          -- array types.
78
      ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
79

80 81
          -- Family instance environment for /all/ home-package modules including those instances
          -- generated by 'vectTypeEnv'.
82
      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
83

84
          -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
85
          -- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
86
      ; let impBinds = [(imp_id, expr) | Vect imp_id expr <- vect_decls, isGlobalId imp_id]
87
      ; binds_imp <- mapM vectImpBind impBinds
88
      ; binds_top <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
89

90
      ; return $ guts { mg_tcs          = tycons ++ new_tycons
91 92
                        -- we produce no new classes or instances, only new class type constructors
                        -- and dfuns
93
                      , mg_binds        = Rec tc_binds : (binds_top ++ binds_imp)
94
                      , mg_fam_inst_env = fam_inst_env
95
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
96 97
                      }
      }
98

99 100
-- Try to vectorise a top-level binding.  If it doesn't vectorise, or if it is entirely scalar, then
-- omit vectorisation of that binding.
101
--
102
-- For example, for the binding
103
--
104
-- @
105 106 107 108 109 110 111
--    foo :: Int -> Int
--    foo = \x -> x + x
-- @
--
-- we get
-- @
--    foo  :: Int -> Int
112
--    foo  = \x -> vfoo $: x
113 114
--
--    v_foo :: Closure void vfoo lfoo
115
--    v_foo = closure vfoo lfoo void
116 117 118 119 120 121
--
--    vfoo :: Void -> Int -> Int
--    vfoo = ...
--
--    lfoo :: PData Void -> PData Int -> PData Int
--    lfoo = ...
122
-- @
123
--
124 125
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original function foo,
-- but takes an explicit environment.
126 127 128
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
129
-- @v_foo@ combines both of these into a `Closure` that also contains the environment.
130
--
131
-- The original binding @foo@ is rewritten to call the vectorised version present in the closure.
132
--
Gabor Greif's avatar
Gabor Greif committed
133
-- Vectorisation may be suppressed by annotating a binding with a 'NOVECTORISE' pragma.  If this
134
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
135 136 137
-- the pragma.  If only some bindings are annotated, a fatal error is being raised. (In the case of
-- scalar bindings, we only omit vectorisation if all bindings in a group are scalar.)
--
138 139
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
--   we may emit a warning and refrain from vectorising the entire group.
140
--
141
vectTopBind :: CoreBind -> VM CoreBind
142
vectTopBind b@(NonRec var expr)
143 144
  = do
    { traceVt "= Vectorise non-recursive top-level variable" (ppr var)
145

146 147 148 149 150 151 152
    ; (hasNoVect, vectDecl) <- lookupVectDecl var
    ; if hasNoVect
      then do
      {   -- 'NOVECTORISE' pragma => leave this binding as it is
      ; traceVt "NOVECTORISE" $ ppr var
      ; return b
      }
153
      else do
154 155 156 157 158
    { vectRhs <- case vectDecl of
        Just (_, expr') ->
            -- 'VECTORISE' pragma => just use the provided vectorised rhs
          do
          { traceVt "VECTORISE" $ ppr var
159
          ; addGlobalParallelVar var
160 161 162 163 164 165 166 167 168
          ; return $ Just (False, inlineMe, expr')
          }
        Nothing         ->
            -- no pragma => standard vectorisation of rhs
          do
          { traceVt "[Vanilla]" $ ppr var <+> char '=' <+> ppr expr
          ; vectTopExpr var expr
          }
    ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
169
    ; case vectRhs of
170 171
      { Nothing ->
          -- scalar binding => leave this binding as it is
172
          do
173 174 175
          { traceVt "scalar binding [skip]" $ ppr var
          ; return b
          }
176
      ; Just (parBind, inline, expr') -> do
177 178
    {
       -- vanilla case => create an appropriate top-level binding & add it to the vectorisation map
179
    ; when parBind $
180 181 182 183 184 185 186 187 188
        addGlobalParallelVar var
    ; var' <- vectTopBinder var inline expr'

        -- We replace the original top-level binding by a value projected from the vectorised
        -- closure and add any newly created hoisted top-level bindings.
    ; cexpr <- tryConvert var var' expr
    ; return . Rec $ (var, cexpr) : (var', expr') : hs
    } } } }
    `orElseErrV`
189
    do
190 191 192 193 194 195
    { emitVt "  Could NOT vectorise top-level binding" $ ppr var
    ; return b
    }
vectTopBind b@(Rec binds)
  = do
    { traceVt "= Vectorise recursive top-level variables" $ ppr vars
196

197 198
    ; vectDecls <- mapM lookupVectDecl vars
    ; let hasNoVects = map fst vectDecls
199
    ; if and hasNoVects
200 201 202 203 204
      then do
      {   -- 'NOVECTORISE' pragmas => leave this entire binding group as it is
      ; traceVt "NOVECTORISE" $ ppr vars
      ; return b
      }
205
      else do
206 207 208 209 210
    { if or hasNoVects
      then do
        {   -- Inconsistent 'NOVECTORISE' pragmas => bail out
        ; dflags <- getDynFlags
        ; cantVectorise dflags noVectoriseErr (ppr b)
211
        }
212
      else do
213
    { traceVt "[Vanilla]" $ vcat [ppr var <+> char '=' <+> ppr expr | (var, expr) <- binds]
214

215
       -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    ; newBindsWPragma  <- concat <$>
                          sequence [ vectTopBindAndConvert bind inlineMe expr'
                                   | (bind, (_, Just (_, expr'))) <- zip binds vectDecls]

        -- Standard vectorisation of all rhses that are *without* a pragma.
        -- NB: The reason for 'fixV' is rather subtle: 'vectTopBindAndConvert' adds entries for
        --     the bound variables in the recursive group to the vectorisation map, which in turn
        --     are needed by 'vectPolyExprs' (unless it returns 'Nothing').
    ; let bindsWOPragma = [bind | (bind, (_, Nothing)) <- zip binds vectDecls]
    ; (newBinds, _) <- fixV $
        \ ~(_, exprs') ->
          do
          {   -- Create appropriate top-level bindings, enter them into the vectorisation map, and
              -- vectorise the right-hand sides
          ; newBindsWOPragma <- concat <$>
231
                                sequence [vectTopBindAndConvert bind inline expr
232 233 234 235 236
                                         | (bind, ~(inline, expr)) <- zipLazy bindsWOPragma exprs']
                                         -- irrefutable pattern and 'zipLazy' to tie the knot;
                                         -- hence, can't use 'zipWithM'
          ; vectRhses <- vectTopExprs bindsWOPragma
          ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
237

238 239 240 241
          ; case vectRhses of
              Nothing ->
                -- scalar bindings => skip all bindings except those with pragmas and retract the
                --   entries into the vectorisation map for the scalar bindings
242
                do
243 244 245 246
                { traceVt "scalar bindings [skip]" $ ppr vars
                ; mapM_ (undefGlobalVar . fst) bindsWOPragma
                ; return (bindsWOPragma ++ newBindsWPragma, exprs')
                }
247
              Just (parBind, exprs') ->
248 249
                -- vanilla case => record parallel variables and return the final bindings
                do
250
                { when parBind $
251
                    mapM_ addGlobalParallelVar vars
252
                ; return (newBindsWOPragma ++ newBindsWPragma ++ hs, exprs')
253 254 255 256 257
                }
          }
    ; return $ Rec newBinds
    } } }
    `orElseErrV`
258
    do
259 260 261 262 263
    { emitVt "  Could NOT vectorise top-level bindings" $ ppr vars
    ; return b
    }
  where
    vars = map fst binds
264
    noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
265

266 267 268 269 270 271 272 273
    -- Replace the original top-level bindings by a values projected from the vectorised
    -- closures and add any newly created hoisted top-level bindings to the group.
    vectTopBindAndConvert (var, expr) inline expr'
      = do
        { var'  <- vectTopBinder var inline expr'
        ; cexpr <- tryConvert var var' expr
        ; return [(var, cexpr), (var', expr')]
        }
274

275
-- Add a vectorised binding to an imported top-level variable that has a VECTORISE pragma
276 277
-- in this module.
--
278
-- RESTRICTION: Currently, we cannot use the pragma for mutually recursive definitions.
279
--
280 281
vectImpBind :: (Id, CoreExpr) -> VM CoreBind
vectImpBind (var, expr)
282
  = do
283
    { traceVt "= Add vectorised binding to imported variable" (ppr var)
284

285 286 287
    ; var' <- vectTopBinder var inlineMe expr
    ; return $ NonRec var' expr
    }
288

289 290
-- |Make the vectorised version of this top level binder, and add the mapping between it and the
-- original to the state. For some binder @foo@ the vectorised version is @$v_foo@
291
--
292 293
-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is used inside of
--       'fixV' in 'vectTopBind'.
294 295 296 297 298
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
299
vectTopBinder var inline expr
300 301
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
302

303 304
          -- If there is a vectorisation declartion for this binding, make sure its type matches
      ; (_, vectDecl) <- lookupVectDecl var
305
      ; case vectDecl of
306
          Nothing             -> return ()
307
          Just (vdty, _)
308
            | eqType vty vdty -> return ()
309 310
            | otherwise       ->
              do
311 312 313 314 315 316
              { dflags <- getDynFlags
              ; cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
                  (text "Expected type" <+> ppr vty)
                  $$
                  (text "Inferred type" <+> ppr vdty)
              }
317
          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
318
      ; var' <- liftM (`setIdUnfolding` unfolding)
319
                $  mkVectId var vty
320 321 322 323 324 325

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
326 327
  where
    unfolding = case inline of
328
                  Inline arity -> mkInlineUnfoldingWithArity arity expr
329
                  DontInline   -> noUnfolding
330 331 332 333 334 335 336 337 338 339 340 341
{-
!!!TODO: dfuns and unfoldings:
           -- Do not inline the dfun; instead give it a magic DFunFunfolding
           -- See Note [ClassOp/DFun selection]
           -- See also note [Single-method classes]
        dfun_id_w_fun
           | isNewTyCon class_tc
           = dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
           | otherwise
           = dfun_id `setIdUnfolding`  mkDFunUnfolding dfun_ty dfun_args
                     `setInlinePragma` dfunInlinePragma
 -}
Ian Lynagh's avatar
Ian Lynagh committed
342

343 344
-- |Project out the vectorised version of a binding from some closure, or return the original body
-- if that doesn't work.
345
--
346 347 348
tryConvert :: Var       -- ^Name of the original binding (eg @foo@)
           -> Var       -- ^Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^The original body of the binding.
349
           -> VM CoreExpr
350
tryConvert var vect_var rhs
351 352 353
  = fromVect (idType var) (Var vect_var)
    `orElseErrV`
    do
354
    { emitVt "  Could NOT call vectorised from original version" $ ppr var <+> dcolon <+> ppr (idType var)
355 356
    ; return rhs
    }