Vectorise.hs 11.3 KB
Newer Older
1

2
module Vectorise ( vectorise )
3 4
where

5 6 7
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
8
import Vectorise.Utils.Hoisting
9
import Vectorise.Exp
10
import Vectorise.Vect
11
import Vectorise.Env
12
import Vectorise.Monad
13

14
import HscTypes hiding      ( MonadThings(..) )
15
import CoreUnfold           ( mkInlineUnfolding )
16
import CoreFVs
17 18
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
19
import CoreMonad            ( CoreM, getHscEnv )
20
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
21
import Id
22
import OccName
23
import DynFlags
24
import BasicTypes           ( isStrongLoopBreaker )
25
import Outputable
26
import Util                 ( zipLazy )
27 28
import MonadUtils

29
import Control.Monad
30 31 32


-- | Vectorise a single module.
33 34 35 36 37 38
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
39

40 41 42 43 44 45
-- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
46

47 48
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
49

50 51 52 53
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
54 55

-- | Vectorise a single module, in the VM monad.
56
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
57
vectModule :: ModGuts -> VM ModGuts
58 59 60 61 62 63 64 65 66 67 68 69
vectModule guts@(ModGuts { mg_types     = types
                         , mg_binds     = binds
                         , mg_fam_insts = fam_insts
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
          -- Vectorise the type environment.
          -- This may add new TyCons and DataCons.
      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types

      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
70

71 72
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
73

74 75
          -- Vectorise all the top level bindings.
      ; binds'  <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
76

77 78 79 80 81 82
      ; return $ guts { mg_types        = types'
                      , mg_binds        = Rec tc_binds : binds'
                      , mg_fam_inst_env = fam_inst_env
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
                      }
      }
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
-- |Try to vectorise a top-level binding.  If it doesn't vectorise then return it unharmed.
--
-- For example, for the binding 
--
-- @  
--    foo :: Int -> Int
--    foo = \x -> x + x
-- @
--
-- we get
-- @
--    foo  :: Int -> Int
--    foo  = \x -> vfoo $: x                  
--
--    v_foo :: Closure void vfoo lfoo
--    v_foo = closure vfoo lfoo void        
--
--    vfoo :: Void -> Int -> Int
--    vfoo = ...
--
--    lfoo :: PData Void -> PData Int -> PData Int
--    lfoo = ...
-- @ 
--
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original
-- function foo, but takes an explicit environment.
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
-- @v_foo@ combines both of these into a `Closure` that also contains the
-- environment.
--
-- The original binding @foo@ is rewritten to call the vectorised version
-- present in the closure.
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma.  If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
-- the pragma.  If only some bindings are annotated, a fatal error is being raised.
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
--   we may emit a warning and refrain from vectorising the entire group.
124
--
125
vectTopBind :: CoreBind -> VM CoreBind
126
vectTopBind b@(NonRec var expr)
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
  = unlessNoVectDecl $
      do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it
             -- to the vectorisation map.
         ; (inline, isScalar, expr') <- vectTopRhs [] var expr
         ; var' <- vectTopBinder var inline expr'
         ; when isScalar $ 
             addGlobalScalar var
 
             -- We replace the original top-level binding by a value projected from the vectorised
             -- closure and add any newly created hoisted top-level bindings.
         ; cexpr <- tryConvert var var' expr
         ; hs <- takeHoisted
         ; return . Rec $ (var, cexpr) : (var', expr') : hs
         }
     `orElseV`
       return b
  where
    unlessNoVectDecl vectorise
      = do { hasNoVectDecl <- noVectDecl var
           ; when hasNoVectDecl $
               traceVt "NOVECTORISE" $ ppr var
           ; if hasNoVectDecl then return b else vectorise
           }
150
vectTopBind b@(Rec bs)
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
  = unlessSomeNoVectDecl $
      do { (vars', _, exprs', hs) <- fixV $ 
             \ ~(_, inlines, rhss, _) ->
               do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings
                      -- and add them to the vectorisation map.
                  ; vars' <- sequence [vectTopBinder var inline rhs
                                      | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
                  ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
                  ; hs <- takeHoisted
                  ; if and areScalars
                    then      -- (1) Entire recursive group is scalar
                              --      => add all variables to the global set of scalars
                         do { mapM_ addGlobalScalar vars
                            ; return (vars', inlines, exprs', hs)
                            }
                    else      -- (2) At least one binding is not scalar
                              --     => vectorise again with empty set of local scalars
                         do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
                            ; hs <- takeHoisted
                            ; return (vars', inlines, exprs', hs)
                            }
                  }
                       
             -- Replace the original top-level bindings by a values projected from the vectorised
             -- closures and add any newly created hoisted top-level bindings to the group.
         ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
         ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
         }
     `orElseV`
       return b    
  where
    (vars, exprs) = unzip bs

    unlessSomeNoVectDecl vectorise
      = do { hasNoVectDecls <- mapM noVectDecl vars
           ; when (and hasNoVectDecls) $
               traceVt "NOVECTORISE" $ ppr vars
           ; if and hasNoVectDecls 
             then return b                              -- all bindings have 'NOVECTORISE'
             else if or hasNoVectDecls 
             then cantVectorise noVectoriseErr (ppr b)  -- some (but not all) have 'NOVECTORISE'
             else vectorise                             -- no binding has a 'NOVECTORISE' decl
           }
    noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
     
196 197 198 199
-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
200 201
--   NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is
--   used inside of 'fixV' in 'vectTopBind'.
202 203 204 205 206
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
207
vectTopBinder var inline expr
208 209 210 211 212 213 214 215 216
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
          -- If there is a vectorisation declartion for this binding, make sure that its type
          --  matches
      ; vectDecl <- lookupVectDecl var
      ; case vectDecl of
          Nothing                 -> return ()
          Just (vdty, _) 
217
            | eqType vty vdty -> return ()
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
            | otherwise           -> 
              cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
                (text "Expected type" <+> ppr vty)
                $$
                (text "Inferred type" <+> ppr vdty)

          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
                $  cloneId mkVectOcc var vty

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
233 234
  where
    unfolding = case inline of
235
                  Inline arity -> mkInlineUnfolding (Just arity) expr
236
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
237

238
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
--
-- We need to distinguish three cases:
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--     vectorised code implemented by the user)
--     => no automatic vectorisation & instead use the user-supplied code
-- 
-- (2) We have a scalar vectorisation declaration for the variable
--     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
-- 
-- (3) There is no vectorisation declaration for the variable
--     => perform automatic vectorisation of the RHS
--
vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
           -> Var             -- ^ Name of the binding.
           -> CoreExpr        -- ^ Body of the binding.
           -> VM ( Inline     -- (1) inline specification for the binding
                 , Bool       -- (2) whether the right-hand side is a scalar computation
                 , CoreExpr)  -- (3) the vectorised right-hand side
258
vectTopRhs recFs var expr
259 260 261 262 263 264 265 266 267 268
  = closedV
  $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
  
       ; globalScalar <- isGlobalScalar var
       ; vectDecl     <- lookupVectDecl var
       ; rhs globalScalar vectDecl
       }
  where
    rhs _globalScalar (Just (_, expr'))               -- Case (1)
      = return (inlineMe, False, expr')
269 270 271 272 273
    rhs True          Nothing                         -- Case (2)
      = do { expr' <- vectScalarFun True recFs expr
           ; return (inlineMe, True, vectorised expr')
           }
    rhs False         Nothing                         -- Case (3)
274 275
      = do { let fvs = freeVars expr
           ; (inline, isScalar, vexpr) <- inBind var $
276
                                          vectPolyExpr (isStrongLoopBreaker $ idOccInfo var) recFs fvs
277 278
           ; return (inline, isScalar, vectorised vexpr)
           }
279 280

-- | Project out the vectorised version of a binding from some closure,
281 282 283 284 285 286
--   or return the original body if that doesn't work or the binding is scalar. 
--
tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
           -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^ The original body of the binding.
           -> VM CoreExpr
287
tryConvert var vect_var rhs
288 289 290 291 292 293 294
  = do { globalScalar <- isGlobalScalar var
       ; if globalScalar
         then
           return rhs
         else
           fromVect (idType var) (Var vect_var) `orElseV` return rhs
       }