Vectorise.hs 9.91 KB
Newer Older
1
{-# OPTIONS -fno-warn-missing-signatures -fno-warn-unused-do-bind #-}
2

3
module Vectorise ( vectorise )
4 5
where

6 7 8
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
9
import Vectorise.Utils.Hoisting
10
import Vectorise.Exp
11
import Vectorise.Vect
12
import Vectorise.Env
13
import Vectorise.Monad
14

15
import HscTypes hiding      ( MonadThings(..) )
16
import CoreUnfold           ( mkInlineUnfolding )
17
import CoreFVs
18 19
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
20
import CoreMonad            ( CoreM, getHscEnv )
21
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
22
import Id
23
import OccName
24
import DynFlags
25
import BasicTypes           ( isLoopBreaker )
26
import Outputable
27
import Util                 ( zipLazy )
28 29
import MonadUtils

30
import Control.Monad
31 32 33


-- | Vectorise a single module.
34 35 36 37 38 39
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
40

41 42 43 44 45 46
-- | Vectorise a single monad, given the dynamic compiler flags and HscEnv.
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
47

48 49
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
50

51 52 53 54
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
55 56

-- | Vectorise a single module, in the VM monad.
57
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
58
vectModule :: ModGuts -> VM ModGuts
59 60 61 62 63 64 65 66 67 68 69 70
vectModule guts@(ModGuts { mg_types     = types
                         , mg_binds     = binds
                         , mg_fam_insts = fam_insts
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
          -- Vectorise the type environment.
          -- This may add new TyCons and DataCons.
      ; (types', new_fam_insts, tc_binds) <- vectTypeEnv types

      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
71

72 73
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
74

75 76
          -- Vectorise all the top level bindings.
      ; binds'  <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
77

78 79 80 81 82 83
      ; return $ guts { mg_types        = types'
                      , mg_binds        = Rec tc_binds : binds'
                      , mg_fam_inst_env = fam_inst_env
                      , mg_fam_insts    = fam_insts ++ new_fam_insts
                      }
      }
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120

-- | Try to vectorise a top-level binding.
--   If it doesn't vectorise then return it unharmed.
--
--   For example, for the binding 
--
--   @  
--      foo :: Int -> Int
--      foo = \x -> x + x
--   @
--  
--   we get
--   @
--      foo  :: Int -> Int
--      foo  = \x -> vfoo $: x                  
-- 
--      v_foo :: Closure void vfoo lfoo
--      v_foo = closure vfoo lfoo void        
-- 
--      vfoo :: Void -> Int -> Int
--      vfoo = ...
--
--      lfoo :: PData Void -> PData Int -> PData Int
--      lfoo = ...
--   @ 
--
--   @vfoo@ is the "vectorised", or scalar, version that does the same as the original
--   function foo, but takes an explicit environment.
-- 
--   @lfoo@ is the "lifted" version that works on arrays.
--
--   @v_foo@ combines both of these into a `Closure` that also contains the
--   environment.
--
--   The original binding @foo@ is rewritten to call the vectorised version
--   present in the closure.
--
121
vectTopBind :: CoreBind -> VM CoreBind
122
vectTopBind b@(NonRec var expr)
123 124 125 126 127 128
 = do {   -- Vectorise the right-hand side, create an appropriate top-level binding and add it to
          -- the vectorisation map.
      ; (inline, isScalar, expr') <- vectTopRhs [] var expr
      ; var' <- vectTopBinder var inline expr'
      ; when isScalar $ 
          addGlobalScalar var
129

130 131 132 133 134 135
          -- We replace the original top-level binding by a value projected from the vectorised
          -- closure and add any newly created hoisted top-level bindings.
      ; cexpr <- tryConvert var var' expr
      ; hs <- takeHoisted
      ; return . Rec $ (var, cexpr) : (var', expr') : hs
      }
136 137 138
  `orElseV`
    return b
vectTopBind b@(Rec bs)
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 = let (vars, exprs) = unzip bs
   in
   do { (vars', _, exprs', hs) <- fixV $ 
          \ ~(_, inlines, rhss, _) ->
            do {   -- Vectorise the right-hand sides, create an appropriate top-level bindings and
                   --  add them to the vectorisation map.
               ; vars' <- sequence [vectTopBinder var inline rhs
                                   | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
               ; (inlines, areScalars, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
               ; hs <- takeHoisted
               ; if and areScalars
                 then      -- (1) Entire recursive group is scalar
                           --      => add all variables to the global set of scalars
                      do { mapM addGlobalScalar vars
                         ; return (vars', inlines, exprs', hs)
                         }
                 else      -- (2) At least one binding is not scalar
                           --     => vectorise again with empty set of local scalars
                      do { (inlines, _, exprs') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
                         ; hs <- takeHoisted
                         ; return (vars', inlines, exprs', hs)
                         }
               }
162
                      
163 164 165 166 167
          -- Replace the original top-level bindings by a values projected from the vectorised
          -- closures and add any newly created hoisted top-level bindings to the group.
      ; cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      ; return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
      }
168
  `orElseV`
169
    return b    
170
    
171 172 173 174 175 176
-- | Make the vectorised version of this top level binder, and add the mapping
--   between it and the original to the state. For some binder @foo@ the vectorised
--   version is @$v_foo@
--
--   NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
--   used inside of fixV in vectTopBind
177 178 179 180 181
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
182
vectTopBinder var inline expr
183 184 185 186 187 188 189 190 191
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
          -- If there is a vectorisation declartion for this binding, make sure that its type
          --  matches
      ; vectDecl <- lookupVectDecl var
      ; case vectDecl of
          Nothing                 -> return ()
          Just (vdty, _) 
192
            | eqType vty vdty -> return ()
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
            | otherwise           -> 
              cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
                (text "Expected type" <+> ppr vty)
                $$
                (text "Inferred type" <+> ppr vdty)

          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
                $  cloneId mkVectOcc var vty

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
208 209
  where
    unfolding = case inline of
210
                  Inline arity -> mkInlineUnfolding (Just arity) expr
211
                  DontInline   -> noUnfolding
Ian Lynagh's avatar
Ian Lynagh committed
212

213
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
--
-- We need to distinguish three cases:
--
-- (1) We have a (non-scalar) vectorisation declaration for the variable (which explicitly provides
--     vectorised code implemented by the user)
--     => no automatic vectorisation & instead use the user-supplied code
-- 
-- (2) We have a scalar vectorisation declaration for the variable
--     => generate vectorised code that uses a scalar 'map'/'zipWith' to lift the computation
-- 
-- (3) There is no vectorisation declaration for the variable
--     => perform automatic vectorisation of the RHS
--
vectTopRhs :: [Var]           -- ^ Names of all functions in the rec block
           -> Var             -- ^ Name of the binding.
           -> CoreExpr        -- ^ Body of the binding.
           -> VM ( Inline     -- (1) inline specification for the binding
                 , Bool       -- (2) whether the right-hand side is a scalar computation
                 , CoreExpr)  -- (3) the vectorised right-hand side
233
vectTopRhs recFs var expr
234 235 236 237 238 239 240 241 242 243
  = closedV
  $ do { traceVt ("vectTopRhs of " ++ show var) $ ppr expr
  
       ; globalScalar <- isGlobalScalar var
       ; vectDecl     <- lookupVectDecl var
       ; rhs globalScalar vectDecl
       }
  where
    rhs _globalScalar (Just (_, expr'))               -- Case (1)
      = return (inlineMe, False, expr')
244 245 246 247 248
    rhs True          Nothing                         -- Case (2)
      = do { expr' <- vectScalarFun True recFs expr
           ; return (inlineMe, True, vectorised expr')
           }
    rhs False         Nothing                         -- Case (3)
249 250 251 252 253
      = do { let fvs = freeVars expr
           ; (inline, isScalar, vexpr) <- inBind var $
                                            vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs fvs
           ; return (inline, isScalar, vectorised vexpr)
           }
254 255

-- | Project out the vectorised version of a binding from some closure,
256 257 258 259 260 261
--   or return the original body if that doesn't work or the binding is scalar. 
--
tryConvert :: Var       -- ^ Name of the original binding (eg @foo@)
           -> Var       -- ^ Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^ The original body of the binding.
           -> VM CoreExpr
262
tryConvert var vect_var rhs
263 264 265 266 267 268 269
  = do { globalScalar <- isGlobalScalar var
       ; if globalScalar
         then
           return rhs
         else
           fromVect (idType var) (Var vect_var) `orElseV` return rhs
       }