Vectorise.hs 9.93 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures -fno-warn-unused-do-bind #-}
2

3
module Vectorise ( vectorise )
4 5
where

6 7 8
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
9
import Vectorise.Utils.Hoisting
10
import Vectorise.Exp
11
import Vectorise.Vect
12
import Vectorise.Env
13
import Vectorise.Monad
14

15
import HscTypes hiding      ( MonadThings(..) )
16
import CoreUnfold           ( mkInlineUnfolding )
17
import CoreFVs
18 19
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
20
import CoreMonad            ( CoreM, getHscEnv )
21
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
22
import Var
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
23
import Id
24
import OccName
25
import DynFlags
26
import BasicTypes           ( isLoopBreaker )
27
import Outputable
28
import Util                 ( zipLazy )
29 30
import MonadUtils

31
import Control.Monad
32 33 34


-- | Vectorise a single module.
35 36 37 38 39 40
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
41

42 43 44 45 46 47
-- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
48

49 50
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
51

52 53 54 55
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
56 57

-- | Vectorise a single module, in the VM monad.
58
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59
vectModule :: ModGuts -> VM ModGuts
60 61 62 63 64 65 66 67 68 69 70 71
vectModule guts@(ModGuts { mg_types     = types
                         , mg_binds     = binds
                         , mg_fam_insts = fam_insts
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
          -- Vectorise the type environment.
          -- This may add new TyCons and DataCons.
      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types

      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
72

73 74
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
75

76 77
          -- Vectorise all the top level bindings.
      ; binds'  <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
78

79 80 81 82 83 84
      ; return $ guts { mg_types        = types'
                      , mg_binds        = Rec tc_binds : binds'
                      , mg_fam_inst_env = fam_inst_env
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
                      }
      }
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
122
vectTopBind :: CoreBind -> VM CoreBind
123
vectTopBind b@(NonRec var expr)
124 125 126 127 128 129
 = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it to
          -- the vectorisation map.
      ; (inline, isScalar, expr') <- vectTopRhs [] var expr
      ; var' <- vectTopBinder var inline expr'
      ; when isScalar $ 
          addGlobalScalar var
130

131 132 133 134 135 136
          -- We replace the original top-level binding by a value projected from the vectorised
          -- closure and add any newly created hoisted top-level bindings.
      ; cexpr <- tryConvert var var' expr
      ; hs <- takeHoisted
      ; return . Rec $ (var, cexpr) : (var', expr') : hs
      }
137 138 139
  `orElseV`
    return b
vectTopBind b@(Rec bs)
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
 = let (vars, exprs) = unzip bs
   in
   do { (vars', _, exprs', hs) <- fixV $ 
          \ ~(_, inlines, rhss, _) ->
            do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings and
                   --  add them to the vectorisation map.
               ; vars' <- sequence [vectTopBinder var inline rhs
                                   | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
               ; hs <- takeHoisted
               ; if and areScalars
                 then      -- (1) Entire recursive group is scalar
                           --      => add all variables to the global set of scalars
                      do { mapM addGlobalScalar vars
                         ; return (vars', inlines, exprs', hs)
                         }
                 else      -- (2) At least one binding is not scalar
                           --     => vectorise again with empty set of local scalars
                      do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
                         ; hs <- takeHoisted
                         ; return (vars', inlines, exprs', hs)
                         }
               }
163
                      
164 165 166 167 168
          -- Replace the original top-level bindings by a values projected from the vectorised
          -- closures and add any newly created hoisted top-level bindings to the group.
      ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
      }
169
  `orElseV`
170
    return b    
171
    
172 173 174 175 176 177
-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
178 179 180 181 182
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
183
vectTopBinder var inline expr
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
          -- If there is a vectorisation declartion for this binding, make sure that its type
          --  matches
      ; vectDecl <- lookupVectDecl var
      ; case vectDecl of
          Nothing                 -> return ()
          Just (vdty, _) 
            | coreEqType vty vdty -> return ()
            | otherwise           -> 
              cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
                (text "Expected type" <+> ppr vty)
                $$
                (text "Inferred type" <+> ppr vdty)

          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
                $  cloneId mkVectOcc var vty

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
209 210
  where
    unfolding = case inline of
211
                  Inline arity -> mkInlineUnfolding (Just arity) expr
212
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
213

214
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
--
-- We need to distinguish three cases:
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--     vectorised code implemented by the user)
--     => no automatic vectorisation & instead use the user-supplied code
-- 
-- (2) We have a scalar vectorisation declaration for the variable
--     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
-- 
-- (3) There is no vectorisation declaration for the variable
--     => perform automatic vectorisation of the RHS
--
vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
           -> Var             -- ^ Name of the binding.
           -> CoreExpr        -- ^ Body of the binding.
           -> VM ( Inline     -- (1) inline specification for the binding
                 , Bool       -- (2) whether the right-hand side is a scalar computation
                 , CoreExpr)  -- (3) the vectorised right-hand side
234
vectTopRhs recFs var expr
235 236 237 238 239 240 241 242 243 244
  = closedV
  $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
  
       ; globalScalar <- isGlobalScalar var
       ; vectDecl     <- lookupVectDecl var
       ; rhs globalScalar vectDecl
       }
  where
    rhs _globalScalar (Just (_, expr'))               -- Case (1)
      = return (inlineMe, False, expr')
245 246 247 248 249
    rhs True          Nothing                         -- Case (2)
      = do { expr' <- vectScalarFun True recFs expr
           ; return (inlineMe, True, vectorised expr')
           }
    rhs False         Nothing                         -- Case (3)
250 251 252 253 254
      = do { let fvs = freeVars expr
           ; (inline, isScalar, vexpr) <- inBind var $
                                            vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs fvs
           ; return (inline, isScalar, vectorised vexpr)
           }
255 256

-- | Project out the vectorised version of a binding from some closure,
257 258 259 260 261 262
--   or return the original body if that doesn't work or the binding is scalar. 
--
tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
           -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^ The original body of the binding.
           -> VM CoreExpr
263
tryConvert var vect_var rhs
264 265 266 267 268 269 270
  = do { globalScalar <- isGlobalScalar var
       ; if globalScalar
         then
           return rhs
         else
           fromVect (idType var) (Var vect_var) `orElseV` return rhs
       }