Vectorise.hs 13.7 KB
Newer Older
1
{-# OPTIONS -w #-}
2 3 4
-- The above warning supression flag is a temporary kludge.
-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
Ian Lynagh's avatar
Ian Lynagh committed
5
--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 7
-- for details

8 9 10 11 12
module Vectorise( vectorise )
where

#include "HsVersions.h"

13
import VectMonad
14
import VectUtils
15
import VectType
16
import VectCore
17

18 19 20
import DynFlags
import HscTypes

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
21
import CoreLint             ( showPass, endPass )
22
import CoreSyn
23 24
import CoreUtils
import CoreFVs
25 26
import SimplMonad           ( SimplCount, zeroSimplCount )
import Rules                ( RuleBase )
27
import DataCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
28
import TyCon
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
29
import Type
30 31
import FamInstEnv           ( extendFamInstEnvList )
import InstEnv              ( extendInstEnvList )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
32 33
import Var
import VarEnv
34
import VarSet
35
import Name                 ( Name, mkSysTvName, getName )
36
import NameEnv
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
37
import Id
38
import MkId                 ( unwrapFamInstScrut )
39
import OccName
40
import Module               ( Module )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
41

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
42
import DsMonad hiding (mapAndUnzipM)
43
import DsUtils              ( mkCoreTup, mkCoreTupTy )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
44

45
import Literal              ( Literal )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
46
import PrelNames
47
import TysWiredIn
48
import TysPrim              ( intPrimTy )
49
import BasicTypes           ( Boxity(..) )
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
50

51
import Outputable
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
52
import FastString
53
import Control.Monad        ( liftM, liftM2, zipWithM, mapAndUnzipM )
54
import Data.List            ( sortBy, unzip4 )
55

56 57 58
vectorise :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
          -> IO (SimplCount, ModGuts)
vectorise hsc_env _ _ guts
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
59 60 61 62
  = do
      showPass dflags "Vectorisation"
      eps <- hscEPS hsc_env
      let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
63
      Just (info', guts') <- initV hsc_env guts info (vectModule guts)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
64
      endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
65
      return (zeroSimplCount dflags, guts' { mg_vect_info = info' })
66 67 68
  where
    dflags = hsc_dflags hsc_env

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
69
vectModule :: ModGuts -> VM ModGuts
70 71
vectModule guts
  = do
72
      (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
73
      
74 75
      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
      updGEnv (setFamInstEnv fam_inst_env')
76
     
77 78
      -- dicts   <- mapM buildPADict pa_insts
      -- workers <- mapM vectDataConWorkers pa_insts
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
79
      binds'  <- mapM vectTopBind (mg_binds guts)
80
      return $ guts { mg_types        = types'
81
                    , mg_binds        = Rec tc_binds : binds'
82 83 84
                    , mg_fam_inst_env = fam_inst_env'
                    , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                    }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
85

86
vectTopBind :: CoreBind -> VM CoreBind
87 88 89
vectTopBind b@(NonRec var expr)
  = do
      var'  <- vectTopBinder var
90
      expr' <- vectTopRhs var expr
91 92 93 94 95 96 97 98
      hs    <- takeHoisted
      return . Rec $ (var, expr) : (var', expr') : hs
  `orElseV`
    return b

vectTopBind b@(Rec bs)
  = do
      vars'  <- mapM vectTopBinder vars
99
      exprs' <- zipWithM vectTopRhs vars exprs
100 101 102 103 104 105 106 107 108 109
      hs     <- takeHoisted
      return . Rec $ bs ++ zip vars' exprs' ++ hs
  `orElseV`
    return b
  where
    (vars, exprs) = unzip bs

vectTopBinder :: Var -> VM Var
vectTopBinder var
  = do
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
110 111
      vty  <- vectType (idType var)
      var' <- cloneId mkVectOcc var vty
112 113 114
      defGlobalVar var var'
      return var'
    
115 116
vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
vectTopRhs var expr
117 118
  = do
      closedV . liftM vectorised
119
              . inBind var
120
              $ vectPolyExpr (freeVars expr)
121

122 123
-- ----------------------------------------------------------------------------
-- Bindings
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
124

125
vectBndr :: Var -> VM VVar
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
126 127 128
vectBndr v
  = do
      vty <- vectType (idType v)
129
      lty <- mkPArrayType vty
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
130 131 132 133 134
      let vv = v `Id.setIdType` vty
          lv = v `Id.setIdType` lty
      updLEnv (mapTo vv lv)
      return (vv, lv)
  where
135
    mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
136

137 138 139 140 141 142 143 144 145 146
vectBndrNew :: Var -> FastString -> VM VVar
vectBndrNew v fs
  = do
      vty <- vectType (idType v)
      vv  <- newLocalVVar fs vty
      updLEnv (upd vv)
      return vv
  where
    upd vv env = env { local_vars = extendVarEnv (local_vars env) v vv }

147
vectBndrIn :: Var -> VM a -> VM (VVar, a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
148 149 150
vectBndrIn v p
  = localV
  $ do
151
      vv <- vectBndr v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
152
      x <- p
153
      return (vv, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
154

155 156 157 158 159 160 161 162
vectBndrNewIn :: Var -> FastString -> VM a -> VM (VVar, a)
vectBndrNewIn v fs p
  = localV
  $ do
      vv <- vectBndrNew v fs
      x  <- p
      return (vv, x)

163 164 165 166 167 168 169 170
vectBndrIn' :: Var -> (VVar -> VM a) -> VM (VVar, a)
vectBndrIn' v p
  = localV
  $ do
      vv <- vectBndr v
      x  <- p vv
      return (vv, x)

171
vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
172 173 174
vectBndrsIn vs p
  = localV
  $ do
175
      vvs <- mapM vectBndr vs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
176
      x <- p
177
      return (vvs, x)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
178

rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
179
-- ----------------------------------------------------------------------------
180 181
-- Expressions

182 183
vectVar :: Var -> VM VExpr
vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
184 185 186
  = do
      r <- lookupVar v
      case r of
187 188 189
        Local (vv,lv) -> return (Var vv, Var lv)
        Global vv     -> do
                           let vexpr = Var vv
190
                           lexpr <- liftPA vexpr
191
                           return (vexpr, lexpr)
192

193 194
vectPolyVar :: Var -> [Type] -> VM VExpr
vectPolyVar v tys
195
  = do
196
      vtys <- mapM vectType tys
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
197
      r <- lookupVar v
198
      case r of
199 200 201 202
        Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
                                     (polyApply (Var lv) vtys)
        Global poly    -> do
                            vexpr <- polyApply (Var poly) vtys
203
                            lexpr <- liftPA vexpr
204
                            return (vexpr, lexpr)
205

206 207
vectLiteral :: Literal -> VM VExpr
vectLiteral lit
208
  = do
209
      lexpr <- liftPA (Lit lit)
210 211
      return (Lit lit, lexpr)

212 213
vectPolyExpr :: CoreExprWithFVs -> VM VExpr
vectPolyExpr expr
214
  = polyAbstract tvs $ \abstract ->
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
215
    do
216
      mono' <- vectExpr mono
217
      return $ mapVect abstract mono'
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
218 219
  where
    (tvs, mono) = collectAnnTypeBinders expr  
220
                
221 222
vectExpr :: CoreExprWithFVs -> VM VExpr
vectExpr (_, AnnType ty)
223
  = liftM vType (vectType ty)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
224

225
vectExpr (_, AnnVar v) = vectVar v
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
226

227
vectExpr (_, AnnLit lit) = vectLiteral lit
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
228

229 230
vectExpr (_, AnnNote note expr)
  = liftM (vNote note) (vectExpr expr)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
231

232
vectExpr e@(_, AnnApp _ arg)
233
  | isAnnTypeArg arg
234
  = vectTyAppExpr fn tys
235 236
  where
    (fn, tys) = collectAnnTypeArgs e
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
237

238
vectExpr (_, AnnApp fn arg)
239
  = do
240 241 242 243 244 245 246
      arg_ty' <- vectType arg_ty
      res_ty' <- vectType res_ty
      fn'     <- vectExpr fn
      arg'    <- vectExpr arg
      mkClosureApp arg_ty' res_ty' fn' arg'
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
247

248
vectExpr (_, AnnCase scrut bndr ty alts)
249 250 251
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
  = vectAlgCase tycon ty_args scrut bndr ty alts
252 253 254
  where
    scrut_ty = exprType (deAnnotate scrut)

255
vectExpr (_, AnnCase expr bndr ty alts)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
256
  = panic "vectExpr: case"
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
257

258
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
259
  = do
260 261
      vrhs <- localV . inBind bndr $ vectPolyExpr rhs
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
262
      return $ vLet (vNonRec vbndr vrhs) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
263

264
vectExpr (_, AnnLet (AnnRec bs) body)
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
265
  = do
266 267
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
268
                                  (zipWithM vect_rhs bndrs rhss)
269
                                  (vectPolyExpr body)
270
      return $ vLet (vRec vbndrs vrhss) vbody
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
271
  where
272
    (bndrs, rhss) = unzip bs
rl@cse.unsw.edu.au's avatar
rl@cse.unsw.edu.au committed
273

274 275
    vect_rhs bndr rhs = localV
                      . inBind bndr
276
                      $ vectExpr rhs
277

278
vectExpr e@(fvs, AnnLam bndr _)
279
  | not (isId bndr) = pprPanic "vectExpr" (ppr $ deAnnotate e)
280
  | otherwise = vectLam fvs bs body
281 282
  where
    (bs,body) = collectAnnValBinders e
283

284 285
vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
286
  = do
287
      tyvars <- localTyVars
288 289 290
      (vs, vvs) <- readLEnv $ \env ->
                   unzip [(var, vv) | var <- varSetElems fvs
                                    , Just vv <- [lookupVarEnv (local_vars env) var]]
291

292 293 294
      arg_tys <- mapM (vectType . idType) bs
      res_ty  <- vectType (exprType $ deAnnotate body)

295
      buildClosures tyvars vvs arg_tys res_ty
296
        . hoistPolyVExpr tyvars
297
        $ do
298
            lc <- builtin liftingContext
299
            (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
300
                                           (vectExpr body)
301
            return $ vLams lc vbndrs vbody
302
  
303 304 305
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
vectTyAppExpr e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
306

307 308 309 310 311 312 313 314
type CoreAltWithFVs = AnnAlt Id VarSet

-- We convert
--
--   case e :: t of v { ... }
--
-- to
--
315 316
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
317 318
--
-- When lifting, we have to do it this way because v must have the type
319 320
-- [:V(T):] but the scrutinee must be cast to the representation type. We also
-- have to handle the case where v is a wild var correctly.
321 322 323
--   

-- FIXME: this is too lazy
324
vectAlgCase tycon ty_args scrut bndr ty [(DEFAULT, [], body)]
325 326 327 328 329 330 331
  = do
      vscrut <- vectExpr scrut
      vty    <- vectType ty
      lty    <- mkPArrayType vty
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

332
vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
333
  = do
334
      vect_tc <- maybeV (lookupTyCon tycon)
335 336 337
      vty <- vectType ty
      lty <- mkPArrayType vty
      vexpr <- vectExpr scrut
338
      (vbndr, (vbndrs, vbody)) <- vect_scrut_bndr
339 340 341 342 343 344
                                . vectBndrsIn bndrs
                                $ vectExpr body

      (vscrut, arr_tc, arg_tys) <- mkVScrut (vVar vbndr)
      vect_dc <- maybeV (lookupDataCon dc)
      let [arr_dc] = tyConDataCons arr_tc
345 346
      repr <- mkRepr vect_tc
      shape_bndrs <- arrShapeVars repr
347 348
      return . vLet (vNonRec vbndr vexpr)
             $ vCaseProd vscrut vty lty vect_dc arr_dc shape_bndrs vbndrs vbody
349 350 351 352
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr FSLIT("scrut")
                    | otherwise         = vectBndrIn bndr

353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
vectAlgCase tycon ty_args scrut bndr ty alts
  = do
      vect_tc <- maybeV (lookupTyCon tycon)
      vty               <- vectType ty
      lty               <- mkPArrayType vty

      repr        <- mkRepr vect_tc
      shape_bndrs <- arrShapeVars repr
      (len, sel, indices) <- arrSelector repr (map Var shape_bndrs)

      (vbndr, valts) <- vect_scrut_bndr $ mapM (proc_alt sel lty) alts'
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

      vexpr <- vectExpr scrut
      (vscrut, arr_tc, arg_tys) <- mkVScrut (vVar vbndr)
      let [arr_dc] = tyConDataCons arr_tc

      let (vect_scrut,  lift_scrut)  = vscrut
          (vect_bodies, lift_bodies) = unzip vbodies

      let vect_case = Case vect_scrut (mkWildId (exprType vect_scrut)) vty
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

      lbody <- combinePA vty len sel indices lift_bodies
      let lift_case = Case lift_scrut (mkWildId (exprType lift_scrut)) lty
                           [(DataAlt arr_dc, shape_bndrs ++ concat lift_bndrss,
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr FSLIT("scrut")
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT

    proc_alt sel lty (DataAlt dc, bndrs, body)
      = do
          vect_dc <- maybeV (lookupDataCon dc)
          let tag = mkDataConTag vect_dc
              fvs = freeVarsOf body `delVarSetList` bndrs
          (vect_bndrs, lift_bndrs, vbody)
            <- vect_alt_bndrs bndrs
             $ \len -> packLiftingContext len sel tag fvs lty
             $ vectExpr body

          return (vect_dc, vect_bndrs, lift_bndrs, vbody)

    vect_alt_bndrs [] p
      = do
          void_tc <- builtin voidTyCon
          let void_ty = mkTyConApp void_tc []
          arr_ty <- mkPArrayType void_ty
          bndr   <- newLocalVar FSLIT("voids") arr_ty
          len    <- lengthPA void_ty (Var bndr)
          e      <- p len
          return ([], [bndr], e)

    vect_alt_bndrs bndrs p
       = localV
       $ do
           vbndrs <- mapM vectBndr bndrs
           let (vect_bndrs, lift_bndrs) = unzip vbndrs
               vv : _ = vect_bndrs
               lv : _ = lift_bndrs
           len <- lengthPA (idType vv) (Var lv)
           e   <- p len
           return (vect_bndrs, lift_bndrs, e)

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

packLiftingContext :: CoreExpr -> CoreExpr -> CoreExpr -> VarSet -> Type -> VM VExpr -> VM VExpr
packLiftingContext len shape tag fvs res_ty p
  = do
      select <- builtin selectPAIntPrimVar
      let sel_expr = mkApps (Var select) [shape, tag]
      sel_var <- newLocalVar FSLIT("sel#") (exprType sel_expr)
      lc_var <- builtin liftingContext
      localV $
        do
438 439 440
          bnds <- mapM (packFreeVar (Var lc_var) (Var sel_var))
                . filter isLocalId
                $ varSetElems fvs
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
          (vexpr, lexpr) <- p
          return (vexpr, Let (NonRec sel_var sel_expr)
                         $ Case len lc_var res_ty [(DEFAULT, [], lexpr)])

packFreeVar :: CoreExpr -> CoreExpr -> Var -> VM [CoreBind]
packFreeVar len sel v
  = do
      r <- lookupVar v
      case r of
        Local (vv,lv) ->
          do
            lv' <- cloneVar lv
            expr <- packPA (idType vv) (Var lv) len sel
            updLEnv (upd vv lv')
            return [(NonRec lv' expr)]

        _  -> return []
  where
    upd vv lv' env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv') }