Vectorise.hs 16.7 KB
Newer Older
1 2 3 4 5 6
-- Main entry point to the vectoriser.  It is invoked iff the option '-fvectorise' is passed.
--
-- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
-- It vectorises all type declarations and value bindings.  It also processes all VECTORISE pragmas
-- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
-- and the enrichment of imported functions with vectorised versions.
7

8
module Vectorise ( vectorise )
9 10
where

11 12 13
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
14
import Vectorise.Utils.Hoisting
15
import Vectorise.Exp
16
import Vectorise.Vect
17
import Vectorise.Env
18
import Vectorise.Monad
19

20
import HscTypes hiding      ( MonadThings(..) )
21
import CoreUnfold           ( mkInlineUnfolding )
22
import CoreFVs
23 24
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
25
import CoreMonad            ( CoreM, getHscEnv )
26
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
27
import Id
28
import DynFlags
29
import BasicTypes           ( isStrongLoopBreaker )
30
import Outputable
31
import Util                 ( zipLazy )
32 33
import MonadUtils

34
import Control.Monad
35
import Data.Maybe
36 37


38
-- |Vectorise a single module.
39 40 41 42 43 44
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
45

46
-- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
47 48 49 50 51
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
52

53 54
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
55

56 57 58 59
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
60

61
-- Vectorise a single module, in the VM monad.
62
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
63
vectModule :: ModGuts -> VM ModGuts
64
vectModule guts@(ModGuts { mg_tcs        = tycons
65 66 67
                         , mg_binds      = binds
                         , mg_fam_insts  = fam_insts
                         , mg_vect_decls = vect_decls
68 69 70 71
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
72 73 74 75
          -- Pick out all 'VECTORISE type' and 'VECTORISE class' pragmas
      ; let ty_vect_decls  = [vd | vd@(VectType _ _ _) <- vect_decls]
            cls_vect_decls = [vd | vd@(VectClass _)    <- vect_decls]
      
76 77 78 79 80 81
          -- Vectorise the type environment.  This will add vectorised
          -- type constructors, their representaions, and the
          -- conrresponding data constructors.  Moreover, we produce
          -- bindings for dfuns and family instances of the classes
          -- and type families used in the DPH library to represent
          -- array types.
82
      ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
83

84 85
          -- Family instance environment for /all/ home-package modules including those instances
          -- generated by 'vectTypeEnv'.
86
      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
87

88
          -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
89
          -- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
90 91
      ; let impBinds = [imp_id | Vect     imp_id _ <- vect_decls, isGlobalId imp_id] ++
                       [imp_id | VectInst imp_id   <- vect_decls, isGlobalId imp_id]
92
      ; binds_imp <- mapM vectImpBind impBinds
93
      ; binds_top <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
94

95
      ; return $ guts { mg_tcs          = tycons ++ new_tycons
96 97
                        -- we produce no new classes or instances, only new class type constructors
                        -- and dfuns
98
                      , mg_binds        = Rec tc_binds : (binds_top ++ binds_imp)
99 100 101 102
                      , mg_fam_inst_env = fam_inst_env
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
                      }
      }
103

104
-- Try to vectorise a top-level binding.  If it doesn't vectorise then return it unharmed.
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
--
-- For example, for the binding 
--
-- @  
--    foo :: Int -> Int
--    foo = \x -> x + x
-- @
--
-- we get
-- @
--    foo  :: Int -> Int
--    foo  = \x -> vfoo $: x                  
--
--    v_foo :: Closure void vfoo lfoo
--    v_foo = closure vfoo lfoo void        
--
--    vfoo :: Void -> Int -> Int
--    vfoo = ...
--
--    lfoo :: PData Void -> PData Int -> PData Int
--    lfoo = ...
-- @ 
--
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
-- function foo, but takes an explicit environment.
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
-- @v_foo@ combines both of these into a `Closure` that also contains the
-- environment.
--
-- The original binding @foo@ is rewritten to call the vectorised version
-- present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma.  If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
-- the pragma.  If only some bindings are annotated, a fatal error is being raised.
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
--   we may emit a warning and refrain from vectorising the entire group.
144
--
145
vectTopBind :: CoreBind -> VM CoreBind
146
vectTopBind b@(NonRec var expr)
147 148 149 150 151 152
  = unlessNoVectDecl $
      do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it
             -- to the vectorisation map.
         ; (inline, isScalar, expr') <- vectTopRhs [] var expr
         ; var' <- vectTopBinder var inline expr'
         ; when isScalar $ 
153
             addGlobalScalarVar var
154 155 156 157 158 159 160
 
             -- We replace the original top-level binding by a value projected from the vectorised
             -- closure and add any newly created hoisted top-level bindings.
         ; cexpr <- tryConvert var var' expr
         ; hs <- takeHoisted
         ; return . Rec $ (var, cexpr) : (var', expr') : hs
         }
161 162 163 164
     `orElseErrV`
     do { emitVt "  Could NOT vectorise top-level binding" $ ppr var
        ; return b
        }
165 166 167 168 169 170 171
  where
    unlessNoVectDecl vectorise
      = do { hasNoVectDecl <- noVectDecl var
           ; when hasNoVectDecl $
               traceVt "NOVECTORISE" $ ppr var
           ; if hasNoVectDecl then return b else vectorise
           }
172
vectTopBind b@(Rec bs)
173 174 175 176 177 178 179 180 181 182 183 184
  = unlessSomeNoVectDecl $
      do { (vars', _, exprs', hs) <- fixV $ 
             \ ~(_, inlines, rhss, _) ->
               do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings
                      -- and add them to the vectorisation map.
                  ; vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
                  ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
                  ; hs <- takeHoisted
                  ; if and areScalars
                    then      -- (1) Entire recursive group is scalar
                              --      => add all variables to the global set of scalars
185
                         do { mapM_ addGlobalScalarVar vars
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
                            ; return (vars', inlines, exprs', hs)
                            }
                    else      -- (2) At least one binding is not scalar
                              --     => vectorise again with empty set of local scalars
                         do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
                            ; hs <- takeHoisted
                            ; return (vars', inlines, exprs', hs)
                            }
                  }
                       
             -- Replace the original top-level bindings by a values projected from the vectorised
             -- closures and add any newly created hoisted top-level bindings to the group.
         ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
         ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
         }
201
     `orElseErrV`
202 203 204 205 206 207 208 209 210 211 212
       return b    
  where
    (vars, exprs) = unzip bs

    unlessSomeNoVectDecl vectorise
      = do { hasNoVectDecls <- mapM noVectDecl vars
           ; when (and hasNoVectDecls) $
               traceVt "NOVECTORISE" $ ppr vars
           ; if and hasNoVectDecls 
             then return b                              -- all bindings have 'NOVECTORISE'
             else if or hasNoVectDecls 
Ian Lynagh's avatar
Ian Lynagh committed
213 214
             then do dflags <- getDynFlags
                     cantVectorise dflags noVectoriseErr (ppr b)  -- some (but not all) have 'NOVECTORISE'
215 216 217
             else vectorise                             -- no binding has a 'NOVECTORISE' decl
           }
    noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
218 219 220 221

-- Add a vectorised binding to an imported top-level variable that has a VECTORISE [SCALAR] pragma
-- in this module.
--
222 223
-- RESTIRCTION: Currently, we cannot use the pragma vor mutually recursive definitions.
--
224 225 226 227 228
vectImpBind :: Id -> VM CoreBind
vectImpBind var
  = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it
           -- to the vectorisation map.  For the non-lifted version, we refer to the original
           -- definition — i.e., 'Var var'.
229 230 231 232 233 234 235 236 237 238
           -- NB: To support recursive definitions, we tie a lazy knot.
       ; (var', _, expr') <- fixV $
           \ ~(_, inline, rhs) ->
             do { var' <- vectTopBinder var inline rhs
                ; (inline, isScalar, expr') <- vectTopRhs [] var (Var var)

                ; when isScalar $ 
                    addGlobalScalarVar var
                ; return (var', inline, expr')
                }
239 240 241 242 243 244

           -- We add any newly created hoisted top-level bindings.
       ; hs <- takeHoisted
       ; return . Rec $ (var', expr') : hs
       }

245 246 247 248
-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
249 250
--   NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
--   used inside of 'fixV' in 'vectTopBind'.
251 252 253 254 255
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
256
vectTopBinder var inline expr
257 258 259 260
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
          -- If there is a vectorisation declartion for this binding, make sure that its type
261
          -- matches
262 263
      ; vectDecl <- lookupVectDecl var
      ; case vectDecl of
264
          Nothing             -> return ()
265
          Just (vdty, _) 
266
            | eqType vty vdty -> return ()
267
            | otherwise       -> 
Ian Lynagh's avatar
Ian Lynagh committed
268
              do dflags <- getDynFlags
Ian Lynagh's avatar
Ian Lynagh committed
269
                 cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
Ian Lynagh's avatar
Ian Lynagh committed
270 271 272
                   (text "Expected type" <+> ppr vty)
                   $$
                   (text "Inferred type" <+> ppr vdty)
273 274 275

          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
276
                $  mkVectId var vty
277 278 279 280 281 282

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
283 284
  where
    unfolding = case inline of
285
                  Inline arity -> mkInlineUnfolding (Just arity) expr
286
                  DontInline   -> noUnfolding
287 288 289 290 291 292 293 294 295 296 297 298
{-
!!!TODO: dfuns and unfoldings:
           -- Do not inline the dfun; instead give it a magic DFunFunfolding
           -- See Note [ClassOp/DFun selection]
           -- See also note [Single-method classes]
        dfun_id_w_fun
           | isNewTyCon class_tc
           = dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
           | otherwise
           = dfun_id `setIdUnfolding`  mkDFunUnfolding dfun_ty dfun_args
                     `setInlinePragma` dfunInlinePragma
 -}
Ian Lynagh's avatar
Ian Lynagh committed
299

300
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
301
--
302
-- We need to distinguish four cases:
303 304 305 306 307
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--     vectorised code implemented by the user)
--     => no automatic vectorisation & instead use the user-supplied code
-- 
308
-- (2) We have a scalar vectorisation declaration for a variable that is no dfun
309 310
--     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
-- 
311 312 313 314
-- (3) We have a scalar vectorisation declaration for a variable that *is* a dfun
--     => generate vectorised code according to the the "Note [Scalar dfuns]" below
-- 
-- (4) There is no vectorisation declaration for the variable
315 316
--     => perform automatic vectorisation of the RHS (the definition may or may not be a dfun;
--        vectorisation proceeds differently depending on which it is)
317
--
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
-- Note [Scalar dfuns]
-- ~~~~~~~~~~~~~~~~~~~
--
-- Here is the translation scheme for scalar dfuns — assume the instance declaration:
--
--   instance Num Int where
--     (+) = primAdd
--   {-# VECTORISE SCALAR instance Num Int #-}
--
-- It desugars to
--
--   $dNumInt :: Num Int
--   $dNumInt = D:Num primAdd
--
-- We vectorise it to
--
--   $v$dNumInt :: V:Num Int
--   $v$dNumInt = D:V:Num (closure2 ((+) $dNumInt) (scalar_zipWith ((+) $dNumInt))))
--
-- while adding the following entry to the vectorisation map: '$dNumInt' --> '$v$dNumInt'.
--
-- See "Note [Vectorising classes]" in 'Vectorise.Type.Env' for the definition of 'V:Num'.
--
-- NB: The outlined vectorisation scheme does not require the right-hand side of the original dfun.
--     In fact, we definitely want to refer to the dfn variable instead of the right-hand side to 
--     ensure that the dictionary selection rules fire.
--
345 346 347 348 349 350
vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
           -> Var             -- ^ Name of the binding.
           -> CoreExpr        -- ^ Body of the binding.
           -> VM ( Inline     -- (1) inline specification for the binding
                 , Bool       -- (2) whether the right-hand side is a scalar computation
                 , CoreExpr)  -- (3) the vectorised right-hand side
351
vectTopRhs recFs var expr
352
  = closedV
353
  $ do { globalScalar <- isGlobalScalarVar var
354
       ; vectDecl     <- lookupVectDecl var
Ian Lynagh's avatar
Ian Lynagh committed
355
       ; dflags       <- getDynFlags
356
       ; let isDFun = isDFunId var
357

Ian Lynagh's avatar
Ian Lynagh committed
358
       ; traceVt ("vectTopRhs of " ++ showPpr dflags var ++ info globalScalar isDFun vectDecl ++ ":") $
359
           ppr expr
360

361
       ; rhs globalScalar isDFun vectDecl
362 363
       }
  where
364
    rhs _globalScalar _isDFun (Just (_, expr'))               -- Case (1)
365
      = return (inlineMe, False, expr')
366
    rhs True          False   Nothing                         -- Case (2)
367
      = do { expr' <- vectScalarFun expr
368 369
           ; return (inlineMe, True, vectorised expr')
           }
370
    rhs True          True    Nothing                         -- Case (3)
371
      = do { expr' <- vectScalarDFun var
372 373
           ; return (DontInline, True, expr')
           }
374 375
    rhs False         False   Nothing                         -- Case (4) — not a dfun
      = do { let exprFvs = freeVars expr
376 377
           ; (inline, isScalar, vexpr) 
               <- inBind var $
378
                    vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs exprFvs Nothing
379 380
           ; return (inline, isScalar, vectorised vexpr)
           }
381 382 383 384 385
    rhs False         True    Nothing                         -- Case (4) — is a dfun
      = do { expr' <- vectDictExpr expr
           ; return  (DontInline, True, expr')
           }

386 387 388 389
    info True  False _                          = " [VECTORISE SCALAR]"
    info True  True  _                          = " [VECTORISE SCALAR instance]"
    info False _     vectDecl | isJust vectDecl = " [VECTORISE]"
                              | otherwise       = " (no pragma)"
390

391 392
-- |Project out the vectorised version of a binding from some closure,
-- or return the original body if that doesn't work or the binding is scalar. 
393 394 395 396 397
--
tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
           -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^ The original body of the binding.
           -> VM CoreExpr
398
tryConvert var vect_var rhs
399
  = do { globalScalar <- isGlobalScalarVar var
400 401 402 403
       ; if globalScalar
         then
           return rhs
         else
404 405 406 407 408
           fromVect (idType var) (Var vect_var) 
           `orElseErrV` 
           do { emitVt "  Could NOT call vectorised from original version" $ ppr var
              ; return rhs
              }
409
       }