PrelRules.hs 64.5 KB
Newer Older
Austin Seipp's avatar
Austin Seipp committed
1 2 3
{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998

4 5
\section[ConFold]{Constant Folder}

6 7 8 9
Conceptually, constant folding should be parameterized with the kind
of target machine to get identical behaviour during compilation time
and runtime. We cheat a little bit here...

10 11
ToDo:
   check boundaries before folding, e.g. we can fold the Float addition
12
   (i1 + i2) only if it results in a valid Float.
Austin Seipp's avatar
Austin Seipp committed
13
-}
14

15 16
{-# LANGUAGE CPP, RankNTypes #-}
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
17

Sylvain Henry's avatar
Sylvain Henry committed
18 19 20 21 22 23
module PrelRules
   ( primOpRules
   , builtinRules
   , caseRules
   )
where
24 25

#include "HsVersions.h"
pcapriotti's avatar
pcapriotti committed
26
#include "../includes/MachDeps.h"
27

28 29
import GhcPrelude

30
import {-# SOURCE #-} MkId ( mkPrimOpId, magicDictId )
31

32
import CoreSyn
33
import MkCore
34 35
import Id
import Literal
36
import CoreOpt     ( exprIsLiteral_maybe )
37
import PrimOp      ( PrimOp(..), tagToEnumKey )
38
import TysWiredIn
39
import TysPrim
40 41
import TyCon       ( tyConDataCons_maybe, isAlgTyCon, isEnumerationTyCon
                   , isNewTyCon, unwrapNewTyCon_maybe, tyConDataCons )
42
import DataCon     ( DataCon, dataConTagZ, dataConTyCon, dataConWorkId )
43
import CoreUtils   ( cheapEqExpr, exprIsHNF, exprType )
44
import CoreUnfold  ( exprIsConApp_maybe )
45
import Type
46
import OccName     ( occNameFS )
47
import PrelNames
48 49
import Maybes      ( orElse )
import Name        ( Name, nameOccName )
50
import Outputable
51
import FastString
52
import BasicTypes
53
import DynFlags
54
import Platform
55
import Util
56
import Coercion     (mkUnbranchedAxInstCo,mkSymCo,Role(..))
57

58
import Control.Applicative ( Alternative(..) )
59

60
import Control.Monad
quchen's avatar
quchen committed
61
import qualified Control.Monad.Fail as MonadFail
62
import Data.Bits as Bits
63
import qualified Data.ByteString as BS
64
import Data.Int
65
import Data.Ratio
66
import Data.Word
67

Austin Seipp's avatar
Austin Seipp committed
68
{-
69 70
Note [Constant folding]
~~~~~~~~~~~~~~~~~~~~~~~
71
primOpRules generates a rewrite rule for each primop
72 73
These rules do what is often called "constant folding"
E.g. the rules for +# might say
74 75
        4 +# 5 = 9
Well, of course you'd need a lot of rules if you did it
76 77 78
like that, so we use a BuiltinRule instead, so that we
can match in any two literal values.  So the rule is really
more like
dterei's avatar
dterei committed
79
        (Lit x) +# (Lit y) = Lit (x+#y)
80 81
where the (+#) on the rhs is done at compile time

82
That is why these rules are built in here.
Austin Seipp's avatar
Austin Seipp committed
83
-}
84

85
primOpRules :: Name -> PrimOp -> Maybe CoreRule
86 87
    -- ToDo: something for integer-shift ops?
    --       NotOp
88 89 90 91 92
primOpRules nm TagToEnumOp = mkPrimOpRule nm 2 [ tagToEnumRule ]
primOpRules nm DataToTagOp = mkPrimOpRule nm 2 [ dataToTagRule ]

-- Int operations
primOpRules nm IntAddOp    = mkPrimOpRule nm 2 [ binaryLit (intOp2 (+))
93
                                               , identityDynFlags zeroi ]
94
primOpRules nm IntSubOp    = mkPrimOpRule nm 2 [ binaryLit (intOp2 (-))
95 96
                                               , rightIdentityDynFlags zeroi
                                               , equalArgs >> retLit zeroi ]
97 98
primOpRules nm IntMulOp    = mkPrimOpRule nm 2 [ binaryLit (intOp2 (*))
                                               , zeroElem zeroi
99
                                               , identityDynFlags onei ]
100 101
primOpRules nm IntQuotOp   = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
                                               , leftZero zeroi
102 103
                                               , rightIdentityDynFlags onei
                                               , equalArgs >> retLit onei ]
104 105 106
primOpRules nm IntRemOp    = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 rem)
                                               , leftZero zeroi
                                               , do l <- getLiteral 1
107 108 109 110 111
                                                    dflags <- getDynFlags
                                                    guard (l == onei dflags)
                                                    retLit zeroi
                                               , equalArgs >> retLit zeroi
                                               , equalArgs >> retLit zeroi ]
112 113 114 115 116 117 118 119 120
primOpRules nm AndIOp      = mkPrimOpRule nm 2 [ binaryLit (intOp2 (.&.))
                                               , idempotent
                                               , zeroElem zeroi ]
primOpRules nm OrIOp       = mkPrimOpRule nm 2 [ binaryLit (intOp2 (.|.))
                                               , idempotent
                                               , identityDynFlags zeroi ]
primOpRules nm XorIOp      = mkPrimOpRule nm 2 [ binaryLit (intOp2 xor)
                                               , identityDynFlags zeroi
                                               , equalArgs >> retLit zeroi ]
121 122
primOpRules nm NotIOp      = mkPrimOpRule nm 1 [ unaryLit complementOp
                                               , inversePrimOp NotIOp ]
123 124
primOpRules nm IntNegOp    = mkPrimOpRule nm 1 [ unaryLit negOp
                                               , inversePrimOp IntNegOp ]
125
primOpRules nm ISllOp      = mkPrimOpRule nm 2 [ shiftRule (const Bits.shiftL)
126
                                               , rightIdentityDynFlags zeroi ]
127
primOpRules nm ISraOp      = mkPrimOpRule nm 2 [ shiftRule (const Bits.shiftR)
128
                                               , rightIdentityDynFlags zeroi ]
129
primOpRules nm ISrlOp      = mkPrimOpRule nm 2 [ shiftRule shiftRightLogical
130
                                               , rightIdentityDynFlags zeroi ]
131 132 133

-- Word operations
primOpRules nm WordAddOp   = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (+))
134
                                               , identityDynFlags zerow ]
135
primOpRules nm WordSubOp   = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (-))
136 137
                                               , rightIdentityDynFlags zerow
                                               , equalArgs >> retLit zerow ]
138
primOpRules nm WordMulOp   = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (*))
139
                                               , identityDynFlags onew ]
140
primOpRules nm WordQuotOp  = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
141
                                               , rightIdentityDynFlags onew ]
142
primOpRules nm WordRemOp   = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem)
143 144 145 146 147 148
                                               , leftZero zerow
                                               , do l <- getLiteral 1
                                                    dflags <- getDynFlags
                                                    guard (l == onew dflags)
                                                    retLit zerow
                                               , equalArgs >> retLit zerow ]
149
primOpRules nm AndOp       = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (.&.))
150
                                               , idempotent
151 152
                                               , zeroElem zerow ]
primOpRules nm OrOp        = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (.|.))
153
                                               , idempotent
154
                                               , identityDynFlags zerow ]
155
primOpRules nm XorOp       = mkPrimOpRule nm 2 [ binaryLit (wordOp2 xor)
156 157
                                               , identityDynFlags zerow
                                               , equalArgs >> retLit zerow ]
158 159
primOpRules nm NotOp       = mkPrimOpRule nm 1 [ unaryLit complementOp
                                               , inversePrimOp NotOp ]
160 161
primOpRules nm SllOp       = mkPrimOpRule nm 2 [ shiftRule (const Bits.shiftL) ]
primOpRules nm SrlOp       = mkPrimOpRule nm 2 [ shiftRule shiftRightLogical ]
162 163

-- coercions
164
primOpRules nm Word2IntOp     = mkPrimOpRule nm 1 [ liftLitDynFlags word2IntLit
pcapriotti's avatar
pcapriotti committed
165
                                                  , inversePrimOp Int2WordOp ]
166
primOpRules nm Int2WordOp     = mkPrimOpRule nm 1 [ liftLitDynFlags int2WordLit
pcapriotti's avatar
pcapriotti committed
167
                                                  , inversePrimOp Word2IntOp ]
168 169
primOpRules nm Narrow8IntOp   = mkPrimOpRule nm 1 [ liftLit narrow8IntLit
                                                  , subsumedByPrimOp Narrow8IntOp
170 171
                                                  , Narrow8IntOp `subsumesPrimOp` Narrow16IntOp
                                                  , Narrow8IntOp `subsumesPrimOp` Narrow32IntOp ]
172
primOpRules nm Narrow16IntOp  = mkPrimOpRule nm 1 [ liftLit narrow16IntLit
173
                                                  , subsumedByPrimOp Narrow8IntOp
174
                                                  , subsumedByPrimOp Narrow16IntOp
175
                                                  , Narrow16IntOp `subsumesPrimOp` Narrow32IntOp ]
pcapriotti's avatar
pcapriotti committed
176
primOpRules nm Narrow32IntOp  = mkPrimOpRule nm 1 [ liftLit narrow32IntLit
177 178
                                                  , subsumedByPrimOp Narrow8IntOp
                                                  , subsumedByPrimOp Narrow16IntOp
179
                                                  , subsumedByPrimOp Narrow32IntOp
pcapriotti's avatar
pcapriotti committed
180
                                                  , removeOp32 ]
181 182
primOpRules nm Narrow8WordOp  = mkPrimOpRule nm 1 [ liftLit narrow8WordLit
                                                  , subsumedByPrimOp Narrow8WordOp
183 184
                                                  , Narrow8WordOp `subsumesPrimOp` Narrow16WordOp
                                                  , Narrow8WordOp `subsumesPrimOp` Narrow32WordOp ]
185
primOpRules nm Narrow16WordOp = mkPrimOpRule nm 1 [ liftLit narrow16WordLit
186
                                                  , subsumedByPrimOp Narrow8WordOp
187
                                                  , subsumedByPrimOp Narrow16WordOp
188
                                                  , Narrow16WordOp `subsumesPrimOp` Narrow32WordOp ]
pcapriotti's avatar
pcapriotti committed
189
primOpRules nm Narrow32WordOp = mkPrimOpRule nm 1 [ liftLit narrow32WordLit
190 191
                                                  , subsumedByPrimOp Narrow8WordOp
                                                  , subsumedByPrimOp Narrow16WordOp
192
                                                  , subsumedByPrimOp Narrow32WordOp
pcapriotti's avatar
pcapriotti committed
193
                                                  , removeOp32 ]
194 195 196 197 198 199
primOpRules nm OrdOp          = mkPrimOpRule nm 1 [ liftLit char2IntLit
                                                  , inversePrimOp ChrOp ]
primOpRules nm ChrOp          = mkPrimOpRule nm 1 [ do [Lit lit] <- getArgs
                                                       guard (litFitsInChar lit)
                                                       liftLit int2CharLit
                                                  , inversePrimOp OrdOp ]
200 201 202 203 204 205 206 207 208 209 210 211 212 213
primOpRules nm Float2IntOp    = mkPrimOpRule nm 1 [ liftLit float2IntLit ]
primOpRules nm Int2FloatOp    = mkPrimOpRule nm 1 [ liftLit int2FloatLit ]
primOpRules nm Double2IntOp   = mkPrimOpRule nm 1 [ liftLit double2IntLit ]
primOpRules nm Int2DoubleOp   = mkPrimOpRule nm 1 [ liftLit int2DoubleLit ]
-- SUP: Not sure what the standard says about precision in the following 2 cases
primOpRules nm Float2DoubleOp = mkPrimOpRule nm 1 [ liftLit float2DoubleLit ]
primOpRules nm Double2FloatOp = mkPrimOpRule nm 1 [ liftLit double2FloatLit ]

-- Float
primOpRules nm FloatAddOp   = mkPrimOpRule nm 2 [ binaryLit (floatOp2 (+))
                                                , identity zerof ]
primOpRules nm FloatSubOp   = mkPrimOpRule nm 2 [ binaryLit (floatOp2 (-))
                                                , rightIdentity zerof ]
primOpRules nm FloatMulOp   = mkPrimOpRule nm 2 [ binaryLit (floatOp2 (*))
214 215
                                                , identity onef
                                                , strengthReduction twof FloatAddOp  ]
216 217 218
                         -- zeroElem zerof doesn't hold because of NaN
primOpRules nm FloatDivOp   = mkPrimOpRule nm 2 [ guardFloatDiv >> binaryLit (floatOp2 (/))
                                                , rightIdentity onef ]
219 220
primOpRules nm FloatNegOp   = mkPrimOpRule nm 1 [ unaryLit negOp
                                                , inversePrimOp FloatNegOp ]
221 222 223 224 225 226 227

-- Double
primOpRules nm DoubleAddOp   = mkPrimOpRule nm 2 [ binaryLit (doubleOp2 (+))
                                                 , identity zerod ]
primOpRules nm DoubleSubOp   = mkPrimOpRule nm 2 [ binaryLit (doubleOp2 (-))
                                                 , rightIdentity zerod ]
primOpRules nm DoubleMulOp   = mkPrimOpRule nm 2 [ binaryLit (doubleOp2 (*))
228 229
                                                 , identity oned
                                                 , strengthReduction twod DoubleAddOp  ]
230 231 232
                          -- zeroElem zerod doesn't hold because of NaN
primOpRules nm DoubleDivOp   = mkPrimOpRule nm 2 [ guardDoubleDiv >> binaryLit (doubleOp2 (/))
                                                 , rightIdentity oned ]
233 234
primOpRules nm DoubleNegOp   = mkPrimOpRule nm 1 [ unaryLit negOp
                                                 , inversePrimOp DoubleNegOp ]
235 236

-- Relational operators
237

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
primOpRules nm IntEqOp    = mkRelOpRule nm (==) [ litEq True ]
primOpRules nm IntNeOp    = mkRelOpRule nm (/=) [ litEq False ]
primOpRules nm CharEqOp   = mkRelOpRule nm (==) [ litEq True ]
primOpRules nm CharNeOp   = mkRelOpRule nm (/=) [ litEq False ]

primOpRules nm IntGtOp    = mkRelOpRule nm (>)  [ boundsCmp Gt ]
primOpRules nm IntGeOp    = mkRelOpRule nm (>=) [ boundsCmp Ge ]
primOpRules nm IntLeOp    = mkRelOpRule nm (<=) [ boundsCmp Le ]
primOpRules nm IntLtOp    = mkRelOpRule nm (<)  [ boundsCmp Lt ]

primOpRules nm CharGtOp   = mkRelOpRule nm (>)  [ boundsCmp Gt ]
primOpRules nm CharGeOp   = mkRelOpRule nm (>=) [ boundsCmp Ge ]
primOpRules nm CharLeOp   = mkRelOpRule nm (<=) [ boundsCmp Le ]
primOpRules nm CharLtOp   = mkRelOpRule nm (<)  [ boundsCmp Lt ]

Ben Gamari's avatar
Ben Gamari committed
253 254 255 256 257 258 259 260 261 262 263 264 265
primOpRules nm FloatGtOp  = mkFloatingRelOpRule nm (>)
primOpRules nm FloatGeOp  = mkFloatingRelOpRule nm (>=)
primOpRules nm FloatLeOp  = mkFloatingRelOpRule nm (<=)
primOpRules nm FloatLtOp  = mkFloatingRelOpRule nm (<)
primOpRules nm FloatEqOp  = mkFloatingRelOpRule nm (==)
primOpRules nm FloatNeOp  = mkFloatingRelOpRule nm (/=)

primOpRules nm DoubleGtOp = mkFloatingRelOpRule nm (>)
primOpRules nm DoubleGeOp = mkFloatingRelOpRule nm (>=)
primOpRules nm DoubleLeOp = mkFloatingRelOpRule nm (<=)
primOpRules nm DoubleLtOp = mkFloatingRelOpRule nm (<)
primOpRules nm DoubleEqOp = mkFloatingRelOpRule nm (==)
primOpRules nm DoubleNeOp = mkFloatingRelOpRule nm (/=)
266 267 268 269 270 271

primOpRules nm WordGtOp   = mkRelOpRule nm (>)  [ boundsCmp Gt ]
primOpRules nm WordGeOp   = mkRelOpRule nm (>=) [ boundsCmp Ge ]
primOpRules nm WordLeOp   = mkRelOpRule nm (<=) [ boundsCmp Le ]
primOpRules nm WordLtOp   = mkRelOpRule nm (<)  [ boundsCmp Lt ]
primOpRules nm WordEqOp   = mkRelOpRule nm (==) [ litEq True ]
pcapriotti's avatar
pcapriotti committed
272
primOpRules nm WordNeOp   = mkRelOpRule nm (/=) [ litEq False ]
273

274 275
primOpRules nm AddrAddOp  = mkPrimOpRule nm 2 [ rightIdentityDynFlags zeroi ]

276 277 278
primOpRules nm SeqOp      = mkPrimOpRule nm 4 [ seqRule ]
primOpRules nm SparkOp    = mkPrimOpRule nm 4 [ sparkRule ]

279
primOpRules _  _          = Nothing
280

Austin Seipp's avatar
Austin Seipp committed
281 282 283
{-
************************************************************************
*                                                                      *
284
\subsection{Doing the business}
Austin Seipp's avatar
Austin Seipp committed
285 286 287
*                                                                      *
************************************************************************
-}
288

289
-- useful shorthands
290 291 292 293 294 295
mkPrimOpRule :: Name -> Int -> [RuleM CoreExpr] -> Maybe CoreRule
mkPrimOpRule nm arity rules = Just $ mkBasicRule nm arity (msum rules)

mkRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
            -> [RuleM CoreExpr] -> Maybe CoreRule
mkRelOpRule nm cmp extra
Ben Gamari's avatar
Ben Gamari committed
296 297
  = mkPrimOpRule nm 2 $
    binaryCmpLit cmp : equal_rule : extra
298
  where
Ben Gamari's avatar
Ben Gamari committed
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        -- x `cmp` x does not depend on x, so
        -- compute it for the arbitrary value 'True'
        -- and use that result
    equal_rule = do { equalArgs
                    ; dflags <- getDynFlags
                    ; return (if cmp True True
                              then trueValInt  dflags
                              else falseValInt dflags) }

{- Note [Rules for floating-point comparisons]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We need different rules for floating-point values because for floats
it is not true that x = x (for NaNs); so we do not want the equal_rule
rule that mkRelOpRule uses.

Note also that, in the case of equality/inequality, we do /not/
want to switch to a case-expression.  For example, we do not want
to convert
   case (eqFloat# x 3.8#) of
     True -> this
     False -> that
to
  case x of
    3.8#::Float# -> this
    _            -> that
See Trac #9238.  Reason: comparing floating-point values for equality
delicate, and we don't want to implement that delicacy in the code for
case expressions.  So we make it an invariant of Core that a case
expression never scrutinises a Float# or Double#.

This transformation is what the litEq rule does;
see Note [The litEq rule: converting equality to case].
So we /refrain/ from using litEq for mkFloatingRelOpRule.
-}
333 334

mkFloatingRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
Ben Gamari's avatar
Ben Gamari committed
335 336 337 338
                    -> Maybe CoreRule
-- See Note [Rules for floating-point comparisons]
mkFloatingRelOpRule nm cmp
  = mkPrimOpRule nm 2 [binaryCmpLit cmp]
339 340

-- common constants
341 342 343 344 345 346
zeroi, onei, zerow, onew :: DynFlags -> Literal
zeroi dflags = mkMachInt  dflags 0
onei  dflags = mkMachInt  dflags 1
zerow dflags = mkMachWord dflags 0
onew  dflags = mkMachWord dflags 1

347
zerof, onef, twof, zerod, oned, twod :: Literal
348 349
zerof = mkMachFloat 0.0
onef  = mkMachFloat 1.0
350
twof  = mkMachFloat 2.0
351 352
zerod = mkMachDouble 0.0
oned  = mkMachDouble 1.0
353
twod  = mkMachDouble 2.0
354

355
cmpOp :: DynFlags -> (forall a . Ord a => a -> a -> Bool)
356
      -> Literal -> Literal -> Maybe CoreExpr
357
cmpOp dflags cmp = go
358
  where
359 360
    done True  = Just $ trueValInt  dflags
    done False = Just $ falseValInt dflags
361

362
    -- These compares are at different types
363 364 365 366 367 368 369
    go (MachChar i1)   (MachChar i2)   = done (i1 `cmp` i2)
    go (MachInt i1)    (MachInt i2)    = done (i1 `cmp` i2)
    go (MachInt64 i1)  (MachInt64 i2)  = done (i1 `cmp` i2)
    go (MachWord i1)   (MachWord i2)   = done (i1 `cmp` i2)
    go (MachWord64 i1) (MachWord64 i2) = done (i1 `cmp` i2)
    go (MachFloat i1)  (MachFloat i2)  = done (i1 `cmp` i2)
    go (MachDouble i1) (MachDouble i2) = done (i1 `cmp` i2)
Ian Lynagh's avatar
Ian Lynagh committed
370
    go _               _               = Nothing
371 372

--------------------------
373

374 375
negOp :: DynFlags -> Literal -> Maybe CoreExpr  -- Negate
negOp _      (MachFloat 0.0)  = Nothing  -- can't represent -0.0 as a Rational
376
negOp dflags (MachFloat f)    = Just (mkFloatVal dflags (-f))
377
negOp _      (MachDouble 0.0) = Nothing
378
negOp dflags (MachDouble d)   = Just (mkDoubleVal dflags (-d))
379 380
negOp dflags (MachInt i)      = intResult dflags (-i)
negOp _      _                = Nothing
381

382 383 384 385 386
complementOp :: DynFlags -> Literal -> Maybe CoreExpr  -- Binary complement
complementOp dflags (MachWord i) = wordResult dflags (complement i)
complementOp dflags (MachInt i)  = intResult  dflags (complement i)
complementOp _      _            = Nothing

387
--------------------------
388 389
intOp2 :: (Integral a, Integral b)
       => (a -> b -> Integer)
390
       -> DynFlags -> Literal -> Literal -> Maybe CoreExpr
391
intOp2 = intOp2' . const
392

393 394 395 396 397 398 399 400 401
intOp2' :: (Integral a, Integral b)
        => (DynFlags -> a -> b -> Integer)
        -> DynFlags -> Literal -> Literal -> Maybe CoreExpr
intOp2' op dflags (MachInt i1) (MachInt i2) =
  let o = op dflags
  in  intResult dflags (fromInteger i1 `o` fromInteger i2)
intOp2' _  _      _            _            = Nothing  -- Could find LitLit

shiftRightLogical :: DynFlags -> Integer -> Int -> Integer
402
-- Shift right, putting zeros in rather than sign-propagating as Bits.shiftR would do
403
-- Do this by converting to Word and back.  Obviously this won't work for big
404
-- values, but its ok as we use it here
405 406 407 408
shiftRightLogical dflags x n
  | wordSizeInBits dflags == 32 = fromIntegral (fromInteger x `shiftR` n :: Word32)
  | wordSizeInBits dflags == 64 = fromIntegral (fromInteger x `shiftR` n :: Word64)
  | otherwise = panic "shiftRightLogical: unsupported word size"
409

410
--------------------------
411 412 413 414
retLit :: (DynFlags -> Literal) -> RuleM CoreExpr
retLit l = do dflags <- getDynFlags
              return $ Lit $ l dflags

415 416
wordOp2 :: (Integral a, Integral b)
        => (a -> b -> Integer)
417 418 419 420 421
        -> DynFlags -> Literal -> Literal -> Maybe CoreExpr
wordOp2 op dflags (MachWord w1) (MachWord w2)
    = wordResult dflags (fromInteger w1 `op` fromInteger w2)
wordOp2 _ _ _ _ = Nothing  -- Could find LitLit

422
shiftRule :: (DynFlags -> Integer -> Int -> Integer) -> RuleM CoreExpr
423
                 -- Shifts take an Int; hence third arg of op is Int
424
-- See Note [Guarding against silly shifts]
425
shiftRule shift_op
426 427 428
  = do { dflags <- getDynFlags
       ; [e1, Lit (MachInt shift_len)] <- getArgs
       ; case e1 of
Austin Seipp's avatar
Austin Seipp committed
429
           _ | shift_len == 0
430 431
             -> return e1
             | shift_len < 0 || wordSizeInBits dflags < shift_len
Austin Seipp's avatar
Austin Seipp committed
432
             -> return (mkRuntimeErrorApp rUNTIME_ERROR_ID wordPrimTy
433
                                        ("Bad shift length" ++ show shift_len))
434 435 436 437 438 439

           -- Do the shift at type Integer, but shift length is Int
           Lit (MachInt x)
             -> let op = shift_op dflags
                in  liftMaybe $ intResult dflags (x `op` fromInteger shift_len)

440
           Lit (MachWord x)
441 442
             -> let op = shift_op dflags
                in  liftMaybe $ wordResult dflags (x `op` fromInteger shift_len)
443

444 445 446 447
           _ -> mzero }

wordSizeInBits :: DynFlags -> Integer
wordSizeInBits dflags = toInteger (platformWordSize (targetPlatform dflags) `shiftL` 3)
448

449
--------------------------
450 451
floatOp2 :: (Rational -> Rational -> Rational)
         -> DynFlags -> Literal -> Literal
Ian Lynagh's avatar
Ian Lynagh committed
452
         -> Maybe (Expr CoreBndr)
453 454
floatOp2 op dflags (MachFloat f1) (MachFloat f2)
  = Just (mkFloatVal dflags (f1 `op` f2))
455
floatOp2 _ _ _ _ = Nothing
456 457

--------------------------
458 459
doubleOp2 :: (Rational -> Rational -> Rational)
          -> DynFlags -> Literal -> Literal
Ian Lynagh's avatar
Ian Lynagh committed
460
          -> Maybe (Expr CoreBndr)
461 462
doubleOp2 op dflags (MachDouble f1) (MachDouble f2)
  = Just (mkDoubleVal dflags (f1 `op` f2))
463
doubleOp2 _ _ _ _ = Nothing
464 465

--------------------------
Ben Gamari's avatar
Ben Gamari committed
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
{- Note [The litEq rule: converting equality to case]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This stuff turns
     n ==# 3#
into
     case n of
       3# -> True
       m  -> False

This is a Good Thing, because it allows case-of case things
to happen, and case-default absorption to happen.  For
example:

     if (n ==# 3#) || (n ==# 4#) then e1 else e2
will transform to
     case n of
       3# -> e1
       4# -> e1
       m  -> e2
(modulo the usual precautions to avoid duplicating e1)
-}
487

488 489 490 491
litEq :: Bool  -- True <=> equality, False <=> inequality
      -> RuleM CoreExpr
litEq is_eq = msum
  [ do [Lit lit, expr] <- getArgs
492 493
       dflags <- getDynFlags
       do_lit_eq dflags lit expr
494
  , do [expr, Lit lit] <- getArgs
495 496
       dflags <- getDynFlags
       do_lit_eq dflags lit expr ]
497
  where
498
    do_lit_eq dflags lit expr = do
499
      guard (not (litIsLifted lit))
500
      return (mkWildCase expr (literalType lit) intPrimTy
501 502
                    [(DEFAULT,    [], val_if_neq),
                     (LitAlt lit, [], val_if_eq)])
503 504 505 506 507
      where
        val_if_eq  | is_eq     = trueValInt  dflags
                   | otherwise = falseValInt dflags
        val_if_neq | is_eq     = falseValInt dflags
                   | otherwise = trueValInt  dflags
508

509 510 511 512

-- | Check if there is comparison with minBound or maxBound, that is
-- always true or false. For instance, an Int cannot be smaller than its
-- minBound, so we can replace such comparison with False.
513 514
boundsCmp :: Comparison -> RuleM CoreExpr
boundsCmp op = do
515
  dflags <- getDynFlags
516
  [a, b] <- getArgs
517
  liftMaybe $ mkRuleFn dflags op a b
518 519 520

data Comparison = Gt | Ge | Lt | Le

521
mkRuleFn :: DynFlags -> Comparison -> CoreExpr -> CoreExpr -> Maybe CoreExpr
522 523 524 525 526 527 528 529
mkRuleFn dflags Gt (Lit lit) _ | isMinBound dflags lit = Just $ falseValInt dflags
mkRuleFn dflags Le (Lit lit) _ | isMinBound dflags lit = Just $ trueValInt  dflags
mkRuleFn dflags Ge _ (Lit lit) | isMinBound dflags lit = Just $ trueValInt  dflags
mkRuleFn dflags Lt _ (Lit lit) | isMinBound dflags lit = Just $ falseValInt dflags
mkRuleFn dflags Ge (Lit lit) _ | isMaxBound dflags lit = Just $ trueValInt  dflags
mkRuleFn dflags Lt (Lit lit) _ | isMaxBound dflags lit = Just $ falseValInt dflags
mkRuleFn dflags Gt _ (Lit lit) | isMaxBound dflags lit = Just $ falseValInt dflags
mkRuleFn dflags Le _ (Lit lit) | isMaxBound dflags lit = Just $ trueValInt  dflags
530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
mkRuleFn _ _ _ _                                       = Nothing

isMinBound :: DynFlags -> Literal -> Bool
isMinBound _      (MachChar c)   = c == minBound
isMinBound dflags (MachInt i)    = i == tARGET_MIN_INT dflags
isMinBound _      (MachInt64 i)  = i == toInteger (minBound :: Int64)
isMinBound _      (MachWord i)   = i == 0
isMinBound _      (MachWord64 i) = i == 0
isMinBound _      _              = False

isMaxBound :: DynFlags -> Literal -> Bool
isMaxBound _      (MachChar c)   = c == maxBound
isMaxBound dflags (MachInt i)    = i == tARGET_MAX_INT dflags
isMaxBound _      (MachInt64 i)  = i == toInteger (maxBound :: Int64)
isMaxBound dflags (MachWord i)   = i == tARGET_MAX_WORD dflags
isMaxBound _      (MachWord64 i) = i == toInteger (maxBound :: Word64)
isMaxBound _      _              = False
547

548 549
-- | Create an Int literal expression while ensuring the given Integer is in the
-- target Int range
550
intResult :: DynFlags -> Integer -> Maybe CoreExpr
551
intResult dflags result = Just (Lit (mkMachIntWrap dflags result))
552

553 554
-- | Create a Word literal expression while ensuring the given Integer is in the
-- target Word range
555
wordResult :: DynFlags -> Integer -> Maybe CoreExpr
556
wordResult dflags result = Just (Lit (mkMachWordWrap dflags result))
557

pcapriotti's avatar
pcapriotti committed
558 559 560 561 562 563
inversePrimOp :: PrimOp -> RuleM CoreExpr
inversePrimOp primop = do
  [Var primop_id `App` e] <- getArgs
  matchPrimOpId primop primop_id
  return e

564 565 566 567 568 569 570 571 572 573 574
subsumesPrimOp :: PrimOp -> PrimOp -> RuleM CoreExpr
this `subsumesPrimOp` that = do
  [Var primop_id `App` e] <- getArgs
  matchPrimOpId that primop_id
  return (Var (mkPrimOpId this) `App` e)

subsumedByPrimOp :: PrimOp -> RuleM CoreExpr
subsumedByPrimOp primop = do
  [e@(Var primop_id `App` _)] <- getArgs
  matchPrimOpId primop primop_id
  return e
575 576 577 578 579

idempotent :: RuleM CoreExpr
idempotent = do [e1, e2] <- getArgs
                guard $ cheapEqExpr e1 e2
                return e1
580

Austin Seipp's avatar
Austin Seipp committed
581
{-
582 583 584 585 586 587 588 589 590 591 592
Note [Guarding against silly shifts]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider this code:

  import Data.Bits( (.|.), shiftL )
  chunkToBitmap :: [Bool] -> Word32
  chunkToBitmap chunk = foldr (.|.) 0 [ 1 `shiftL` n | (True,n) <- zip chunk [0..] ]

This optimises to:
Shift.$wgo = \ (w_sCS :: GHC.Prim.Int#) (w1_sCT :: [GHC.Types.Bool]) ->
    case w1_sCT of _ {
593
      [] -> 0##;
594 595 596 597 598
      : x_aAW xs_aAX ->
        case x_aAW of _ {
          GHC.Types.False ->
            case w_sCS of wild2_Xh {
              __DEFAULT -> Shift.$wgo (GHC.Prim.+# wild2_Xh 1) xs_aAX;
599
              9223372036854775807 -> 0## };
600 601 602 603 604 605 606
          GHC.Types.True ->
            case GHC.Prim.>=# w_sCS 64 of _ {
              GHC.Types.False ->
                case w_sCS of wild3_Xh {
                  __DEFAULT ->
                    case Shift.$wgo (GHC.Prim.+# wild3_Xh 1) xs_aAX of ww_sCW { __DEFAULT ->
                      GHC.Prim.or# (GHC.Prim.narrow32Word#
607
                                      (GHC.Prim.uncheckedShiftL# 1## wild3_Xh))
608 609 610 611
                                   ww_sCW
                     };
                  9223372036854775807 ->
                    GHC.Prim.narrow32Word#
612
!!!!-->                  (GHC.Prim.uncheckedShiftL# 1## 9223372036854775807)
613 614 615 616
                };
              GHC.Types.True ->
                case w_sCS of wild3_Xh {
                  __DEFAULT -> Shift.$wgo (GHC.Prim.+# wild3_Xh 1) xs_aAX;
617
                  9223372036854775807 -> 0##
618 619
                } } } }

Austin Seipp's avatar
Austin Seipp committed
620
Note the massive shift on line "!!!!".  It can't happen, because we've checked
621 622 623 624 625 626 627 628
that w < 64, but the optimiser didn't spot that. We DO NO want to constant-fold this!
Moreover, if the programmer writes (n `uncheckedShiftL` 9223372036854775807), we
can't constant fold it, but if it gets to the assember we get
     Error: operand type mismatch for `shl'

So the best thing to do is to rewrite the shift with a call to error,
when the second arg is stupid.

Austin Seipp's avatar
Austin Seipp committed
629 630
************************************************************************
*                                                                      *
631
\subsection{Vaguely generic functions}
Austin Seipp's avatar
Austin Seipp committed
632 633 634
*                                                                      *
************************************************************************
-}
635

636
mkBasicRule :: Name -> Int -> RuleM CoreExpr -> CoreRule
637
-- Gives the Rule the same name as the primop itself
638
mkBasicRule op_name n_args rm
639 640 641
  = BuiltinRule { ru_name = occNameFS (nameOccName op_name),
                  ru_fn = op_name,
                  ru_nargs = n_args,
642
                  ru_try = \ dflags in_scope _ -> runRuleM rm dflags in_scope }
643 644

newtype RuleM r = RuleM
645
  { runRuleM :: DynFlags -> InScopeEnv -> [CoreExpr] -> Maybe r }
646

Austin Seipp's avatar
Austin Seipp committed
647 648 649 650
instance Functor RuleM where
    fmap = liftM

instance Applicative RuleM where
651
    pure x = RuleM $ \_ _ _ -> Just x
Austin Seipp's avatar
Austin Seipp committed
652 653
    (<*>) = ap

654
instance Monad RuleM where
655
  RuleM f >>= g = RuleM $ \dflags iu e -> case f dflags iu e of
656
    Nothing -> Nothing
657
    Just r -> runRuleM (g r) dflags iu e
658
  fail = MonadFail.fail
659

quchen's avatar
quchen committed
660 661 662
instance MonadFail.MonadFail RuleM where
    fail _ = mzero

Austin Seipp's avatar
Austin Seipp committed
663
instance Alternative RuleM where
664 665 666
  empty = RuleM $ \_ _ _ -> Nothing
  RuleM f1 <|> RuleM f2 = RuleM $ \dflags iu args ->
    f1 dflags iu args <|> f2 dflags iu args
Austin Seipp's avatar
Austin Seipp committed
667

668
instance MonadPlus RuleM
669 670 671

instance HasDynFlags RuleM where
    getDynFlags = RuleM $ \dflags _ _ -> Just dflags
672 673 674 675 676 677

liftMaybe :: Maybe a -> RuleM a
liftMaybe Nothing = mzero
liftMaybe (Just x) = return x

liftLit :: (Literal -> Literal) -> RuleM CoreExpr
678 679 680 681 682
liftLit f = liftLitDynFlags (const f)

liftLitDynFlags :: (DynFlags -> Literal -> Literal) -> RuleM CoreExpr
liftLitDynFlags f = do
  dflags <- getDynFlags
683
  [Lit lit] <- getArgs
684
  return $ Lit (f dflags lit)
685

pcapriotti's avatar
pcapriotti committed
686
removeOp32 :: RuleM CoreExpr
pcapriotti's avatar
pcapriotti committed
687
removeOp32 = do
688 689 690 691 692 693
  dflags <- getDynFlags
  if wordSizeInBits dflags == 32
  then do
    [e] <- getArgs
    return e
  else mzero
pcapriotti's avatar
pcapriotti committed
694

695
getArgs :: RuleM [CoreExpr]
696
getArgs = RuleM $ \_ _ args -> Just args
697

698 699
getInScopeEnv :: RuleM InScopeEnv
getInScopeEnv = RuleM $ \_ iu _ -> Just iu
700 701 702 703

-- return the n-th argument of this rule, if it is a literal
-- argument indices start from 0
getLiteral :: Int -> RuleM Literal
704
getLiteral n = RuleM $ \_ _ exprs -> case drop n exprs of
705 706 707
  (Lit l:_) -> Just l
  _ -> Nothing

708
unaryLit :: (DynFlags -> Literal -> Maybe CoreExpr) -> RuleM CoreExpr
709
unaryLit op = do
710
  dflags <- getDynFlags
711
  [Lit l] <- getArgs
712
  liftMaybe $ op dflags (convFloating dflags l)
713

714
binaryLit :: (DynFlags -> Literal -> Literal -> Maybe CoreExpr) -> RuleM CoreExpr
715
binaryLit op = do
716
  dflags <- getDynFlags
717
  [Lit l1, Lit l2] <- getArgs
718
  liftMaybe $ op dflags (convFloating dflags l1) (convFloating dflags l2)
719

720 721 722 723 724
binaryCmpLit :: (forall a . Ord a => a -> a -> Bool) -> RuleM CoreExpr
binaryCmpLit op = do
  dflags <- getDynFlags
  binaryLit (\_ -> cmpOp dflags op)

725
leftIdentity :: Literal -> RuleM CoreExpr
726 727 728 729 730 731 732 733 734 735 736
leftIdentity id_lit = leftIdentityDynFlags (const id_lit)

rightIdentity :: Literal -> RuleM CoreExpr
rightIdentity id_lit = rightIdentityDynFlags (const id_lit)

identity :: Literal -> RuleM CoreExpr
identity lit = leftIdentity lit `mplus` rightIdentity lit

leftIdentityDynFlags :: (DynFlags -> Literal) -> RuleM CoreExpr
leftIdentityDynFlags id_lit = do
  dflags <- getDynFlags
737
  [Lit l1, e2] <- getArgs
738
  guard $ l1 == id_lit dflags
739 740
  return e2

741 742 743
rightIdentityDynFlags :: (DynFlags -> Literal) -> RuleM CoreExpr
rightIdentityDynFlags id_lit = do
  dflags <- getDynFlags
744
  [e1, Lit l2] <- getArgs
745
  guard $ l2 == id_lit dflags
746 747
  return e1

748 749
identityDynFlags :: (DynFlags -> Literal) -> RuleM CoreExpr
identityDynFlags lit = leftIdentityDynFlags lit `mplus` rightIdentityDynFlags lit
750

751
leftZero :: (DynFlags -> Literal) -> RuleM CoreExpr
752
leftZero zero = do
753
  dflags <- getDynFlags
754
  [Lit l1, _] <- getArgs
755 756
  guard $ l1 == zero dflags
  return $ Lit l1
757

758
rightZero :: (DynFlags -> Literal) -> RuleM CoreExpr
759
rightZero zero = do
760
  dflags <- getDynFlags
761
  [_, Lit l2] <- getArgs
762 763
  guard $ l2 == zero dflags
  return $ Lit l2
764

765
zeroElem :: (DynFlags -> Literal) -> RuleM CoreExpr
766 767 768 769 770 771 772 773 774
zeroElem lit = leftZero lit `mplus` rightZero lit

equalArgs :: RuleM ()
equalArgs = do
  [e1, e2] <- getArgs
  guard $ e1 `cheapEqExpr` e2

nonZeroLit :: Int -> RuleM ()
nonZeroLit n = getLiteral n >>= guard . not . isZeroLit
775

776 777 778
-- When excess precision is not requested, cut down the precision of the
-- Rational value to that of Float/Double. We confuse host architecture
-- and target architecture here, but it's convenient (and wrong :-).
779
convFloating :: DynFlags -> Literal -> Literal
ian@well-typed.com's avatar
ian@well-typed.com committed
780
convFloating dflags (MachFloat  f) | not (gopt Opt_ExcessPrecision dflags) =
781
   MachFloat  (toRational (fromRational f :: Float ))
ian@well-typed.com's avatar
ian@well-typed.com committed
782
convFloating dflags (MachDouble d) | not (gopt Opt_ExcessPrecision dflags) =
783
   MachDouble (toRational (fromRational d :: Double))
784
convFloating _ l = l
785

786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
guardFloatDiv :: RuleM ()
guardFloatDiv = do
  [Lit (MachFloat f1), Lit (MachFloat f2)] <- getArgs
  guard $ (f1 /=0 || f2 > 0) -- see Note [negative zero]
       && f2 /= 0            -- avoid NaN and Infinity/-Infinity

guardDoubleDiv :: RuleM ()
guardDoubleDiv = do
  [Lit (MachDouble d1), Lit (MachDouble d2)] <- getArgs
  guard $ (d1 /=0 || d2 > 0) -- see Note [negative zero]
       && d2 /= 0            -- avoid NaN and Infinity/-Infinity
-- Note [negative zero] Avoid (0 / -d), otherwise 0/(-1) reduces to
-- zero, but we might want to preserve the negative zero here which
-- is representable in Float/Double but not in (normalised)
-- Rational. (#3676) Perhaps we should generate (0 :% (-1)) instead?

802 803 804 805 806 807 808 809 810 811
strengthReduction :: Literal -> PrimOp -> RuleM CoreExpr
strengthReduction two_lit add_op = do -- Note [Strength reduction]
  arg <- msum [ do [arg, Lit mult_lit] <- getArgs
                   guard (mult_lit == two_lit)
                   return arg
              , do [Lit mult_lit, arg] <- getArgs
                   guard (mult_lit == two_lit)
                   return arg ]
  return $ Var (mkPrimOpId add_op) `App` arg `App` arg

Jan Stolarek's avatar
Jan Stolarek committed
812 813 814 815 816 817
-- Note [Strength reduction]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- This rule turns floating point multiplications of the form 2.0 * x and
-- x * 2.0 into x + x addition, because addition costs less than multiplication.
-- See #7116
818

819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
-- Note [What's true and false]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- trueValInt and falseValInt represent true and false values returned by
-- comparison primops for Char, Int, Word, Integer, Double, Float and Addr.
-- True is represented as an unboxed 1# literal, while false is represented
-- as 0# literal.
-- We still need Bool data constructors (True and False) to use in a rule
-- for constant folding of equal Strings

trueValInt, falseValInt :: DynFlags -> Expr CoreBndr
trueValInt  dflags = Lit $ onei  dflags -- see Note [What's true and false]
falseValInt dflags = Lit $ zeroi dflags

trueValBool, falseValBool :: Expr CoreBndr
trueValBool   = Var trueDataConId -- see Note [What's true and false]
falseValBool  = Var falseDataConId
836 837 838 839 840 841

ltVal, eqVal, gtVal :: Expr CoreBndr
ltVal = Var ltDataConId
eqVal = Var eqDataConId
gtVal = Var gtDataConId

842 843
mkIntVal :: DynFlags -> Integer -> Expr CoreBndr
mkIntVal dflags i = Lit (mkMachInt dflags i)
844 845 846 847
mkFloatVal :: DynFlags -> Rational -> Expr CoreBndr
mkFloatVal dflags f = Lit (convFloating dflags (MachFloat  f))
mkDoubleVal :: DynFlags -> Rational -> Expr CoreBndr
mkDoubleVal dflags d = Lit (convFloating dflags (MachDouble d))
848