Exp.hs 45 KB
Newer Older
1 2
{-# LANGUAGE TupleSections #-}

3
-- |Vectorisation of expressions.
4

5 6 7 8
module Vectorise.Exp
  (   -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular 
      --   variable bindings
    vectPolyExpr
9
  , vectDictExpr
10 11 12 13
  , vectScalarFun
  , vectScalarDFun
  ) 
where
14 15 16

#include "HsVersions.h"

17
import Vectorise.Type.Type
18
import Vectorise.Var
19
import Vectorise.Convert
20 21 22 23
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
24
import Vectorise.Utils
25 26 27

import CoreUtils
import MkCore
28
import CoreSyn
29
import CoreFVs
30
import Class
31 32
import DataCon
import TyCon
33
import TcType
34
import Type
35
import PrelNames
36 37 38 39
import Var
import VarEnv
import VarSet
import Id
40
import BasicTypes( isStrongLoopBreaker )
41 42 43 44 45 46
import Literal
import TysWiredIn
import TysPrim
import Outputable
import FastString
import Control.Monad
47
import Control.Applicative
48
import Data.Maybe
49
import Data.List
gckeller's avatar
gckeller committed
50 51
import TcRnMonad (doptM)
import DynFlags (DynFlag(Opt_AvoidVect))
gckeller's avatar
gckeller committed
52 53


54
-- Main entry point to vectorise expressions -----------------------------------
gckeller's avatar
gckeller committed
55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
-- |Vectorise a polymorphic expression.
--
-- If not yet available, precompute vectorisation avoidance information before vectorising.  If
-- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance
-- information to encapsulated subexpression that do not need to be vectorised.
--
vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree
             -> VM (Inline, Bool, VExpr)
  -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions)
vectPolyExpr loop_breaker recFns expr Nothing
  = do
    { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
    ; vi <- vectAvoidInfo expr  
    ; (expr', vi') <- 
        if vectAvoidance
        then do 
             { (expr', vi') <- encapsulateScalars vi expr
             ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr')
             ; return (expr', vi')
             }
        else return (expr, vi)
    ; vectPolyExpr loop_breaker recFns expr' (Just vi')
    }

  -- traverse through ticks
vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit])) 
  = do 
    { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit)
    ; return (inline, isScalarFn, vTick tickish expr')
    }

  -- collect and vectorise type abstractions; then, descent into the body
vectPolyExpr loop_breaker recFns expr (Just vit)
  = do 
    { let (tvs, mono) = collectAnnTypeBinders expr
          vit'        = stripLevels (length tvs) vit
    ; arity <- polyArity tvs
    ; polyAbstract tvs $ \args ->
        do 
        { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit'
        ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
        }
    }
  where
    stripLevels 0 vit               = vit
    stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit
    stripLevels _ vit               = pprPanic "vectPolyExpr: stripLevels:" (text (show vit))
gckeller's avatar
gckeller committed
103

104 105 106 107 108 109 110 111 112
-- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
-- into a lambda abstraction over all its free variables followed by the corresponding application
-- to those variables.  We can, then, avoid the vectorisation of the ensapsulated subexpressions.
--
-- Preconditions:
--
-- * All free variables and the result type must be /simple/ types.
-- * The expression is sufficientlt complex (top warrant special treatment).  For now, that is
--   every expression that is not constant and contains at least one operation.
gckeller's avatar
gckeller committed
113
--  
114 115
encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
encapsulateScalars  vit ce@(_, AnnType _ty) 
gckeller's avatar
gckeller committed
116
  = return (ce, vit)
gckeller's avatar
gckeller committed
117
      
118
encapsulateScalars  vit ce@(_, AnnVar _v)  
gckeller's avatar
gckeller committed
119
  = return (ce, vit)
gckeller's avatar
gckeller committed
120
  
121
encapsulateScalars vit ce@(_, AnnLit _)
gckeller's avatar
gckeller committed
122
  = return (ce, vit) 
gckeller's avatar
gckeller committed
123

124 125
encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr)
  = do { (extExpr, vit') <- encapsulateScalars vit expr
gckeller's avatar
gckeller committed
126
       ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
gckeller's avatar
gckeller committed
127 128
       }

129
encapsulateScalars _ (_fvs, AnnTick _tck _expr)
gckeller's avatar
gckeller committed
130 131
  = panic "encapsulateScalar AnnTick doesn't match up"
  
132
encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr) 
gckeller's avatar
gckeller committed
133 134
  = do { varsS <- varsSimple fvs 
       ; case (vi, varsS) of
135
           (VISimple, True) -> do { let (e', vit') = liftSimple vit ce
gckeller's avatar
gckeller committed
136 137
                                  ; return (e', vit') 
                                  }
138
           _                -> do { (extExpr, vit') <- encapsulateScalars vit expr
gckeller's avatar
gckeller committed
139 140
                                  ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
                                  }
gckeller's avatar
gckeller committed
141 142
       }

143 144
encapsulateScalars _ (_fvs, AnnLam _bndr _expr) 
  = panic "encapsulateScalars AnnLam doesn't match up"
gckeller's avatar
gckeller committed
145

146
encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2) 
gckeller's avatar
gckeller committed
147 148
  = do { varsS <- varsSimple fvs 
       ; case (vi, varsS) of
149
           (VISimple, True) -> do { let (e', vt') = liftSimple vt ce
gckeller's avatar
gckeller committed
150 151 152 153
                                  -- ; checkTreeAnnM vt' e'
                                  -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
                                  ; return (e', vt')
                                  }
154 155
           _                -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1
                                  ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2
gckeller's avatar
gckeller committed
156
                                  ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
gckeller's avatar
gckeller committed
157 158
                                  }
       }
159 160 161

encapsulateScalars _  (_fvs, AnnApp _ce1 _ce2)                           
  = panic "encapsulateScalars AnnApp doesn't match up"
gckeller's avatar
gckeller committed
162
  
163
encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts) 
gckeller's avatar
gckeller committed
164 165
  = do { varsS <- varsSimple fvs 
       ; case (vi, varsS) of
166 167
           (VISimple, True) -> return $ liftSimple vt ce
           _                -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut
gckeller's avatar
gckeller committed
168 169 170
                                  ; extAltsVits  <- zipWithM expAlt altVits alts
                                  ; let (extAlts, altVits') = unzip extAltsVits
                                  ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
gckeller's avatar
gckeller committed
171 172
                                  }
       }
173
  where
gckeller's avatar
gckeller committed
174
    expAlt vt (con, bndrs, expr) 
175
      = do { (extExpr, vt') <- encapsulateScalars vt expr
gckeller's avatar
gckeller committed
176
           ; return ((con, bndrs, extExpr), vt')
gckeller's avatar
gckeller committed
177 178
           }
           
179 180
encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts) 
  = panic "encapsulateScalars AnnCase doesn't match up"
gckeller's avatar
gckeller committed
181
  
182
encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2) 
gckeller's avatar
gckeller committed
183 184
  = do { varsS <- varsSimple fvs 
       ; case (vi, varsS) of
185 186 187
           (VISimple, True) -> return $ liftSimple vt ce
           _                -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1
                                  ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2
gckeller's avatar
gckeller committed
188
                                  ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
gckeller's avatar
gckeller committed
189 190 191
                                  }
       }

192 193
encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)       
  = panic "encapsulateScalars AnnLet nonrec doesn't match up"
gckeller's avatar
gckeller committed
194
         
195
encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr) 
gckeller's avatar
gckeller committed
196 197
  = do { varsS <- varsSimple fvs 
       ; case (vi, varsS) of 
198
           (VISimple, True) -> return $ liftSimple vt ce
gckeller's avatar
gckeller committed
199 200
           _                -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
                                  ; let (extBnds, vtBnds') = unzip extBndsVts
201
                                  ; (extExpr, vtB') <- encapsulateScalars vtB expr
gckeller's avatar
gckeller committed
202 203
                                  ; let vt' = VITNode vi (vtB':vtBnds')
                                  ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
gckeller's avatar
gckeller committed
204 205 206 207
                                  }
       }                            
    where
      expBndg vit (bndr, expr) 
208
        = do { (extExpr, vit') <- encapsulateScalars vit expr
gckeller's avatar
gckeller committed
209
             ; return  ((bndr, extExpr), vit')
gckeller's avatar
gckeller committed
210 211
             }
       
212 213
encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2)       
  = panic "encapsulateScalars AnnLet rec doesn't match up"
gckeller's avatar
gckeller committed
214

215 216
encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion)
  = do { (extExpr, vit') <- encapsulateScalars  vit expr
gckeller's avatar
gckeller committed
217
       ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
gckeller's avatar
gckeller committed
218 219
       }
       
220 221
encapsulateScalars  _ (_fvs, AnnCast _expr _coercion) 
  = panic "encapsulateScalars AnnCast rec doesn't match up"
gckeller's avatar
gckeller committed
222
    
223 224
encapsulateScalars _ _  
  = panic "encapsulateScalars case not handled"
gckeller's avatar
gckeller committed
225

226 227 228 229 230 231 232
-- Lambda-lift the given expression and apply it to the abstracted free variables.
--
-- If the expression is a case expression scrutinising anything but a primitive type, then lift
-- each alternative individually.
--
liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts) 
gckeller's avatar
gckeller committed
233
  | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),  
234 235 236 237 238 239 240
    (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon])   -- FIXME: shouldn't be hardcoded
     = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))      
  where 
    (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $ 
                        zipWith  (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits
          
liftSimple viTree ae@(fvs, _annEx) 
gckeller's avatar
gckeller committed
241 242
  = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
  where
243 244
    mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
    mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
245

246 247 248 249 250 251 252 253
    mkViTreeApps vi []      = vi
    mkViTreeApps vi (_:vs)  = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
    
    vars    = varSetElems fvs
    viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
    
    mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
    mkAnnLam bndr ce = AnnLam bndr ce         
gckeller's avatar
gckeller committed
254
      
255 256 257
    mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
    mkAnnLams (fv, aex') []     = (fv, aex')  -- fv should be empty. check!
    mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
gckeller's avatar
gckeller committed
258
      
259 260 261 262 263 264 265 266 267
    mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
    mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
      
    mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
    mkAnnApps (fv, aex') [] = (fv, aex')
    mkAnnApps ae (v:vs) = 
      let
        (fv, aex') = mkAnnApps ae vs
      in (extendVarSet fv v, mkAnnApp (fv, aex') v)
gckeller's avatar
gckeller committed
268

269 270
-- |Vectorise an expression.
--
gckeller's avatar
gckeller committed
271 272 273 274 275
vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
-- vectExpr e vi | not (checkTree vi (deAnnotate e))
--   = pprPanic "vectExpr" (ppr $ deAnnotate e)
 
vectExpr (_, AnnVar v)  _ 
276 277
  = vectVar v

gckeller's avatar
gckeller committed
278
vectExpr (_, AnnLit lit) _
279
  = vectConst $ Lit lit
280

gckeller's avatar
gckeller committed
281 282
vectExpr e@(_, AnnLam bndr _) vt
  | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
283
  | otherwise = cantVectorise "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e))
284

285 286 287
  -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
  --   its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
  --   happy.
288
-- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
gckeller's avatar
gckeller committed
289
vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)  _
290 291 292 293 294 295 296
  | v == pAT_ERROR_ID
  = do { (vty, lty) <- vectAndLiftType ty
       ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
       }
  where
    err' = deAnnotate err

297 298
  -- type application (handle multiple consecutive type applications simultaneously to ensure the
  -- PA dictionaries are put at the right places)
gckeller's avatar
gckeller committed
299
vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
300
  | isAnnTypeArg arg
301 302 303 304
  = vectPolyApp e
    
  -- 'Int', 'Float', or 'Double' literal
  -- FIXME: this needs to be generalised
gckeller's avatar
gckeller committed
305
vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
306 307 308 309 310 311 312 313 314
  | Just con <- isDataConId_maybe v
  , is_special_con con
  = do
      let vexpr = App (Var v) (Lit lit)
      lexpr <- liftPD vexpr
      return (vexpr, lexpr)
  where
    is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]

315
  -- value application (dictionary or user value)
gckeller's avatar
gckeller committed
316
vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2]) 
317 318 319 320
  | isPredTy arg_ty   -- dictionary application (whose result is not a dictionary)
  = vectPolyApp e
  | otherwise         -- user value
  = do {   -- vectorise the types
gckeller's avatar
gckeller committed
321
       ; varg_ty <- vectType arg_ty 
322 323 324
       ; vres_ty <- vectType res_ty

           -- vectorise the function and argument expression
gckeller's avatar
gckeller committed
325 326
       ; vfn  <- vectExpr fn  vit1
       ; varg <- vectExpr arg vit2
327 328 329 330

           -- the vectorised function is a closure; apply it to the vectorised argument
       ; mkClosureApp varg_ty vres_ty vfn varg
       }
331 332 333
  where
    (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn

gckeller's avatar
gckeller committed
334
vectExpr (_, AnnCase scrut bndr ty alts)  vt
335 336
  | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
  , isAlgTyCon tycon
gckeller's avatar
gckeller committed
337
  = vectAlgCase tycon ty_args scrut bndr ty alts vt
338
  | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) 
339 340 341
  where
    scrut_ty = exprType (deAnnotate scrut)

gckeller's avatar
gckeller committed
342
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2]) 
343
  = do
344
      vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1)
gckeller's avatar
gckeller committed
345
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
346 347
      return $ vLet (vNonRec vbndr vrhs) vbody

gckeller's avatar
gckeller committed
348
vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
349 350 351
  = do
      (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                $ liftM2 (,)
gckeller's avatar
gckeller committed
352 353
                                  (zipWith3M vect_rhs bndrs rhss vtBnds)
                                  (vectExpr body vtB)
354 355 356 357
      return $ vLet (vRec vbndrs vrhss) vbody
  where
    (bndrs, rhss) = unzip bs

gckeller's avatar
gckeller committed
358 359 360
    vect_rhs bndr rhs vt = localV
                         . inBind bndr
                         . liftM (\(_,_,z)->z)
361
                         $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt)
gckeller's avatar
gckeller committed
362
    zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
363

gckeller's avatar
gckeller committed
364 365
vectExpr (_, AnnTick tickish expr)  (VITNode _ [vit])
  = liftM (vTick tickish) (vectExpr expr vit)
366

gckeller's avatar
gckeller committed
367
vectExpr (_, AnnType ty) _
368
  = liftM vType (vectType ty)
369

370
vectExpr e vit = cantVectorise "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text ("  " ++ show vit))
371

372 373 374 375 376
-- |Vectorise an expression that *may* have an outer lambda abstraction.
--
-- We do not handle type variables at this point, as they will already have been stripped off by
-- 'vectPolyExpr'.  We also only have to worry about one set of dictionary arguments as we (1) only
-- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
377
--
378 379 380 381 382
vectFnExpr :: Bool             -- ^ If we process the RHS of a binding, whether that binding should
                               --   be inlined
           -> Bool             -- ^ Whether the binding is a loop breaker
           -> [Var]            -- ^ Names of function in same recursive binding group
           -> CoreExprWithFVs  -- ^ Expression to vectorise; must have an outer `AnnLam`
gckeller's avatar
gckeller committed
383
           -> VITree
384
           -> VM (Inline, Bool, VExpr)
gckeller's avatar
gckeller committed
385 386 387
-- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
--   = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
388 389 390 391
      -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
  | isId bndr
    && isPredTy (idType bndr)
  = do { vBndr <- vectBndr bndr
gckeller's avatar
gckeller committed
392
       ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
393 394 395 396
       ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
       }
      -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
  | isId bndr
397
  = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt)
398
    `orElseV` 
gckeller's avatar
gckeller committed
399 400
    mark inlineMe False (vectLam inline loop_breaker expr vt)
vectFnExpr _ _ _  e vt
401
      -- not an abstraction: vectorise as a vanilla expression
gckeller's avatar
gckeller committed
402
  = mark DontInline False $ vectExpr e vt
403

404 405
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
406

407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
-- |Vectorise type and dictionary applications.
--
-- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
-- involve two sets of type variables and dictionaries.  Consider,
--
-- > class C a where
-- >   m :: D b => b -> a
--
-- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
--
vectPolyApp :: CoreExprWithFVs -> VM VExpr
vectPolyApp e0
  = case e4 of
      (_, AnnVar var)
        -> do {   -- get the vectorised form of the variable
              ; vVar <- lookupVar var
              ; traceVt "vectPolyApp of" (ppr var)

                  -- vectorise type and dictionary arguments
              ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
              ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
              ; vTysOuter   <- mapM vectType     tysOuter
              ; vTysInner   <- mapM vectType     tysInner
              
              ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter

              ; case vVar of
                  Local (vv, lv)
                    -> do { MASSERT( null dictsInner )    -- local vars cannot be class selectors
                          ; traceVt "  LOCAL" (text "")
                          ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
                          }
                  Global vv
                    | isDictComp var                      -- dictionary computation
                    -> do {   -- in a dictionary computation, the innermost, non-empty set of
                              -- arguments are non-vectorised arguments, where no 'PA'dictionaries
                              -- are needed for the type variables
                          ; ve <- if null dictsInner
                                  then 
                                    return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
                                  else 
                                    reconstructOuter 
                                      (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
                          ; traceVt "  GLOBAL (dict):" (ppr ve)
                          ; vectConst ve
                          }
                    | otherwise                           -- non-dictionary computation
                    -> do { MASSERT( null dictsInner )
                          ; ve <- reconstructOuter (Var vv)
                          ; traceVt "  GLOBAL (non-dict):" (ppr ve)
                          ; vectConst ve
                          }
              }
      _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
  where
    -- if there is only one set of variables or dictionaries, it will be the outer set
    (e1, dictsOuter) = collectAnnDictArgs e0
    (e2, tysOuter)   = collectAnnTypeArgs e1
    (e3, dictsInner) = collectAnnDictArgs e2
    (e4, tysInner)   = collectAnnTypeArgs e3
    --
    isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
    
-- |Vectorise the body of a dfun.  
--
-- Dictionary computations are special for the following reasons.  The application of dictionary
-- functions are always saturated, so there is no need to create closures.  Dictionary computations
-- don't depend on array values, so they are always scalar computations whose result we can
-- replicate (instead of executing them in parallel).
--
-- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
--     computation.  Consequently, the variable case needs to deal with cases where binders are
--     in the vectoriser environments and where that is not the case.
--
vectDictExpr :: CoreExpr -> VM CoreExpr
vectDictExpr (Var var)
  = do { mb_scope <- lookupVar_maybe var
       ; case mb_scope of
           Nothing                -> return $ Var var   -- binder from within the dict. computation
           Just (Local (vVar, _)) -> return $ Var vVar  -- local vectorised variable
           Just (Global vVar)     -> return $ Var vVar  -- global vectorised variable
       }
vectDictExpr (Lit lit)
  = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
vectDictExpr (Lam bndr e)
  = Lam bndr <$> vectDictExpr e
vectDictExpr (App fn arg)
  = App <$> vectDictExpr fn <*> vectDictExpr arg
vectDictExpr (Case e bndr ty alts)
  = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
  where
    vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
    --
    vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
      where
        dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
    vectDictAltCon (LitAlt lit)      = return $ LitAlt lit
    vectDictAltCon DEFAULT           = return DEFAULT
vectDictExpr (Let bnd body)
  = Let <$> vectDictBind bnd <*> vectDictExpr body
  where
    vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
    vectDictBind (Rec bnds)      = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
vectDictExpr e@(Cast _e _coe)
  = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
vectDictExpr (Tick tickish e)
  = Tick tickish <$> vectDictExpr e
vectDictExpr (Type ty)
  = Type <$> vectType ty
vectDictExpr (Coercion coe)
  = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)

519 520 521
-- |Vectorise an expression of functional type, where all arguments and the result are of primitive
-- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
-- which does not contain any subcomputations that involve parallel arrays.  Such functionals do not
gckeller's avatar
gckeller committed
522
-- require the full blown vectorisation transformation; instead, they can be lifted by application
523
-- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
524
--
525 526 527 528
-- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
-- instead they become dictionaries of vectorised methods).  We treat them differently, though see
-- "Note [Scalar dfuns]" in 'Vectorise'.
--
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
vectScalarFunMaybe :: CoreExpr   -- ^ Expression to be vectorised
                   -> VITree     -- ^ Vectorisation information
                   -> VM VExpr
vectScalarFunMaybe expr  (VITNode VIEncaps _) = vectScalarFun expr
vectScalarFunMaybe _expr _                    = noV $ ptext (sLit "not a scalar function")

-- |Vectorise an expression of functional type by lifting it by an application of a member of the
-- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)  This is only a valid strategy if the
-- function does not contain parallel subcomputations and has only 'Scalar' types in its result and
-- arguments — this is a predcondition for calling this function.
--
-- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
-- instead they become dictionaries of vectorised methods).  We treat them differently, though see
-- "Note [Scalar dfuns]" in 'Vectorise'.
--
vectScalarFun :: CoreExpr -> VM VExpr
vectScalarFun expr 
  = do 
    { traceVt "vectScalarFun" (ppr expr) 
    ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
    ; mkScalarFun arg_tys res_ty expr
    }
551

552 553 554
-- Generate code for a scalar function by generating a scalar closure.  If the function is a
-- dictionary function, vectorise it as dictionary code.
-- 
555 556
mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
mkScalarFun arg_tys res_ty expr
557 558 559 560 561 562
  | isPredTy res_ty
  = do { vExpr <- vectDictExpr expr
       ; return (vExpr, unused)
       }
  | otherwise
  = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit "  ::") <+> ppr (mkFunTys arg_tys res_ty)
563 564

       ; fn_var  <- hoistExpr (fsLit "fn") expr DontInline
565 566 567 568 569 570
       ; zipf    <- zipScalars arg_tys res_ty
       ; clo     <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
       ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
       ; lclo    <- liftPD (Var clo_var)
       ; return (Var clo_var, lclo)
       }
571 572
  where
    unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
573

574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
-- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
-- 
-- In other words, all methods in that dictionary are scalar functions — to be vectorised with
-- 'vectScalarFun'.  The dictionary "function" itself may be a constant, though.
--
-- NB: You may think that we could implement this function guided by the struture of the Core
--     expression of the right-hand side of the dictionary function.  We cannot proceed like this as
--     'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
--     to the Core code of the unvectorised dfun.
--
-- Here an example — assume,
--
-- > class Eq a where { (==) :: a -> a -> Bool }
-- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
-- > {-# VECTORISE SCALAR instance Eq (a, b) }
--
-- The unvectorised dfun for the above instance has the following signature:
--
-- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
--
-- We generate the following (scalar) vectorised dfun (liberally using TH notation):
--
-- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
-- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
-- >                D:V:Eq $(vectScalarFun True recFns 
-- >                         [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
--
-- NB:
-- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
-- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
--   the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
--
vectScalarDFun :: Var        -- ^ Original dfun
               -> VM CoreExpr
608
vectScalarDFun var
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
  = do {   -- bring the type variables into scope
       ; mapM_ defLocalTyVar tvs

           -- vectorise dictionary argument types and generate variables for them
       ; vTheta     <- mapM vectType theta
       ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
       ; let vThetaVars = varsToCoreExprs vThetaBndr
       
           -- vectorise superclass dictionaries and methods as scalar expressions
       ; thetaVars  <- mapM (newLocalVar (fsLit "d")) theta
       ; thetaExprs <- zipWithM unVectDict theta vThetaVars
       ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
             dict           = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
             scsOps         = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
                                  selIds
624
       ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665

           -- vectorised applications of the class-dictionary data constructor
       ; Just vDataCon <- lookupDataCon dataCon
       ; vTys          <- mapM vectType tys
       ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)

       ; return $ mkLams (tvs ++ vThetaBndr) vBody
       }
  where
    ty                = varType var
    (tvs, theta, pty) = tcSplitSigmaTy  ty        -- 'theta' is the instance context
    (cls, tys)        = tcSplitDFunHead pty       -- 'pty' is the instance head
    selIds            = classAllSelIds cls
    dataCon           = classDataCon cls

-- Build a value of the dictionary before vectorisation from original, unvectorised type and an
-- expression computing the vectorised dictionary.
--
-- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
-- the unvectorised version, thus:
--
-- > D:C op1 .. opm
-- > where
-- >   opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
--
-- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
--
unVectDict :: Type -> CoreExpr -> VM CoreExpr
unVectDict ty e 
  = do { vTys <- mapM vectType tys
       ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
       ; scOps <- zipWithM fromVect methTys meths
       ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
       }
  where
    (tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
    cls                            = case tyConClass_maybe tycon of
                                       Just cls -> cls
                                       Nothing  -> panic "Vectorise.Exp.unVectDict: no class"
    selIds                         = classAllSelIds cls

666
-- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
667 668
--
-- All non-dictionary free variables go into the closure's environment, whereas the dictionary
669 670
-- variables are passed explicit (as conventional arguments) into the body during closure
-- construction.
671 672 673 674
--
vectLam :: Bool             -- ^ When the RHS of a binding, whether that binding should be inlined.
        -> Bool             -- ^ Whether the binding is a loop breaker.
        -> CoreExprWithFVs  -- ^ Body of abstraction.
gckeller's avatar
gckeller committed
675
        -> VITree
676
        -> VM VExpr
gckeller's avatar
gckeller committed
677
vectLam inline loop_breaker expr@(fvs, AnnLam _ _)  vi
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
 = do { let (bndrs, body) = collectAnnValBinders expr

          -- grab the in-scope type variables
      ; tyvars <- localTyVars

          -- collect and vectorise all /local/ free variables
      ; vfvs <- readLEnv $ \env ->
                  [ (var, fromJust mb_vv) 
                  | var <- varSetElems fvs
                  , let mb_vv = lookupVarEnv (local_vars env) var
                  , isJust mb_vv         -- its local == is in local var env
                  ]
          -- separate dictionary from non-dictionary variables in the free variable set
      ; let (vvs_dict, vvs_nondict)     = partition (isPredTy . varType . fst) vfvs
            (_fvs_dict, vfvs_dict)      = unzip vvs_dict
            (fvs_nondict, vfvs_nondict) = unzip vvs_nondict

          -- compute the type of the vectorised closure
      ; arg_tys <- mapM (vectType . idType) bndrs
      ; res_ty  <- vectType (exprType $ deAnnotate body)

      ; let arity      = length fvs_nondict + length bndrs
            vfvs_dict' = map vectorised vfvs_dict
      ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
        . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
        $ do {   -- generate the vectorised body of the lambda abstraction
             ; lc              <- builtin liftingContext
gckeller's avatar
gckeller committed
705 706 707
             ;  let viBody = stripLams expr vi
             -- ; checkTreeAnnM vi expr
             ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
708 709 710 711 712

             ; vbody' <- break_loop lc res_ty vbody
             ; return $ vLams lc vbndrs vbody'
             }
      }
713
  where
gckeller's avatar
gckeller committed
714 715 716
    stripLams  (_, AnnLam _ e)  (VITNode _ [vt]) = stripLams e vt
    stripLams _ vi = vi
    
717 718 719
    maybe_inline n | inline    = Inline n
                   | otherwise = DontInline

720 721 722 723 724
    -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
    -- to the /lifted/ version of the function body.  The termination tests checks if the lifting
    -- context is empty.  If so, it returns an empty array of the (lifted) result type instead of
    -- executing the function body.  This is the test from the last line (defining \mathcal{L}')
    -- in Figure 6 of HtM.
725 726
    break_loop lc ty (ve, le)
      | loop_breaker
727 728 729 730 731 732
      = do { empty <- emptyPD ty
           ; lty   <- mkPDataType ty
           ; return (ve, mkWildCase (Var lc) intPrimTy lty
                           [(DEFAULT, [], le),
                            (LitAlt (mkMachInt 0), [], empty)])
           }
733
      | otherwise = return (ve, le)
gckeller's avatar
gckeller committed
734
vectLam _ _ _ _ = panic "vectLam"
735

736 737 738
-- Vectorise an algebraic case expression.
--
-- We convert
739 740 741 742 743 744 745 746 747 748 749 750 751 752
--
--   case e :: t of v { ... }
--
-- to
--
--   V:    let v' = e in case v' of _ { ... }
--   L:    let v' = e in case v' `cast` ... of _ { ... }
--
--   When lifting, we have to do it this way because v must have the type
--   [:V(T):] but the scrutinee must be cast to the representation type. We also
--   have to handle the case where v is a wild var correctly.
--

-- FIXME: this is too lazy
gckeller's avatar
gckeller committed
753 754
vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type  
            -> [(AltCon, [Var], CoreExprWithFVs)]  -> VITree
755
            -> VM VExpr
gckeller's avatar
gckeller committed
756
vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
757
  = do
gckeller's avatar
gckeller committed
758
      vscrut         <- vectExpr scrut scrutVit
759
      (vty, lty)     <- vectAndLiftType ty
gckeller's avatar
gckeller committed
760
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
761 762
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

gckeller's avatar
gckeller committed
763
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
764
  = do
gckeller's avatar
gckeller committed
765
      vscrut         <- vectExpr scrut scrutVit
766
      (vty, lty)     <- vectAndLiftType ty
gckeller's avatar
gckeller committed
767
      (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
768 769
      return $ vCaseDEFAULT vscrut vbndr vty lty vbody

gckeller's avatar
gckeller committed
770
vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
771 772
  = do
      (vty, lty) <- vectAndLiftType ty
gckeller's avatar
gckeller committed
773
      vexpr      <- vectExpr scrut scrutVit
774 775 776
      (vbndr, (vbndrs, (vect_body, lift_body)))
         <- vect_scrut_bndr
          . vectBndrsIn bndrs
gckeller's avatar
gckeller committed
777
          $ vectExpr body altVit
778
      let (vect_bndrs, lift_bndrs) = unzip vbndrs
779
      (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
780
      vect_dc <- maybeV dataConErr (lookupDataCon dc)
781 782 783 784 785 786 787 788 789 790 791

      let vcase = mk_wild_case vscrut vty vect_dc  vect_bndrs vect_body
          lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body

      return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
  where
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
                    | otherwise         = vectBndrIn bndr

    mk_wild_case expr ty dc bndrs body
      = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
792 793
      
    dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
794

gckeller's avatar
gckeller committed
795
vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
796
  = do
797
      vect_tc     <- maybeV tyConErr (lookupTyCon tycon)
798 799 800 801 802 803 804 805
      (vty, lty)  <- vectAndLiftType ty

      let arity = length (tyConDataCons vect_tc)
      sel_ty <- builtin (selTy arity)
      sel_bndr <- newLocalVar (fsLit "sel") sel_ty
      let sel = Var sel_bndr

      (vbndr, valts) <- vect_scrut_bndr
gckeller's avatar
gckeller committed
806
                      $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
807 808
      let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts

gckeller's avatar
gckeller committed
809
      vexpr <- vectExpr scrut scrutVit
810
      (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827

      let (vect_bodies, lift_bodies) = unzip vbodies

      vdummy <- newDummyVar (exprType vect_scrut)
      ldummy <- newDummyVar (exprType lift_scrut)
      let vect_case = Case vect_scrut vdummy vty
                           (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)

      lc <- builtin liftingContext
      lbody <- combinePD vty (Var lc) sel lift_bodies
      let lift_case = Case lift_scrut ldummy lty
                           [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
                             lbody)]

      return . vLet (vNonRec vbndr vexpr)
             $ (vect_case, lift_case)
  where
828 829
    tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)

830 831 832 833 834 835 836 837 838 839 840
    vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
                    | otherwise         = vectBndrIn bndr

    alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts

    cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
    cmp DEFAULT       DEFAULT       = EQ
    cmp DEFAULT       _             = LT
    cmp _             DEFAULT       = GT
    cmp _             _             = panic "vectAlgCase/cmp"

gckeller's avatar
gckeller committed
841
    proc_alt arity sel _ lty ((DataAlt dc, bndrs, body),  vi)
842
      = do
843
          vect_dc <- maybeV dataConErr (lookupDataCon dc)
844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
          let ntag = dataConTagZ vect_dc
              tag  = mkDataConTag vect_dc
              fvs  = freeVarsOf body `delVarSetList` bndrs

          sel_tags  <- liftM (`App` sel) (builtin (selTags arity))
          lc        <- builtin liftingContext
          elems     <- builtin (selElements arity ntag)

          (vbndrs, vbody)
            <- vectBndrsIn bndrs
             . localV
             $ do
                 binds    <- mapM (pack_var (Var lc) sel_tags tag)
                           . filter isLocalId
                           $ varSetElems fvs
gckeller's avatar
gckeller committed
859
                 (ve, le) <- vectExpr body vi
860 861 862 863 864 865 866 867 868
                 return (ve, Case (elems `App` sel) lc lty
                             [(DEFAULT, [], (mkLets (concat binds) le))])
                 -- empty    <- emptyPD vty
                 -- return (ve, Case (elems `App` sel) lc lty
                 --             [(DEFAULT, [], Let (NonRec flags_var flags_expr)
                 --                             $ mkLets (concat binds) le),
                 --               (LitAlt (mkMachInt 0), [], empty)])
          let (vect_bndrs, lift_bndrs) = unzip vbndrs
          return (vect_dc, vect_bndrs, lift_bndrs, vbody)
869 870 871
      where
        dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)

872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889

    proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"

    mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)

    pack_var len tags t v
      = do
          r <- lookupVar v
          case r of
            Local (vv, lv) ->
              do
                lv'  <- cloneVar lv
                expr <- packByTagPD (idType vv) (Var lv) len tags t
                updLEnv (\env -> env { local_vars = extendVarEnv
                                                (local_vars env) v (vv, lv') })
                return [(NonRec lv' expr)]

            _ -> return []
gckeller's avatar
gckeller committed
890
            
891
vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _)
gckeller's avatar
gckeller committed
892 893
  = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)



-- Support to compute information for vectorisation avoidance ------------------

-- Annotation for Core AST nodes that describes how they should be handled during vectorisation
-- and especially if vectorisation of the corresponding computation can be avoided.
--
data VectAvoidInfo = VIParr       -- tree contains parallel computations
                   | VISimple     -- result type is scalar & no parallel subcomputation
                   | VIComplex    -- any result type, no parallel subcomputation
                   | VIEncaps     -- tree encapsulated by 'liftSimple'
                   deriving (Eq, Show)

-- Instead of integrating the vectorisation avoidance information into Core expression, we keep
-- them in a separate tree (that structurally mirrors the Core expression that it annotates).
--
data VITree = VITNode VectAvoidInfo [VITree] 
            deriving (Show)

-- Is any of the tree nodes a 'VIPArr' node?
--
anyVIPArr :: [VITree] -> Bool
anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr))

-- Compute Core annotations to determine for which subexpressions we can avoid vectorisation
--
-- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure,
--        that there are no free variables in encapsulated lambda expressions     
vectAvoidInfo :: CoreExprWithFVs -> VM VITree
vectAvoidInfo ce@(_, AnnVar v) 
  = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
       ; viTrace ce vi [] 
       ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
       ; return $ VITNode vi []
       }

vectAvoidInfo ce@(_, AnnLit _)     
  = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce  
       ; viTrace ce vi [] 
       ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce)       
       ; return $ VITNode vi []
       }

vectAvoidInfo ce@(_, AnnApp e1 e2)  
  = do { vt1  <- vectAvoidInfo e1  
       ; vt2  <- vectAvoidInfo e2  
       ; vi <- if anyVIPArr [vt1, vt2] 
                    then return VIParr
                    else vectAvoidInfoType $ exprType $ deAnnotate ce
       ; viTrace ce vi [vt1, vt2]                     
       ; return $ VITNode vi [vt1, vt2]
       }

vectAvoidInfo ce@(_, AnnLam _var body) 
  = do { vt@(VITNode vi _) <- vectAvoidInfo body  
       ; viTrace ce vi [vt]
       ; let resultVI | vi == VIParr = VIParr
                      | otherwise    = VIComplex
       ; return $ VITNode resultVI [vt]
       }

vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body)  
  = do { vtE <- vectAvoidInfo expr 
       ; vtB <- vectAvoidInfo body 
       ; vi <- if anyVIPArr [vtE, vtB] 
                 then return VIParr
                 else vectAvoidInfoType $ exprType $ deAnnotate ce
       ; viTrace ce vi [vtE, vtB]                                          
       ; return $ VITNode vi [vtE, vtB]
       }

vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body)  
  = do { let (_, exprs) = unzip bnds
       ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs
       ; if (anyVIPArr vtBnds)
            then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs    
                    ; vtB <- vectAvoidInfo body 
                    ; return (VITNode VIParr (vtB: vtBnds'))
                    }
            else do { vtB@(VITNode vib _)  <- vectAvoidInfo body 
                    ; ni <- if (vib == VIParr) 
                               then return VIParr
                               else vectAvoidInfoType $ exprType $ deAnnotate ce
                    ; viTrace ce ni (vtB : vtBnds)                                        
                    ; return $ VITNode ni (vtB : vtBnds)
                    }
       }

vectAvoidInfo ce@(_, AnnCase expr _var _ty alts) 
  = do { vtExpr <- vectAvoidInfo expr 
       ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts
       ; ni <- if anyVIPArr (vtExpr : vtAlts)
                 then return VIParr
                 else vectAvoidInfoType $ exprType $ deAnnotate ce
       ; viTrace ce ni (vtExpr  : vtAlts)
       ; return $ VITNode ni (vtExpr: vtAlts)
       }

vectAvoidInfo (_, AnnCast expr _)       
  = do { vt@(VITNode vi _) <- vectAvoidInfo expr 
       ; return $ VITNode vi [vt]
       }

vectAvoidInfo (_, AnnTick _ expr)       
  = do { vt@(VITNode vi _) <- vectAvoidInfo expr 
       ; return $ VITNode vi [vt]
       }

vectAvoidInfo (_, AnnType {})  
  = return $ VITNode VISimple []

vectAvoidInfo (_, AnnCoercion {}) 
  = return $ VITNode VISimple []

-- Compute vectorisation avoidance information for a type.
--
vectAvoidInfoType :: Type -> VM VectAvoidInfo   
vectAvoidInfoType ty 
  | maybeParrTy ty = return VIParr
  | otherwise      
  = do { sType <- isSimpleType ty
       ; if sType 
           then return VISimple
           else return VIComplex
       }

-- Checks whether the type might be a parallel array type.  In particular, if the outermost
-- constructor is a type family, we conservatively assume that it may be a parallel array type.
--
maybeParrTy :: Type -> Bool
maybeParrTy ty 
    | Just ty'         <- coreView ty            = maybeParrTy ty'
    | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon  
                                                 || or (map maybeParrTy ts)
maybeParrTy _  = False               

-- FIXME: This should not be hardcoded.
isSimpleType :: Type -> VM Bool
isSimpleType ty 
  | Just (c, _cs) <- splitTyConApp_maybe ty 
  = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
{-
    = do { globals <- globalScalarTyCons
          ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)  
          ; return (elemNameSet (tyConName c) globals ) 
          } 
  -}
  | Nothing <- splitTyConApp_maybe ty
    = return False
isSimpleType ty 
  = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)

varsSimple :: VarSet -> VM Bool
varsSimple vs 
  = do { varTypes <- mapM isSimpleType $ map varType $  varSetElems vs
       ; return $ and varTypes
       }

viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM ()
viTrace ce vi vTs
  = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") 
            (ppr $ deAnnotate ce)


gckeller's avatar
gckeller committed
1057
{-
1058
---- Sanity check  of the tree, for debugging only
gckeller's avatar
gckeller committed
1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
checkTree :: VITree -> CoreExpr -> Bool
checkTree  (VITNode _ []) (Type _ty) 
  = True
      
checkTree  (VITNode _ [])  (Var _v)  
  = True
  
checkTree  (VITNode _ [])  (Lit _)
  = True
         
checkTree (VITNode _ [vit]) (Tick _ expr)
  = checkTree vit expr
  
gckeller's avatar
gckeller committed
1072
checkTree (VITNode _ [vit]) (Lam _ expr) 
gckeller's avatar
gckeller committed
1073 1074 1075 1076
  = checkTree vit expr
  
checkTree (VITNode _ [vit1, vit2])  (App ce1 ce2) 
  = (checkTree vit1 ce1) && (checkTree vit2 ce2) 
1077
        
gckeller's avatar
gckeller committed
1078 1079 1080 1081 1082
checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts) 
  = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
  where
    checkAlt vt (_, _, expr) = checkTree vt expr
    
gckeller's avatar
gckeller committed
1083
checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2) 
gckeller's avatar
gckeller committed
1084 1085
  = (checkTree vt1 expr1) && (checkTree vt2 expr2) 

gckeller's avatar
gckeller committed
1086
checkTree (VITNode _ (vtB : vtBnds))  (Let (Rec bndngs) expr) 
gckeller's avatar
gckeller committed
1087 1088 1089 1090
  = (and $ zipWith checkBndr vtBnds bndngs) && 
    (checkTree vtB expr)
 where 
   checkBndr vt (_, e) = checkTree vt e
1091
              
gckeller's avatar
gckeller committed
1092 1093 1094 1095
checkTree (VITNode _ [vit]) (Cast expr _)
  = checkTree vit expr

checkTree _ _ = False
1096

gckeller's avatar
gckeller committed
1097 1098 1099