Vectorise.hs 14.9 KB
Newer Older
1
{-# OPTIONS -w #-}
2 3 4
-- The above warning supression flag is a temporary kludge.
-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
Ian Lynagh's avatar
Ian Lynagh committed
5
--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 7
-- for details

8 9 10 11 12
module Vectorise( vectorise )
where

#include "HsVersions.h"

13
import VectMonad
14
import VectUtils
15
import VectType
16
import VectCore
17

18 19 20
import DynFlags
import HscTypes

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
21
import CoreLint             ( showPass, endPass )
22
import CoreSyn
23 24
import CoreUtils
import CoreFVs
25 26
import SimplMonad           ( SimplCount, zeroSimplCount )
import Rules                ( RuleBase )
27
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
28
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Type
30 31
import FamInstEnv           ( extendFamInstEnvList )
import InstEnv              ( extendInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
32 33
import Var
import VarEnv
34
import VarSet
35
import Name                 ( Name, mkSysTvName, getName )
36
import NameEnv
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37
import Id
38
import MkId                 ( unwrapFamInstScrut )
39
import OccName
40
import Module               ( Module )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
41

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
42
import DsMonad hiding (mapAndUnzipM)
43
import DsUtils              ( mkCoreTup, mkCoreTupTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
44

45
import Literal              ( Literal, mkMachInt )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
46
import PrelNames
47
import TysWiredIn
48
import TysPrim              ( intPrimTy )
49
import BasicTypes           ( Boxity(..) )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
50

51
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
52
import FastString
53
import Control.Monad        ( liftM, liftM2, zipWithM, mapAndUnzipM )
54
import Data.List            ( sortBy, unzip4 )
55

56 57 58
vectorise :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
          -> IO (SimplCount, ModGuts)
vectorise hsc_env _ _ guts
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59 60 61 62
  = do
      showPass dflags "Vectorisation"
      eps <- hscEPS hsc_env
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
63
      Just (info', guts') <- initV hsc_env guts info (vectModule guts)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
64
      endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
65
      return (zeroSimplCount dflags, guts' { mg_vect_info = info' })
66 67 68
  where
    dflags = hsc_dflags hsc_env

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
69
vectModule :: ModGuts -> VM ModGuts
70 71
vectModule guts
  = do
72
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
73
      
74 75
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
76
     
77 78
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
79
      binds'  <- mapM vectTopBind (mg_binds guts)
80
      return $ guts { mg_types        = types'
81
                    , mg_binds        = Rec tc_binds : binds'
82 83 84
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
85

86
vectTopBind :: CoreBind -> VM CoreBind
87 88 89
vectTopBind b@(NonRec var expr)
  = do
      var'  <- vectTopBinder var
90
      expr' <- vectTopRhs var expr
91
      hs    <- takeHoisted
92 93
      cexpr <- tryConvert var var' expr
      return . Rec $ (var, cexpr) : (var', expr') : hs
94 95 96 97 98 99
  `orElseV`
    return b

vectTopBind b@(Rec bs)
  = do
      vars'  <- mapM vectTopBinder vars
100
      exprs' <- zipWithM vectTopRhs vars exprs
101
      hs     <- takeHoisted
102 103
      cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
      return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
104 105 106 107 108 109 110 111
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

vectTopBinder :: Var -> VM Var
vectTopBinder var
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
112 113
      vty  <- vectType (idType var)
      var' <- cloneId mkVectOcc var vty
114 115 116
      defGlobalVar var var'
      return var'
    
117 118
vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
vectTopRhs var expr
119 120
  = do
      closedV . liftM vectorised
121
              . inBind var
122
              $ vectPolyExpr (freeVars expr)
123

124 125 126 127
tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
tryConvert var vect_var rhs
  = fromVect (idType var) (Var vect_var) `orElseV` return rhs

128 129
-- ----------------------------------------------------------------------------
-- Bindings
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
130

131
vectBndr :: Var -> VM VVar
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
132 133 134
vectBndr v
  = do
      vty <- vectType (idType v)
135
      lty <- mkPArrayType vty
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
136 137 138 139 140
      let vv = v `Id.setIdType` vty
          lv = v `Id.setIdType` lty
      updLEnv (mapTo vv lv)
      return (vv, lv)
  where
141
    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
142

143 144 145 146 147 148 149 150 151 152
vectBndrNew :: Var -> FastString -> VM VVar
vectBndrNew v fs
  = do
      vty <- vectType (idType v)
      vv  <- newLocalVVar fs vty
      updLEnv (upd vv)
      return vv
  where
    upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }

153
vectBndrIn :: Var -> VM a -> VM (VVar, a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
154 155 156
vectBndrIn v p
  = localV
  $ do
157
      vv <- vectBndr v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
158
      x <- p
159
      return (vv, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
160

161 162 163 164 165 166 167 168
vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
vectBndrNewIn v fs p
  = localV
  $ do
      vv <- vectBndrNew v fs
      x  <- p
      return (vv, x)

169 170 171 172 173 174 175 176
vectBndrIn' :: Var -> (VVar -> VM a) -> VM (VVar, a)
vectBndrIn' v p
  = localV
  $ do
      vv <- vectBndr v
      x  <- p vv
      return (vv, x)

177
vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
178 179 180
vectBndrsIn vs p
  = localV
  $ do
181
      vvs <- mapM vectBndr vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
182
      x <- p
183
      return (vvs, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
184

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
185
-- ----------------------------------------------------------------------------
186 187
-- Expressions

188 189
vectVar :: Var -> VM VExpr
vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
190 191 192
  = do
      r <- lookupVar v
      case r of
193 194 195
        Local (vv,lv) -> return (Var vv, Var lv)
        Global vv     -> do
                           let vexpr = Var vv
196
                           lexpr <- liftPA vexpr
197
                           return (vexpr, lexpr)
198

199 200
vectPolyVar :: Var -> [Type] -> VM VExpr
vectPolyVar v tys
201
  = do
202
      vtys <- mapM vectType tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
203
      r <- lookupVar v
204
      case r of
205 206 207 208
        Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
                                     (polyApply (Var lv) vtys)
        Global poly    -> do
                            vexpr <- polyApply (Var poly) vtys
209
                            lexpr <- liftPA vexpr
210
                            return (vexpr, lexpr)
211

212 213
vectLiteral :: Literal -> VM VExpr
vectLiteral lit
214
  = do
215
      lexpr <- liftPA (Lit lit)
216 217
      return (Lit lit, lexpr)

218
vectPolyExpr :: CoreExprWithFVs -> VM VExpr
219 220
vectPolyExpr (_, AnnNote note expr)
  = liftM (vNote note) $ vectPolyExpr expr
221
vectPolyExpr expr
222
  = polyAbstract tvs $ \abstract ->
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
223
    do
224
      mono' <- vectExpr mono
225
      return $ mapVect abstract mono'
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
226 227
  where
    (tvs, mono) = collectAnnTypeBinders expr  
228
                
229 230
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
231
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
232

233
vectExpr (_, AnnVar v) = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
234

235
vectExpr (_, AnnLit lit) = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
236

237 238
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
239

240
vectExpr e@(_, AnnApp _ arg)
241
  | isAnnTypeArg arg
242
  = vectTyAppExpr fn tys
243 244
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
245

246 247 248 249 250 251 252 253 254 255 256
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
      lexpr <- liftPA vexpr
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
                

257
vectExpr (_, AnnApp fn arg)
258
  = do
259 260 261 262 263 264 265
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
      fn'     <- vectExpr fn
      arg'    <- vectExpr arg
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
266

267
vectExpr (_, AnnCase scrut bndr ty alts)
268 269 270
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
271 272 273
  where
    scrut_ty = exprType (deAnnotate scrut)

274
vectExpr (_, AnnCase expr bndr ty alts)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
275
  = panic "vectExpr: case"
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
276

277
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
278
  = do
279 280
      vrhs <- localV . inBind bndr $ vectPolyExpr rhs
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
281
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
282

283
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
284
  = do
285 286
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
287
                                  (zipWithM vect_rhs bndrs rhss)
288
                                  (vectPolyExpr body)
289
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
290
  where
291
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
292

293 294
    vect_rhs bndr rhs = localV
                      . inBind bndr
295
                      $ vectExpr rhs
296

297
vectExpr e@(fvs, AnnLam bndr _)
298
  | not (isId bndr) = pprPanic "vectExpr" (ppr $ deAnnotate e)
299
  | otherwise = vectLam fvs bs body
300 301
  where
    (bs,body) = collectAnnValBinders e
302

303 304
vectExpr e = pprPanic "vectExpr" (ppr $ deAnnotate e)

305 306
vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
307
  = do
308
      tyvars <- localTyVars
309 310 311
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
312

313 314 315
      arg_tys <- mapM (vectType . idType) bs
      res_ty  <- vectType (exprType $ deAnnotate body)

316
      buildClosures tyvars vvs arg_tys res_ty
317
        . hoistPolyVExpr tyvars
318
        $ do
319
            lc <- builtin liftingContext
320
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
321
                                           (vectExpr body)
322
            return $ vLams lc vbndrs vbody
323
  
324 325 326
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
vectTyAppExpr e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
327

328 329 330 331 332 333 334 335
type CoreAltWithFVs = AnnAlt Id VarSet

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
336 337
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
338 339
--
-- When lifting, we have to do it this way because v must have the type
340 341
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
342 343 344
--   

-- FIXME: this is too lazy
345
vectAlgCase tycon ty_args scrut bndr ty [(DEFAULT, [], body)]
346 347 348 349 350 351 352
  = do
      vscrut <- vectExpr scrut
      vty    <- vectType ty
      lty    <- mkPArrayType vty
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

353 354 355 356 357 358 359 360
vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, [], body)]
  = do
      vscrut <- vectExpr scrut
      vty    <- vectType ty
      lty    <- mkPArrayType vty
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

361
vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
362
  = do
363
      vect_tc <- maybeV (lookupTyCon tycon)
364 365 366
      vty <- vectType ty
      lty <- mkPArrayType vty
      vexpr <- vectExpr scrut
367
      (vbndr, (vbndrs, vbody)) <- vect_scrut_bndr
368 369 370 371 372 373
                                . vectBndrsIn bndrs
                                $ vectExpr body

      (vscrut, arr_tc, arg_tys) <- mkVScrut (vVar vbndr)
      vect_dc <- maybeV (lookupDataCon dc)
      let [arr_dc] = tyConDataCons arr_tc
374 375
      repr <- mkRepr vect_tc
      shape_bndrs <- arrShapeVars repr
376 377
      return . vLet (vNonRec vbndr vexpr)
             $ vCaseProd vscrut vty lty vect_dc arr_dc shape_bndrs vbndrs vbody
378 379 380 381
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr FSLIT("scrut")
                    | otherwise         = vectBndrIn bndr

382 383 384 385 386 387 388 389 390 391
vectAlgCase tycon ty_args scrut bndr ty alts
  = do
      vect_tc <- maybeV (lookupTyCon tycon)
      vty               <- vectType ty
      lty               <- mkPArrayType vty

      repr        <- mkRepr vect_tc
      shape_bndrs <- arrShapeVars repr
      (len, sel, indices) <- arrSelector repr (map Var shape_bndrs)

392
      (vbndr, valts) <- vect_scrut_bndr $ mapM (proc_alt sel vty lty) alts'
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
      (vscrut, arr_tc, arg_tys) <- mkVScrut (vVar vbndr)
      let [arr_dc] = tyConDataCons arr_tc

      let (vect_scrut,  lift_scrut)  = vscrut
          (vect_bodies, lift_bodies) = unzip vbodies

      let vect_case = Case vect_scrut (mkWildId (exprType vect_scrut)) vty
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

      lbody <- combinePA vty len sel indices lift_bodies
      let lift_case = Case lift_scrut (mkWildId (exprType lift_scrut)) lty
                           [(DataAlt arr_dc, shape_bndrs ++ concat lift_bndrss,
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr FSLIT("scrut")
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT

423
    proc_alt sel vty lty (DataAlt dc, bndrs, body)
424 425 426 427 428 429
      = do
          vect_dc <- maybeV (lookupDataCon dc)
          let tag = mkDataConTag vect_dc
              fvs = freeVarsOf body `delVarSetList` bndrs
          (vect_bndrs, lift_bndrs, vbody)
            <- vect_alt_bndrs bndrs
430
             $ \len -> packLiftingContext len sel tag fvs vty lty
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
             $ vectExpr body

          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

    vect_alt_bndrs [] p
      = do
          void_tc <- builtin voidTyCon
          let void_ty = mkTyConApp void_tc []
          arr_ty <- mkPArrayType void_ty
          bndr   <- newLocalVar FSLIT("voids") arr_ty
          len    <- lengthPA void_ty (Var bndr)
          e      <- p len
          return ([], [bndr], e)

    vect_alt_bndrs bndrs p
       = localV
       $ do
           vbndrs <- mapM vectBndr bndrs
           let (vect_bndrs, lift_bndrs) = unzip vbndrs
               vv : _ = vect_bndrs
               lv : _ = lift_bndrs
           len <- lengthPA (idType vv) (Var lv)
           e   <- p len
           return (vect_bndrs, lift_bndrs, e)

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

458 459 460
packLiftingContext :: CoreExpr -> CoreExpr -> CoreExpr -> VarSet
                   -> Type -> Type -> VM VExpr -> VM VExpr
packLiftingContext len shape tag fvs vty lty p
461 462 463 464 465 466 467
  = do
      select <- builtin selectPAIntPrimVar
      let sel_expr = mkApps (Var select) [shape, tag]
      sel_var <- newLocalVar FSLIT("sel#") (exprType sel_expr)
      lc_var <- builtin liftingContext
      localV $
        do
468 469 470
          bnds <- mapM (packFreeVar (Var lc_var) (Var sel_var))
                . filter isLocalId
                $ varSetElems fvs
471
          (vexpr, lexpr) <- p
472
          empty <- emptyPA vty
473
          return (vexpr, Let (NonRec sel_var sel_expr)
474
                         . mkLets (concat bnds)
475 476 477
                         $ Case len lc_var lty
                             [(DEFAULT, [], lexpr),
                              (LitAlt (mkMachInt 0), [], empty)])
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494

packFreeVar :: CoreExpr -> CoreExpr -> Var -> VM [CoreBind]
packFreeVar len sel v
  = do
      r <- lookupVar v
      case r of
        Local (vv,lv) ->
          do
            lv' <- cloneVar lv
            expr <- packPA (idType vv) (Var lv) len sel
            updLEnv (upd vv lv')
            return [(NonRec lv' expr)]

        _  -> return []
  where
    upd vv lv' env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv') }