Commit d3b546b1 authored by Sylvain Henry's avatar Sylvain Henry Committed by Ben Gamari

Scrutinee Constant Folding

This patch introduces new rules to perform constant folding through
case-expressions.

E.g.,
```
case t -# 10# of _ {  ===> case t of _ {
         5#      -> e1              15#     -> e1
         8#      -> e2              18#     -> e2
         DEFAULT -> e               DEFAULT -> e
```

The initial motivation is that it allows "Merge Nested Cases"
optimization to kick in and to further simplify the code
(see Trac #12877).

Currently we recognize the following operations for Word# and Int#: Add,
Sub, Xor, Not and Negate (for Int# only).

Test Plan: validate

Reviewers: simonpj, austin, bgamari

Reviewed By: simonpj, bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2762

GHC Trac Issues: #12877
parent 61932cd3
......@@ -29,7 +29,7 @@ module Literal
, inIntRange, inWordRange, tARGET_MAX_INT, inCharRange
, isZeroLit
, litFitsInChar
, litValue
, litValue, isLitValue, isLitValue_maybe, mapLitValue
-- ** Coercions
, word2IntLit, int2WordLit
......@@ -59,6 +59,7 @@ import Data.ByteString (ByteString)
import Data.Int
import Data.Word
import Data.Char
import Data.Maybe ( isJust )
import Data.Data ( Data )
import Numeric ( fromRat )
......@@ -271,13 +272,37 @@ isZeroLit _ = False
-- | Returns the 'Integer' contained in the 'Literal', for when that makes
-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
litValue :: Literal -> Integer
litValue (MachChar c) = toInteger $ ord c
litValue (MachInt i) = i
litValue (MachInt64 i) = i
litValue (MachWord i) = i
litValue (MachWord64 i) = i
litValue (LitInteger i _) = i
litValue l = pprPanic "litValue" (ppr l)
litValue l = case isLitValue_maybe l of
Just x -> x
Nothing -> pprPanic "litValue" (ppr l)
-- | Returns the 'Integer' contained in the 'Literal', for when that makes
-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
isLitValue_maybe :: Literal -> Maybe Integer
isLitValue_maybe (MachChar c) = Just $ toInteger $ ord c
isLitValue_maybe (MachInt i) = Just i
isLitValue_maybe (MachInt64 i) = Just i
isLitValue_maybe (MachWord i) = Just i
isLitValue_maybe (MachWord64 i) = Just i
isLitValue_maybe (LitInteger i _) = Just i
isLitValue_maybe _ = Nothing
-- | Apply a function to the 'Integer' contained in the 'Literal', for when that
-- makes sense, e.g. for 'Char', 'Int', 'Word' and 'LitInteger'.
mapLitValue :: (Integer -> Integer) -> Literal -> Literal
mapLitValue f (MachChar c) = MachChar (fchar c)
where fchar = chr . fromInteger . f . toInteger . ord
mapLitValue f (MachInt i) = MachInt (f i)
mapLitValue f (MachInt64 i) = MachInt64 (f i)
mapLitValue f (MachWord i) = MachWord (f i)
mapLitValue f (MachWord64 i) = MachWord64 (f i)
mapLitValue f (LitInteger i t) = LitInteger (f i) t
mapLitValue _ l = pprPanic "mapLitValue" (ppr l)
-- | Indicate if the `Literal` contains an 'Integer' value, e.g. 'Char',
-- 'Int', 'Word' and 'LitInteger'.
isLitValue :: Literal -> Bool
isLitValue = isJust . isLitValue_maybe
{-
Coercions
......
......@@ -445,6 +445,7 @@ data GeneralFlag
| Opt_IgnoreAsserts
| Opt_DoEtaReduction
| Opt_CaseMerge
| Opt_CaseFolding -- Constant folding through case-expressions
| Opt_UnboxStrictFields
| Opt_UnboxSmallStrictFields
| Opt_DictsCheap
......@@ -3561,6 +3562,7 @@ fFlagsDeps = [
flagSpec "building-cabal-package" Opt_BuildingCabalPackage,
flagSpec "call-arity" Opt_CallArity,
flagSpec "case-merge" Opt_CaseMerge,
flagSpec "case-folding" Opt_CaseFolding,
flagSpec "cmm-elim-common-blocks" Opt_CmmElimCommonBlocks,
flagSpec "cmm-sink" Opt_CmmSink,
flagSpec "cse" Opt_CSE,
......@@ -4012,6 +4014,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
, ([1,2], Opt_CallArity)
, ([1,2], Opt_CaseMerge)
, ([1,2], Opt_CaseFolding)
, ([1,2], Opt_CmmElimCommonBlocks)
, ([1,2], Opt_CmmSink)
, ([1,2], Opt_CSE)
......
......@@ -15,7 +15,12 @@ ToDo:
{-# LANGUAGE CPP, RankNTypes #-}
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
module PrelRules ( primOpRules, builtinRules ) where
module PrelRules
( primOpRules
, builtinRules
, caseRules
)
where
#include "HsVersions.h"
#include "../includes/MachDeps.h"
......@@ -1385,3 +1390,53 @@ match_smallIntegerTo primOp _ _ _ [App (Var x) y]
| idName x == smallIntegerName
= Just $ App (Var (mkPrimOpId primOp)) y
match_smallIntegerTo _ _ _ _ _ = Nothing
--------------------------------------------------------
-- Constant folding through case-expressions
--
-- cf Scrutinee Constant Folding in simplCore/SimplUtils
--------------------------------------------------------
-- | Match the scrutinee of a case and potentially return a new scrutinee and a
-- function to apply to each literal alternative.
caseRules :: CoreExpr -> Maybe (CoreExpr, Integer -> Integer)
caseRules scrut = case scrut of
-- v `op` x#
App (App (Var f) v) (Lit l)
| Just op <- isPrimOpId_maybe f
, Just x <- isLitValue_maybe l ->
case op of
WordAddOp -> Just (v, \y -> y-x )
IntAddOp -> Just (v, \y -> y-x )
WordSubOp -> Just (v, \y -> y+x )
IntSubOp -> Just (v, \y -> y+x )
XorOp -> Just (v, \y -> y `xor` x)
XorIOp -> Just (v, \y -> y `xor` x)
_ -> Nothing
-- x# `op` v
App (App (Var f) (Lit l)) v
| Just op <- isPrimOpId_maybe f
, Just x <- isLitValue_maybe l ->
case op of
WordAddOp -> Just (v, \y -> y-x )
IntAddOp -> Just (v, \y -> y-x )
WordSubOp -> Just (v, \y -> x-y )
IntSubOp -> Just (v, \y -> x-y )
XorOp -> Just (v, \y -> y `xor` x)
XorIOp -> Just (v, \y -> y `xor` x)
_ -> Nothing
-- op v
App (Var f) v
| Just op <- isPrimOpId_maybe f ->
case op of
NotOp -> Just (v, \y -> complement y)
NotIOp -> Just (v, \y -> complement y)
IntNegOp -> Just (v, \y -> negate y )
_ -> Nothing
_ -> Nothing
......@@ -60,6 +60,8 @@ import Util
import MonadUtils
import Outputable
import Pair
import PrelRules
import Literal
import Control.Monad ( when )
......@@ -1752,9 +1754,46 @@ mkCase tries these things
False -> False
and similar friends.
3. Scrutinee Constant Folding
case x op# k# of _ { ===> case x of _ {
a1# -> e1 (a1# inv_op# k#) -> e1
a2# -> e2 (a2# inv_op# k#) -> e2
... ...
DEFAULT -> ed DEFAULT -> ed
where (x op# k#) inv_op# k# == x
And similarly for commuted arguments and for some unary operations.
The purpose of this transformation is not only to avoid an arithmetic
operation at runtime but to allow other transformations to apply in cascade.
Example with the "Merge Nested Cases" optimization (from #12877):
main = case t of t0
0## -> ...
DEFAULT -> case t0 `minusWord#` 1## of t1
0## -> ...
DEFAUT -> case t1 `minusWord#` 1## of t2
0## -> ...
DEFAULT -> case t2 `minusWord#` 1## of _
0## -> ...
DEFAULT -> ...
becomes:
main = case t of _
0## -> ...
1## -> ...
2## -> ...
3## -> ...
DEFAULT -> ...
-}
mkCase, mkCase1, mkCase2
mkCase, mkCase1, mkCase2, mkCase3
:: DynFlags
-> OutExpr -> OutId
-> OutType -> [OutAlt] -- Alternatives in standard (increasing) order
......@@ -1847,10 +1886,43 @@ mkCase1 _dflags scrut case_bndr _ alts@((_,_,rhs1) : _) -- Identity case
mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts
--------------------------------------------------
-- 2. Scrutinee Constant Folding
--------------------------------------------------
mkCase2 dflags scrut bndr alts_ty alts
| gopt Opt_CaseFolding dflags
, Just (scrut',f) <- caseRules scrut
= mkCase3 dflags scrut' bndr alts_ty (map (mapAlt f) alts)
| otherwise
= mkCase3 dflags scrut bndr alts_ty alts
where
-- We need to keep the correct association between the scrutinee and its
-- binder if the latter isn't dead. Hence we wrap rhs of alternatives with
-- "let bndr = ... in":
--
-- case v + 10 of y =====> case v of y
-- 20 -> e1 10 -> let y = 20 in e1
-- DEFAULT -> e2 DEFAULT -> let y = v + 10 in e2
--
-- Other transformations give: =====> case v of y'
-- 10 -> let y = 20 in e1
-- DEFAULT -> let y = y' + 10 in e2
--
wrap_rhs l rhs
| isDeadBinder bndr = rhs
| otherwise = Let (NonRec bndr l) rhs
mapAlt f alt@(c,bs,e) = case c of
DEFAULT -> (c, bs, wrap_rhs scrut e)
LitAlt l
| isLitValue l -> (LitAlt (mapLitValue f l), bs, wrap_rhs (Lit l) e)
_ -> pprPanic "Unexpected alternative (mkCase2)" (ppr alt)
--------------------------------------------------
-- Catch-all
--------------------------------------------------
mkCase2 _dflags scrut bndr alts_ty alts
mkCase3 _dflags scrut bndr alts_ty alts
= return (Case scrut bndr alts_ty alts)
{-
......
......@@ -115,7 +115,7 @@ list.
:default: on
Merge immediately-nested case expressions that scrutinse the same variable.
Merge immediately-nested case expressions that scrutinise the same variable.
For example, ::
case x of
......@@ -131,6 +131,25 @@ list.
Blue -> e2
Green -> e2
.. ghc-flag:: -fcase-folding
:default: on
Allow constant folding in case expressions that scrutinise some primops:
For example, ::
case x `minusWord#` 10## of
10## -> e1
20## -> e2
v -> e3
Is transformed to, ::
case x of
20## -> e1
30## -> e2
_ -> let v = x `minusWord#` 10## in e3
.. ghc-flag:: -fcall-arity
:default: on
......
-- This ugly cascading case reduces to:
-- case x of
-- 0 -> "0"
-- 1 -> "1"
-- _ -> "n"
--
-- but only if GHC's case-folding reduction kicks in.
{-# NOINLINE test #-}
test :: Word -> String
test x = case x of
0 -> "0"
1 -> "1"
t -> case t + 1 of
1 -> "0"
2 -> "1"
t -> case t + 1 of
2 -> "0"
3 -> "1"
t -> case t + 1 of
3 -> "0"
4 -> "1"
t -> case t + 1 of
4 -> "0"
5 -> "1"
t -> case t + 1 of
5 -> "0"
6 -> "1"
t -> case t + 1 of
6 -> "0"
7 -> "1"
t -> case t + 1 of
7 -> "0"
8 -> "1"
t -> case t + 1 of
8 -> "0"
9 -> "1"
t -> case t + 1 of
10 -> "0"
11 -> "1"
t -> case t + 1 of
11 -> "0"
12 -> "1"
t -> case t + 1 of
12 -> "0"
13 -> "1"
t -> case t + 1 of
13 -> "0"
14 -> "1"
t -> case t + 1 of
14 -> "0"
15 -> "1"
t -> case t + 1 of
15 -> "0"
16 -> "1"
t -> case t + 1 of
16 -> "0"
17 -> "1"
t -> case t + 1 of
17 -> "0"
18 -> "1"
t -> case t + 1 of
18 -> "0"
19 -> "1"
t -> case t + 1 of
19 -> "0"
20 -> "1"
t -> case t + 1 of
20 -> "0"
21 -> "1"
t -> case t + 1 of
21 -> "0"
22 -> "1"
t -> case t + 1 of
22 -> "0"
23 -> "1"
t -> case t + 1 of
23 -> "0"
24 -> "1"
t -> case t + 1 of
24 -> "0"
25 -> "1"
t -> case t + 1 of
25 -> "0"
26 -> "1"
t -> case t + 1 of
26 -> "0"
27 -> "1"
t -> case t + 1 of
27 -> "0"
28 -> "1"
t -> case t + 1 of
28 -> "0"
29 -> "1"
t -> case t + 1 of
29 -> "0"
30 -> "1"
t -> case t + 1 of
30 -> "0"
31 -> "1"
t -> case t + 1 of
31 -> "0"
32 -> "1"
t -> case t + 1 of
32 -> "0"
33 -> "1"
t -> case t + 1 of
33 -> "0"
34 -> "1"
t -> case t + 1 of
34 -> "0"
35 -> "1"
_ -> "n"
main :: IO ()
main = do
putStrLn [last (concat (fmap test [0..12345678]))]
......@@ -895,3 +895,16 @@ test('T12234',
compile,
[''])
test('T12877',
[ stats_num_field('bytes allocated',
[(wordsize(64), 197582248, 5),
# initial: 197582248 (Linux)
])
, compiler_stats_num_field('bytes allocated',
[(wordsize(64), 135979000, 5),
# initial: 135979000 (Linux)
]),
],
compile_and_run,
['-O2'])
......@@ -15,6 +15,11 @@ optimizationsOptions =
, flagType = DynamicFlag
, flagReverse = "-fno-case-merge"
}
, flag { flagName = "-fcase-folding"
, flagDescription = "Enable constant folding in case expressions. Implied by :ghc-flag:`-O`."
, flagType = DynamicFlag
, flagReverse = "-fno-case-folding"
}
, flag { flagName = "-fcmm-elim-common-blocks"
, flagDescription =
"Enable Cmm common block elimination. Implied by :ghc-flag:`-O`."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment