Vectorise.hs 13.9 KB
Newer Older
1 2 3 4 5 6
-- Main entry point to the vectoriser.  It is invoked iff the option '-fvectorise' is passed.
--
-- This module provides the function 'vectorise', which vectorises an entire (desugared) module.
-- It vectorises all type declarations and value bindings.  It also processes all VECTORISE pragmas
-- (aka vectorisation declarations), which can lead to the vectorisation of imported data types
-- and the enrichment of imported functions with vectorised versions.
7

8
module Vectorise ( vectorise )
9 10
where

11 12 13
import Vectorise.Type.Env
import Vectorise.Type.Type
import Vectorise.Convert
14
import Vectorise.Utils.Hoisting
15
import Vectorise.Exp
16
import Vectorise.Env
17
import Vectorise.Monad
18

19
import HscTypes hiding      ( MonadThings(..) )
20
import CoreUnfold           ( mkInlineUnfolding )
21 22
import PprCore
import CoreSyn
Ian Lynagh's avatar
Ian Lynagh committed
23
import CoreMonad            ( CoreM, getHscEnv )
24
import Type
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
25
import Id
26
import DynFlags
27
import Outputable
28
import Util                 ( zipLazy )
29
import MonadUtils
30
import FamInstEnv           ( toBranchedFamInst )
31

32
import Control.Monad
33 34


35
-- |Vectorise a single module.
36 37 38 39 40 41
--
vectorise :: ModGuts -> CoreM ModGuts
vectorise guts
 = do { hsc_env <- getHscEnv
      ; liftIO $ vectoriseIO hsc_env guts
      }
42

43
-- Vectorise a single monad, given the dynamic compiler flags and HscEnv.
44 45 46 47 48
--
vectoriseIO :: HscEnv -> ModGuts -> IO ModGuts
vectoriseIO hsc_env guts
 = do {   -- Get information about currently loaded external packages.
      ; eps <- hscEPS hsc_env
49

50 51
          -- Combine vectorisation info from the current module, and external ones.
      ; let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
52

53 54 55 56
          -- Run the main VM computation.
      ; Just (info', guts') <- initV hsc_env guts info (vectModule guts)
      ; return (guts' { mg_vect_info = info' })
      }
57

58
-- Vectorise a single module, in the VM monad.
59
--
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
60
vectModule :: ModGuts -> VM ModGuts
61
vectModule guts@(ModGuts { mg_tcs        = tycons
62 63 64
                         , mg_binds      = binds
                         , mg_fam_insts  = fam_insts
                         , mg_vect_decls = vect_decls
65 66 67 68
                         })
 = do { dumpOptVt Opt_D_dump_vt_trace "Before vectorisation" $ 
          pprCoreBindings binds
 
69
          -- Pick out all 'VECTORISE [SCALAR] type' and 'VECTORISE class' pragmas
70 71 72
      ; let ty_vect_decls  = [vd | vd@(VectType _ _ _) <- vect_decls]
            cls_vect_decls = [vd | vd@(VectClass _)    <- vect_decls]
      
73 74 75 76 77 78
          -- Vectorise the type environment.  This will add vectorised
          -- type constructors, their representaions, and the
          -- conrresponding data constructors.  Moreover, we produce
          -- bindings for dfuns and family instances of the classes
          -- and type families used in the DPH library to represent
          -- array types.
79
      ; (new_tycons, new_fam_insts, tc_binds) <- vectTypeEnv tycons ty_vect_decls cls_vect_decls
80

81 82
          -- Family instance environment for /all/ home-package modules including those instances
          -- generated by 'vectTypeEnv'.
83
      ; (_, fam_inst_env) <- readGEnv global_fam_inst_env
Ian Lynagh's avatar
Ian Lynagh committed
84

85
          -- Vectorise all the top level bindings and VECTORISE declarations on imported identifiers
86
          -- NB: Need to vectorise the imported bindings first (local bindings may depend on them).
87
      ; let impBinds = [(imp_id, expr) | Vect imp_id expr <- vect_decls, isGlobalId imp_id]
88
      ; binds_imp <- mapM vectImpBind impBinds
89
      ; binds_top <- mapM vectTopBind binds
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
90

91
      ; return $ guts { mg_tcs          = tycons ++ new_tycons
92 93
                        -- we produce no new classes or instances, only new class type constructors
                        -- and dfuns
94
                      , mg_binds        = Rec tc_binds : (binds_top ++ binds_imp)
95
                      , mg_fam_inst_env = fam_inst_env
96
                      , mg_fam_insts    = fam_insts ++ (map toBranchedFamInst new_fam_insts)
97 98
                      }
      }
99

100 101
-- Try to vectorise a top-level binding.  If it doesn't vectorise, or if it is entirely scalar, then
-- omit vectorisation of that binding.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
--
-- For example, for the binding 
--
-- @  
--    foo :: Int -> Int
--    foo = \x -> x + x
-- @
--
-- we get
-- @
--    foo  :: Int -> Int
--    foo  = \x -> vfoo $: x                  
--
--    v_foo :: Closure void vfoo lfoo
--    v_foo = closure vfoo lfoo void        
--
--    vfoo :: Void -> Int -> Int
--    vfoo = ...
--
--    lfoo :: PData Void -> PData Int -> PData Int
--    lfoo = ...
-- @ 
--
125 126
-- @vfoo@ is the "vectorised", or scalar, version that does the same as the original function foo,
-- but takes an explicit environment.
127 128 129
--
-- @lfoo@ is the "lifted" version that works on arrays.
--
130
-- @v_foo@ combines both of these into a `Closure` that also contains the environment.
131
--
132
-- The original binding @foo@ is rewritten to call the vectorised version present in the closure.
133 134 135
--
-- Vectorisation may be surpressed by annotating a binding with a 'NOVECTORISE' pragma.  If this
-- pragma is used in a group of mutually recursive bindings, either all or no binding must have
136 137 138
-- the pragma.  If only some bindings are annotated, a fatal error is being raised. (In the case of
-- scalar bindings, we only omit vectorisation if all bindings in a group are scalar.)
--
139 140
-- FIXME: Once we support partial vectorisation, we may be able to vectorise parts of a group, or
--   we may emit a warning and refrain from vectorising the entire group.
141
--
142
vectTopBind :: CoreBind -> VM CoreBind
143
vectTopBind b@(NonRec var expr)
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
  = do
    { traceVt "= Vectorise non-recursive top-level variable" (ppr var)
    
    ; (hasNoVect, vectDecl) <- lookupVectDecl var
    ; if hasNoVect
      then do
      {   -- 'NOVECTORISE' pragma => leave this binding as it is
      ; traceVt "NOVECTORISE" $ ppr var
      ; return b
      }
      else do 
    { vectRhs <- case vectDecl of
        Just (_, expr') ->
            -- 'VECTORISE' pragma => just use the provided vectorised rhs
          do
          { traceVt "VECTORISE" $ ppr var
160
          ; addGlobalParallelVar var
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
          ; return $ Just (False, inlineMe, expr')
          }
        Nothing         ->
            -- no pragma => standard vectorisation of rhs
          do
          { traceVt "[Vanilla]" $ ppr var <+> char '=' <+> ppr expr
          ; vectTopExpr var expr
          }
    ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
    ; case vectRhs of 
      { Nothing ->
          -- scalar binding => leave this binding as it is
          do 
          { traceVt "scalar binding [skip]" $ ppr var
          ; return b
          }
      ; Just (parBind, inline, expr') -> do 
    {
       -- vanilla case => create an appropriate top-level binding & add it to the vectorisation map
    ; when parBind $ 
        addGlobalParallelVar var
    ; var' <- vectTopBinder var inline expr'

        -- We replace the original top-level binding by a value projected from the vectorised
        -- closure and add any newly created hoisted top-level bindings.
    ; cexpr <- tryConvert var var' expr
    ; return . Rec $ (var, cexpr) : (var', expr') : hs
    } } } }
    `orElseErrV`
    do 
    { emitVt "  Could NOT vectorise top-level binding" $ ppr var
    ; return b
    }
vectTopBind b@(Rec binds)
  = do
    { traceVt "= Vectorise recursive top-level variables" $ ppr vars
    
    ; vectDecls <- mapM lookupVectDecl vars
    ; let hasNoVects = map fst vectDecls
    ; if and hasNoVects 
      then do
      {   -- 'NOVECTORISE' pragmas => leave this entire binding group as it is
      ; traceVt "NOVECTORISE" $ ppr vars
      ; return b
      }
      else do 
    { if or hasNoVects
      then do
        {   -- Inconsistent 'NOVECTORISE' pragmas => bail out
        ; dflags <- getDynFlags
        ; cantVectorise dflags noVectoriseErr (ppr b)
212
        }
213
      else do 
214 215 216
    { traceVt "[Vanilla]" $ vcat [ppr var <+> char '=' <+> ppr expr | (var, expr) <- binds]
    
       -- For all bindings *with* a pragma, just use the pragma-supplied vectorised expression
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    ; newBindsWPragma  <- concat <$>
                          sequence [ vectTopBindAndConvert bind inlineMe expr'
                                   | (bind, (_, Just (_, expr'))) <- zip binds vectDecls]

        -- Standard vectorisation of all rhses that are *without* a pragma.
        -- NB: The reason for 'fixV' is rather subtle: 'vectTopBindAndConvert' adds entries for
        --     the bound variables in the recursive group to the vectorisation map, which in turn
        --     are needed by 'vectPolyExprs' (unless it returns 'Nothing').
    ; let bindsWOPragma = [bind | (bind, (_, Nothing)) <- zip binds vectDecls]
    ; (newBinds, _) <- fixV $
        \ ~(_, exprs') ->
          do
          {   -- Create appropriate top-level bindings, enter them into the vectorisation map, and
              -- vectorise the right-hand sides
          ; newBindsWOPragma <- concat <$>
                                sequence [vectTopBindAndConvert bind inline expr 
                                         | (bind, ~(inline, expr)) <- zipLazy bindsWOPragma exprs']
                                         -- irrefutable pattern and 'zipLazy' to tie the knot;
                                         -- hence, can't use 'zipWithM'
          ; vectRhses <- vectTopExprs bindsWOPragma
          ; hs <- takeHoisted -- make sure we clean those out (even if we skip)
238

239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
          ; case vectRhses of
              Nothing ->
                -- scalar bindings => skip all bindings except those with pragmas and retract the
                --   entries into the vectorisation map for the scalar bindings
                do 
                { traceVt "scalar bindings [skip]" $ ppr vars
                ; mapM_ (undefGlobalVar . fst) bindsWOPragma
                ; return (bindsWOPragma ++ newBindsWPragma, exprs')
                }
              Just (parBind, exprs') -> 
                -- vanilla case => record parallel variables and return the final bindings
                do
                { when parBind $ 
                    mapM_ addGlobalParallelVar vars
                ; return (newBindsWOPragma ++ newBindsWPragma ++ hs, exprs')                  
                }
          }
    ; return $ Rec newBinds
    } } }
    `orElseErrV`
    do 
    { emitVt "  Could NOT vectorise top-level bindings" $ ppr vars
    ; return b
    }
  where
    vars = map fst binds
265
    noVectoriseErr = "NOVECTORISE must be used on all or no bindings of a recursive group"
266 267 268 269 270 271 272 273 274
    
    -- Replace the original top-level bindings by a values projected from the vectorised
    -- closures and add any newly created hoisted top-level bindings to the group.
    vectTopBindAndConvert (var, expr) inline expr'
      = do
        { var'  <- vectTopBinder var inline expr'
        ; cexpr <- tryConvert var var' expr
        ; return [(var, cexpr), (var', expr')]
        }
275

276
-- Add a vectorised binding to an imported top-level variable that has a VECTORISE pragma
277 278
-- in this module.
--
279
-- RESTIRCTION: Currently, we cannot use the pragma for mutually recursive definitions.
280
--
281 282 283 284
vectImpBind :: (Id, CoreExpr) -> VM CoreBind
vectImpBind (var, expr)
  = do 
    { traceVt "= Add vectorised binding to imported variable" (ppr var)
285

286 287 288 289 290 291
    ; var' <- vectTopBinder var inlineMe expr
    ; return $ NonRec var' expr
    }
 
-- |Make the vectorised version of this top level binder, and add the mapping between it and the
-- original to the state. For some binder @foo@ the vectorised version is @$v_foo@
292
--
293 294
-- NOTE: 'vectTopBinder' *MUST* be lazy in inline and expr because of how it is used inside of
--       'fixV' in 'vectTopBind'.
295 296 297 298 299
--
vectTopBinder :: Var      -- ^ Name of the binding.
              -> Inline   -- ^ Whether it should be inlined, used to annotate it.
              -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
              -> VM Var   -- ^ Name of the vectorised binding.
300
vectTopBinder var inline expr
301 302 303
 = do {   -- Vectorise the type attached to the var.
      ; vty  <- vectType (idType var)
      
304 305
          -- If there is a vectorisation declartion for this binding, make sure its type matches
      ; (_, vectDecl) <- lookupVectDecl var
306
      ; case vectDecl of
307
          Nothing             -> return ()
308
          Just (vdty, _) 
309
            | eqType vty vdty -> return ()
310
            | otherwise       -> 
311 312 313 314 315 316 317
              do 
              { dflags <- getDynFlags
              ; cantVectorise dflags ("Type mismatch in vectorisation pragma for " ++ showPpr dflags var) $
                  (text "Expected type" <+> ppr vty)
                  $$
                  (text "Inferred type" <+> ppr vdty)
              }
318 319
          -- Make the vectorised version of binding's name, and set the unfolding used for inlining
      ; var' <- liftM (`setIdUnfoldingLazily` unfolding) 
320
                $  mkVectId var vty
321 322 323 324 325 326

          -- Add the mapping between the plain and vectorised name to the state.
      ; defGlobalVar var var'

      ; return var'
    }
327 328
  where
    unfolding = case inline of
329
                  Inline arity -> mkInlineUnfolding (Just arity) expr
330
                  DontInline   -> noUnfolding
331 332 333 334 335 336 337 338 339 340 341 342
{-
!!!TODO: dfuns and unfoldings:
           -- Do not inline the dfun; instead give it a magic DFunFunfolding
           -- See Note [ClassOp/DFun selection]
           -- See also note [Single-method classes]
        dfun_id_w_fun
           | isNewTyCon class_tc
           = dfun_id `setInlinePragma` alwaysInlinePragma { inl_sat = Just 0 }
           | otherwise
           = dfun_id `setIdUnfolding`  mkDFunUnfolding dfun_ty dfun_args
                     `setInlinePragma` dfunInlinePragma
 -}
Ian Lynagh's avatar
Ian Lynagh committed
343

344 345
-- |Project out the vectorised version of a binding from some closure, or return the original body
-- if that doesn't work.
346
--
347 348 349
tryConvert :: Var       -- ^Name of the original binding (eg @foo@)
           -> Var       -- ^Name of vectorised version of binding (eg @$vfoo@)
           -> CoreExpr  -- ^The original body of the binding.
350
           -> VM CoreExpr
351
tryConvert var vect_var rhs
352 353 354 355 356 357
  = fromVect (idType var) (Var vect_var) 
    `orElseErrV` 
    do 
    { emitVt "  Could NOT call vectorised from original version" $ ppr var
    ; return rhs
    }