Commit 193664d4 authored by Simon Peyton Jones's avatar Simon Peyton Jones Committed by David Feuer
Browse files

Re-engineer caseRules to add tagToEnum/dataToTag

See Note [Scrutinee Constant Folding] in SimplUtils

* Add cases for tagToEnum and dataToTag. This is the main new
  bit.  It allows the simplifier to remove the pervasive uses
  of     case tagToEnum (a > b) of
            False -> e1
            True  -> e2
  and replace it by the simpler
         case a > b of
            DEFAULT -> e1
            1#      -> e2
  See Note [caseRules for tagToEnum]
  and Note [caseRules for dataToTag] in PrelRules.

* This required some changes to the API of caseRules, and hence
  to code in SimplUtils.  See Note [Scrutinee Constant Folding]
  in SimplUtils.

* Avoid duplication of work in the (unusual) case of
     case BIG + 3# of b
       DEFAULT -> e1
       6#      -> e2

  Previously we got
     case BIG of
       DEFAULT -> let b = BIG + 3# in e1
       3#      -> let b = 6#       in e2

  Now we get
     case BIG of b#
       DEFAULT -> let b = b' + 3# in e1
       3#      -> let b = 6#      in e2

* Avoid duplicated code in caseRules

A knock-on refactoring:

* Move Note [Word/Int underflow/overflow] to Literal, as
  documentation to accompany mkMachIntWrap etc; and get
  rid of PrelRuls.intResult' in favour of mkMachIntWrap
parent 1cae73aa
...@@ -222,6 +222,24 @@ instance Ord Literal where ...@@ -222,6 +222,24 @@ instance Ord Literal where
~~~~~~~~~~~~ ~~~~~~~~~~~~
-} -}
{- Note [Word/Int underflow/overflow]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
According to the Haskell Report 2010 (Sections 18.1 and 23.1 about signed and
unsigned integral types): "All arithmetic is performed modulo 2^n, where n is
the number of bits in the type."
GHC stores Word# and Int# constant values as Integer. Core optimizations such
as constant folding must ensure that the Integer value remains in the valid
target Word/Int range (see #13172). The following functions are used to
ensure this.
Note that we *don't* warn the user about overflow. It's not done at runtime
either, and compilation of completely harmless things like
((124076834 :: Word32) + (2147483647 :: Word32))
doesn't yield a warning. Instead we simply squash the value into the *target*
Int/Word range.
-}
-- | Creates a 'Literal' of type @Int#@ -- | Creates a 'Literal' of type @Int#@
mkMachInt :: DynFlags -> Integer -> Literal mkMachInt :: DynFlags -> Integer -> Literal
mkMachInt dflags x = ASSERT2( inIntRange dflags x, integer x ) mkMachInt dflags x = ASSERT2( inIntRange dflags x, integer x )
...@@ -229,6 +247,7 @@ mkMachInt dflags x = ASSERT2( inIntRange dflags x, integer x ) ...@@ -229,6 +247,7 @@ mkMachInt dflags x = ASSERT2( inIntRange dflags x, integer x )
-- | Creates a 'Literal' of type @Int#@. -- | Creates a 'Literal' of type @Int#@.
-- If the argument is out of the (target-dependent) range, it is wrapped. -- If the argument is out of the (target-dependent) range, it is wrapped.
-- See Note [Word/Int underflow/overflow]
mkMachIntWrap :: DynFlags -> Integer -> Literal mkMachIntWrap :: DynFlags -> Integer -> Literal
mkMachIntWrap dflags i mkMachIntWrap dflags i
= MachInt $ case platformWordSize (targetPlatform dflags) of = MachInt $ case platformWordSize (targetPlatform dflags) of
...@@ -243,6 +262,7 @@ mkMachWord dflags x = ASSERT2( inWordRange dflags x, integer x ) ...@@ -243,6 +262,7 @@ mkMachWord dflags x = ASSERT2( inWordRange dflags x, integer x )
-- | Creates a 'Literal' of type @Word#@. -- | Creates a 'Literal' of type @Word#@.
-- If the argument is out of the (target-dependent) range, it is wrapped. -- If the argument is out of the (target-dependent) range, it is wrapped.
-- See Note [Word/Int underflow/overflow]
mkMachWordWrap :: DynFlags -> Integer -> Literal mkMachWordWrap :: DynFlags -> Integer -> Literal
mkMachWordWrap dflags i mkMachWordWrap dflags i
= MachWord $ case platformWordSize (targetPlatform dflags) of = MachWord $ case platformWordSize (targetPlatform dflags) of
...@@ -336,6 +356,7 @@ isLitValue_maybe _ = Nothing ...@@ -336,6 +356,7 @@ isLitValue_maybe _ = Nothing
-- makes sense, e.g. for 'Char', 'Int', 'Word' and 'LitInteger'. For -- makes sense, e.g. for 'Char', 'Int', 'Word' and 'LitInteger'. For
-- fixed-size integral literals, the result will be wrapped in -- fixed-size integral literals, the result will be wrapped in
-- accordance with the semantics of the target type. -- accordance with the semantics of the target type.
-- See Note [Word/Int underflow/overflow]
mapLitValue :: DynFlags -> (Integer -> Integer) -> Literal -> Literal mapLitValue :: DynFlags -> (Integer -> Integer) -> Literal -> Literal
mapLitValue _ f (MachChar c) = mkMachChar (fchar c) mapLitValue _ f (MachChar c) = mkMachChar (fchar c)
where fchar = chr . fromInteger . f . toInteger . ord where fchar = chr . fromInteger . f . toInteger . ord
......
...@@ -1682,6 +1682,8 @@ ltAlt a1 a2 = (a1 `cmpAlt` a2) == LT ...@@ -1682,6 +1682,8 @@ ltAlt a1 a2 = (a1 `cmpAlt` a2) == LT
cmpAltCon :: AltCon -> AltCon -> Ordering cmpAltCon :: AltCon -> AltCon -> Ordering
-- ^ Compares 'AltCon's within a single list of alternatives -- ^ Compares 'AltCon's within a single list of alternatives
-- DEFAULT comes out smallest, so that sorting by AltCon
-- puts alternatives in the order required by #case_invariants#
cmpAltCon DEFAULT DEFAULT = EQ cmpAltCon DEFAULT DEFAULT = EQ
cmpAltCon DEFAULT _ = LT cmpAltCon DEFAULT _ = LT
......
...@@ -35,8 +35,9 @@ import CoreOpt ( exprIsLiteral_maybe ) ...@@ -35,8 +35,9 @@ import CoreOpt ( exprIsLiteral_maybe )
import PrimOp ( PrimOp(..), tagToEnumKey ) import PrimOp ( PrimOp(..), tagToEnumKey )
import TysWiredIn import TysWiredIn
import TysPrim import TysPrim
import TyCon ( tyConDataCons_maybe, isEnumerationTyCon, isNewTyCon, unwrapNewTyCon_maybe ) import TyCon ( tyConDataCons_maybe, isEnumerationTyCon, isNewTyCon
import DataCon ( dataConTag, dataConTyCon, dataConWorkId ) , unwrapNewTyCon_maybe, tyConDataCons )
import DataCon ( DataCon, dataConTagZ, dataConTyCon, dataConWorkId )
import CoreUtils ( cheapEqExpr, exprIsHNF ) import CoreUtils ( cheapEqExpr, exprIsHNF )
import CoreUnfold ( exprIsConApp_maybe ) import CoreUnfold ( exprIsConApp_maybe )
import Type import Type
...@@ -538,51 +539,15 @@ isMaxBound dflags (MachWord i) = i == tARGET_MAX_WORD dflags ...@@ -538,51 +539,15 @@ isMaxBound dflags (MachWord i) = i == tARGET_MAX_WORD dflags
isMaxBound _ (MachWord64 i) = i == toInteger (maxBound :: Word64) isMaxBound _ (MachWord64 i) = i == toInteger (maxBound :: Word64)
isMaxBound _ _ = False isMaxBound _ _ = False
-- Note [Word/Int underflow/overflow]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- According to the Haskell Report 2010 (Sections 18.1 and 23.1 about signed and
-- unsigned integral types): "All arithmetic is performed modulo 2^n, where n is
-- the number of bits in the type."
--
-- GHC stores Word# and Int# constant values as Integer. Core optimizations such
-- as constant folding must ensure that the Integer value remains in the valid
-- target Word/Int range (see #13172). The following functions are used to
-- ensure this.
--
-- Note that we *don't* warn the user about overflow. It's not done at runtime
-- either, and compilation of completely harmless things like
-- ((124076834 :: Word32) + (2147483647 :: Word32))
-- doesn't yield a warning. Instead we simply squash the value into the *target*
-- Int/Word range.
-- | Ensure the given Integer is in the target Int range
intResult' :: DynFlags -> Integer -> Integer
intResult' dflags result = case platformWordSize (targetPlatform dflags) of
4 -> toInteger (fromInteger result :: Int32)
8 -> toInteger (fromInteger result :: Int64)
w -> panic ("intResult: Unknown platformWordSize: " ++ show w)
-- | Ensure the given Integer is in the target Word range
wordResult' :: DynFlags -> Integer -> Integer
wordResult' dflags result = case platformWordSize (targetPlatform dflags) of
4 -> toInteger (fromInteger result :: Word32)
8 -> toInteger (fromInteger result :: Word64)
w -> panic ("wordResult: Unknown platformWordSize: " ++ show w)
-- | Create an Int literal expression while ensuring the given Integer is in the -- | Create an Int literal expression while ensuring the given Integer is in the
-- target Int range -- target Int range
intResult :: DynFlags -> Integer -> Maybe CoreExpr intResult :: DynFlags -> Integer -> Maybe CoreExpr
intResult dflags result = Just (mkIntVal dflags (intResult' dflags result)) intResult dflags result = Just (Lit (mkMachIntWrap dflags result))
-- | Create a Word literal expression while ensuring the given Integer is in the -- | Create a Word literal expression while ensuring the given Integer is in the
-- target Word range -- target Word range
wordResult :: DynFlags -> Integer -> Maybe CoreExpr wordResult :: DynFlags -> Integer -> Maybe CoreExpr
wordResult dflags result = Just (mkWordVal dflags (wordResult' dflags result)) wordResult dflags result = Just (Lit (mkMachWordWrap dflags result))
inversePrimOp :: PrimOp -> RuleM CoreExpr inversePrimOp :: PrimOp -> RuleM CoreExpr
inversePrimOp primop = do inversePrimOp primop = do
...@@ -872,8 +837,6 @@ gtVal = Var gtDataConId ...@@ -872,8 +837,6 @@ gtVal = Var gtDataConId
mkIntVal :: DynFlags -> Integer -> Expr CoreBndr mkIntVal :: DynFlags -> Integer -> Expr CoreBndr
mkIntVal dflags i = Lit (mkMachInt dflags i) mkIntVal dflags i = Lit (mkMachInt dflags i)
mkWordVal :: DynFlags -> Integer -> Expr CoreBndr
mkWordVal dflags w = Lit (mkMachWord dflags w)
mkFloatVal :: DynFlags -> Rational -> Expr CoreBndr mkFloatVal :: DynFlags -> Rational -> Expr CoreBndr
mkFloatVal dflags f = Lit (convFloating dflags (MachFloat f)) mkFloatVal dflags f = Lit (convFloating dflags (MachFloat f))
mkDoubleVal :: DynFlags -> Rational -> Expr CoreBndr mkDoubleVal :: DynFlags -> Rational -> Expr CoreBndr
...@@ -921,7 +884,7 @@ tagToEnumRule = do ...@@ -921,7 +884,7 @@ tagToEnumRule = do
case splitTyConApp_maybe ty of case splitTyConApp_maybe ty of
Just (tycon, tc_args) | isEnumerationTyCon tycon -> do Just (tycon, tc_args) | isEnumerationTyCon tycon -> do
let tag = fromInteger i let tag = fromInteger i
correct_tag dc = (dataConTag dc - fIRST_TAG) == tag correct_tag dc = (dataConTagZ dc) == tag
(dc:rest) <- return $ filter correct_tag (tyConDataCons_maybe tycon `orElse` []) (dc:rest) <- return $ filter correct_tag (tyConDataCons_maybe tycon `orElse` [])
ASSERT(null rest) return () ASSERT(null rest) return ()
return $ mkTyApps (Var (dataConWorkId dc)) tc_args return $ mkTyApps (Var (dataConWorkId dc)) tc_args
...@@ -951,7 +914,7 @@ dataToTagRule = a `mplus` b ...@@ -951,7 +914,7 @@ dataToTagRule = a `mplus` b
in_scope <- getInScopeEnv in_scope <- getInScopeEnv
(dc,_,_) <- liftMaybe $ exprIsConApp_maybe in_scope val_arg (dc,_,_) <- liftMaybe $ exprIsConApp_maybe in_scope val_arg
ASSERT( not (isNewTyCon (dataConTyCon dc)) ) return () ASSERT( not (isNewTyCon (dataConTyCon dc)) ) return ()
return $ mkIntVal dflags (toInteger (dataConTag dc - fIRST_TAG)) return $ mkIntVal dflags (toInteger (dataConTagZ dc))
{- {-
************************************************************************ ************************************************************************
...@@ -1183,7 +1146,7 @@ match_append_lit _ _ _ _ = Nothing ...@@ -1183,7 +1146,7 @@ match_append_lit _ _ _ _ = Nothing
--------------------------------------------------- ---------------------------------------------------
-- The rule is this: -- The rule is this:
-- eqString (unpackCString# (Lit s1)) (unpackCString# (Lit s2) = s1==s2 -- eqString (unpackCString# (Lit s1)) (unpackCString# (Lit s2)) = s1==s2
match_eq_string :: RuleFun match_eq_string :: RuleFun
match_eq_string _ id_unf _ match_eq_string _ id_unf _
...@@ -1432,46 +1395,150 @@ match_smallIntegerTo _ _ _ _ _ = Nothing ...@@ -1432,46 +1395,150 @@ match_smallIntegerTo _ _ _ _ _ = Nothing
-- | Match the scrutinee of a case and potentially return a new scrutinee and a -- | Match the scrutinee of a case and potentially return a new scrutinee and a
-- function to apply to each literal alternative. -- function to apply to each literal alternative.
caseRules :: DynFlags -> CoreExpr -> Maybe (CoreExpr, Integer -> Integer) caseRules :: DynFlags
caseRules dflags scrut = case scrut of -> CoreExpr -- Scrutinee
-> Maybe ( CoreExpr -- New scrutinee
-- We need to call wordResult' and intResult' to ensure that the literal , AltCon -> AltCon -- How to fix up the alt pattern
-- alternatives remain in Word/Int target ranges (cf Note [Word/Int , Id -> CoreExpr) -- How to reconstruct the original scrutinee
-- underflow/overflow] and #13172). -- from the new case-binder
-- e.g case e of b {
-- v `op` x# -- ...;
App (App (Var f) v) (Lit l) -- con bs -> rhs;
| Just op <- isPrimOpId_maybe f -- ... }
, Just x <- isLitValue_maybe l -> -- ==>
case op of -- case e' of b' {
WordAddOp -> Just (v, \y -> wordResult' dflags $ y-x ) -- ...;
IntAddOp -> Just (v, \y -> intResult' dflags $ y-x ) -- fixup_altcon[con] bs -> let b = mk_orig[b] in rhs;
WordSubOp -> Just (v, \y -> wordResult' dflags $ y+x ) -- ... }
IntSubOp -> Just (v, \y -> intResult' dflags $ y+x )
XorOp -> Just (v, \y -> wordResult' dflags $ y `xor` x) caseRules dflags (App (App (Var f) v) (Lit l)) -- v `op` x#
XorIOp -> Just (v, \y -> intResult' dflags $ y `xor` x) | Just op <- isPrimOpId_maybe f
, Just x <- isLitValue_maybe l
, Just adjust_lit <- adjustDyadicRight op x
= Just (v, tx_lit_con dflags adjust_lit
, \v -> (App (App (Var f) (Var v)) (Lit l)))
caseRules dflags (App (App (Var f) (Lit l)) v) -- x# `op` v
| Just op <- isPrimOpId_maybe f
, Just x <- isLitValue_maybe l
, Just adjust_lit <- adjustDyadicLeft x op
= Just (v, tx_lit_con dflags adjust_lit
, \v -> (App (App (Var f) (Var v)) (Lit l)))
caseRules dflags (App (Var f) v ) -- op v
| Just op <- isPrimOpId_maybe f
, Just adjust_lit <- adjustUnary op
= Just (v, tx_lit_con dflags adjust_lit
, \v -> App (Var f) (Var v))
-- See Note [caseRules for tagToEnum]
caseRules dflags (App (App (Var f) type_arg) v)
| Just TagToEnumOp <- isPrimOpId_maybe f
= Just (v, tx_con_tte dflags
, \v -> (App (App (Var f) type_arg) (Var v)))
-- See Note [caseRules for dataToTag]
caseRules _ (App (App (Var f) (Type ty)) v) -- dataToTag x
| Just DataToTagOp <- isPrimOpId_maybe f
= Just (v, tx_con_dtt ty
, \v -> App (App (Var f) (Type ty)) (Var v))
caseRules _ _ = Nothing
tx_lit_con :: DynFlags -> (Integer -> Integer) -> AltCon -> AltCon
tx_lit_con _ _ DEFAULT = DEFAULT
tx_lit_con dflags adjust (LitAlt l) = LitAlt (mapLitValue dflags adjust l)
tx_lit_con _ _ alt = pprPanic "caseRules" (ppr alt)
-- NB: mapLitValue uses mkMachIntWrap etc, to ensure that the
-- literal alternatives remain in Word/Int target ranges
-- (See Note [Word/Int underflow/overflow] in Literal and #13172).
adjustDyadicRight :: PrimOp -> Integer -> Maybe (Integer -> Integer)
-- Given (x `op` lit) return a function 'f' s.t. f (x `op` lit) = x
adjustDyadicRight op lit
= case op of
WordAddOp -> Just (\y -> y-lit )
IntAddOp -> Just (\y -> y-lit )
WordSubOp -> Just (\y -> y+lit )
IntSubOp -> Just (\y -> y+lit )
XorOp -> Just (\y -> y `xor` lit)
XorIOp -> Just (\y -> y `xor` lit)
_ -> Nothing _ -> Nothing
-- x# `op` v adjustDyadicLeft :: Integer -> PrimOp -> Maybe (Integer -> Integer)
App (App (Var f) (Lit l)) v -- Given (lit `op` x) return a function 'f' s.t. f (lit `op` x) = x
| Just op <- isPrimOpId_maybe f adjustDyadicLeft lit op
, Just x <- isLitValue_maybe l -> = case op of
case op of WordAddOp -> Just (\y -> y-lit )
WordAddOp -> Just (v, \y -> wordResult' dflags $ y-x ) IntAddOp -> Just (\y -> y-lit )
IntAddOp -> Just (v, \y -> intResult' dflags $ y-x ) WordSubOp -> Just (\y -> lit-y )
WordSubOp -> Just (v, \y -> wordResult' dflags $ x-y ) IntSubOp -> Just (\y -> lit-y )
IntSubOp -> Just (v, \y -> intResult' dflags $ x-y ) XorOp -> Just (\y -> y `xor` lit)
XorOp -> Just (v, \y -> wordResult' dflags $ y `xor` x) XorIOp -> Just (\y -> y `xor` lit)
XorIOp -> Just (v, \y -> intResult' dflags $ y `xor` x)
_ -> Nothing _ -> Nothing
-- op v
App (Var f) v adjustUnary :: PrimOp -> Maybe (Integer -> Integer)
| Just op <- isPrimOpId_maybe f -> -- Given (op x) return a function 'f' s.t. f (op x) = x
case op of adjustUnary op
NotOp -> Just (v, \y -> wordResult' dflags $ complement y) = case op of
NotIOp -> Just (v, \y -> intResult' dflags $ complement y) NotOp -> Just (\y -> complement y)
IntNegOp -> Just (v, \y -> intResult' dflags $ negate y ) NotIOp -> Just (\y -> complement y)
IntNegOp -> Just (\y -> negate y )
_ -> Nothing _ -> Nothing
_ -> Nothing tx_con_tte :: DynFlags -> AltCon -> AltCon
tx_con_tte _ DEFAULT = DEFAULT
tx_con_tte dflags (DataAlt dc)
| tag == 0 = DEFAULT -- See Note [caseRules for tagToEnum]
| otherwise = LitAlt (mkMachInt dflags (toInteger tag))
where
tag = dataConTagZ dc
tx_con_tte _ alt = pprPanic "caseRules" (ppr alt)
tx_con_dtt :: Type -> AltCon -> AltCon
tx_con_dtt _ DEFAULT = DEFAULT
tx_con_dtt ty (LitAlt (MachInt i)) = DataAlt (get_con ty (fromInteger i))
tx_con_dtt _ alt = pprPanic "caseRules" (ppr alt)
get_con :: Type -> ConTagZ -> DataCon
get_con ty tag = tyConDataCons (tyConAppTyCon ty) !! tag
{- Note [caseRules for tagToEnum]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We want to transform
case tagToEnum x of
False -> e1
True -> e2
into
case x of
0# -> e1
1# -> e1
This rule elimiantes a lot of boilerplate. For
if (x>y) then e1 else e2
we generate
case tagToEnum (x ># y) of
False -> e2
True -> e1
and it is nice to then get rid of the tagToEnum.
NB: in SimplUtils, where we invoke caseRules,
we convert that 0# to DEFAULT
Note [caseRules for dataToTag]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We want to transform
case dataToTag x of
DEFAULT -> e1
1# -> e2
into
case x of
DEFAULT -> e1
(:) _ _ -> e2
Note the need for some wildcard binders in
the 'cons' case.
-}
...@@ -53,7 +53,7 @@ import Demand ...@@ -53,7 +53,7 @@ import Demand
import SimplMonad import SimplMonad
import Type hiding( substTy ) import Type hiding( substTy )
import Coercion hiding( substCo ) import Coercion hiding( substCo )
import DataCon ( dataConWorkId ) import DataCon ( dataConWorkId, isNullaryRepDataCon )
import VarEnv import VarEnv
import VarSet import VarSet
import BasicTypes import BasicTypes
...@@ -62,7 +62,7 @@ import MonadUtils ...@@ -62,7 +62,7 @@ import MonadUtils
import Outputable import Outputable
import Pair import Pair
import PrelRules import PrelRules
import Literal import FastString ( fsLit )
import Control.Monad ( when ) import Control.Monad ( when )
import Data.List ( sortBy ) import Data.List ( sortBy )
...@@ -1779,8 +1779,12 @@ prepareAlts scrut case_bndr' alts ...@@ -1779,8 +1779,12 @@ prepareAlts scrut case_bndr' alts
mkCase tries these things mkCase tries these things
1. Merge Nested Cases * Note [Nerge nested cases]
* Note [Elimiante identity case]
* Note [Scrutinee constant folding]
Note [Merge Nested Cases]
~~~~~~~~~~~~~~~~~~~~~~~~~
case e of b { ==> case e of b { case e of b { ==> case e of b {
p1 -> rhs1 p1 -> rhs1 p1 -> rhs1 p1 -> rhs1
... ... ... ...
...@@ -1792,21 +1796,21 @@ mkCase tries these things ...@@ -1792,21 +1796,21 @@ mkCase tries these things
_ -> rhsd _ -> rhsd
} }
which merges two cases in one case when -- the default alternative of which merges two cases in one case when -- the default alternative of
the outer case scrutises the same variable as the outer case. This the outer case scrutises the same variable as the outer case. This
transformation is called Case Merging. It avoids that the same transformation is called Case Merging. It avoids that the same
variable is scrutinised multiple times. variable is scrutinised multiple times.
2. Eliminate Identity Case
Note [Eliminate Identity Case]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
case e of ===> e case e of ===> e
True -> True; True -> True;
False -> False False -> False
and similar friends. and similar friends.
3. Scrutinee Constant Folding
Note [Scrutinee Constant Folding]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
case x op# k# of _ { ===> case x of _ { case x op# k# of _ { ===> case x of _ {
a1# -> e1 (a1# inv_op# k#) -> e1 a1# -> e1 (a1# inv_op# k#) -> e1
a2# -> e2 (a2# inv_op# k#) -> e2 a2# -> e2 (a2# inv_op# k#) -> e2
...@@ -1815,32 +1819,66 @@ mkCase tries these things ...@@ -1815,32 +1819,66 @@ mkCase tries these things
where (x op# k#) inv_op# k# == x where (x op# k#) inv_op# k# == x
And similarly for commuted arguments and for some unary operations. And similarly for commuted arguments and for some unary operations.
The purpose of this transformation is not only to avoid an arithmetic The purpose of this transformation is not only to avoid an arithmetic
operation at runtime but to allow other transformations to apply in cascade. operation at runtime but to allow other transformations to apply in cascade.
Example with the "Merge Nested Cases" optimization (from #12877): Example with the "Merge Nested Cases" optimization (from #12877):
main = case t of t0 main = case t of t0
0## -> ... 0## -> ...
DEFAULT -> case t0 `minusWord#` 1## of t1 DEFAULT -> case t0 `minusWord#` 1## of t1
0## -> ... 0## -> ...
DEFAUT -> case t1 `minusWord#` 1## of t2 DEFAUT -> case t1 `minusWord#` 1## of t2
0## -> ... 0## -> ...
DEFAULT -> case t2 `minusWord#` 1## of _ DEFAULT -> case t2 `minusWord#` 1## of _
0## -> ... 0## -> ...
DEFAULT -> ... DEFAULT -> ...
becomes: becomes:
main = case t of _ main = case t of _
0## -> ... 0## -> ...
1## -> ... 1## -> ...
2## -> ... 2## -> ...
3## -> ... 3## -> ...
DEFAULT -> ... DEFAULT -> ...
There are some wrinkles
* Do not apply caseRules if there is just a single DEFAULT alternative
case e +# 3# of b { DEFAULT -> rhs }
If we applied the transformation here we would (stupidly) get
case a of b' { DEFAULT -> let b = e +# 3# in rhs }
and now the process may repeat, because that let will really
be a case.
* The type of the scrutinee might change. E.g.
case tagToEnum (x :: Int#) of (b::Bool)
False -> e1
True -> e2
==>
case x of (b'::Int#)
DEFAULT -> e1
1# -> e2
* The case binder may be used in the right hand sides, so we need
to make a local binding for it, if it is alive. e.g.
case e +# 10# of b
DEFAULT -> blah...b...
44# -> blah2...b...
===>
case e of b'
DEFAULT -> let b = b' +# 10# in blah...b...
34# -> let b = 44# in blah2...b...
Note that in the non-DEFAULT cases we know what to bind 'b' to,
whereas in the DEFAULT case we must reconstruct the original value.
But NB: we use b'; we do not duplicate 'e'.
* In dataToTag we might need to make up some fake binders;
see Note [caseRules for dataToTag] in PrelRules
-} -}
mkCase, mkCase1, mkCase2, mkCase3 mkCase, mkCase1, mkCase2, mkCase3
...@@ -1941,9 +1979,18 @@ mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts ...@@ -1941,9 +1979,18 @@ mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts
-------------------------------------------------- --------------------------------------------------
mkCase2 dflags scrut bndr alts_ty alts mkCase2 dflags scrut bndr alts_ty alts
| gopt Opt_CaseFolding dflags | -- See Note [Scrutinee Constant Folding]
, Just (scrut',f) <- caseRules dflags scrut case alts of -- Not if there is just a DEFAULT alterantive
= mkCase3 dflags scrut' bndr alts_ty (new_alts f) [(DEFAULT,_,_)] -> False
_ -> True
, gopt Opt_CaseFolding dflags
, Just (scrut', tx_con, mk_orig) <- caseRules dflags scrut
= do { bndr' <- newId (fsLit "lwild") (exprType scrut')
; alts' <- mapM (tx_alt tx_con mk_orig bndr') alts
; mkCase3 dflags scrut' bndr' alts_ty $
add_default (re_sort alts')
}
| otherwise | otherwise
= mkCase3 dflags scrut bndr alts_ty alts = mkCase3 dflags scrut bndr alts_ty alts
where where
...@@ -1959,19 +2006,41 @@ mkCase2 dflags scrut bndr alts_ty alts ...@@ -1959,19 +2006,41 @@ mkCase2 dflags scrut bndr alts_ty alts
-- 10 -> let y = 20 in e1 -- 10 -> let y = 20 in e1