Commit 60e4bb4d authored by Sylvain Henry's avatar Sylvain Henry Committed by Ben Gamari
Browse files

Enhanced constant folding

Until now GHC only supported basic constant folding (lit op lit, expr op
0, etc.).

This patch uses laws of +/-/* (associativity, commutativity,
distributivity) to support some constant folding into nested
expressions.

Examples of new transformations:

   - simple nesting: (10 + x) + 10 becomes 20 + x
   - deep nesting: 5 + x + (y + (z + (t + 5))) becomes 10 + (x + (y + (z + t)))
   - distribution: (5 + x) * 6 becomes 30 + 6*x
   - simple factorization: 5 + x + (x + (x + (x + 5))) becomes 10 + (4 *x)
   - siblings: (5 + 4*x) - (3*x + 2) becomes 3 + x

Test Plan: validate

Reviewers: simonpj, austin, bgamari

Reviewed By: bgamari

Subscribers: thomie

GHC Trac Issues: #9136

Differential Revision: https://phabricator.haskell.org/D2858

(cherry picked from commit fea04def)
parent f998947f
......@@ -490,6 +490,7 @@ data GeneralFlag
| Opt_SolveConstantDicts
| Opt_AlignmentSanitisation
| Opt_CatchBottoms
| Opt_NumConstantFolding
-- PreInlining is on by default. The option is there just to see how
-- bad things get if you turn it off!
......@@ -3984,6 +3985,7 @@ fFlagsDeps = [
flagSpec "solve-constant-dicts" Opt_SolveConstantDicts,
flagSpec "catch-bottoms" Opt_CatchBottoms,
flagSpec "alignment-sanitisation" Opt_AlignmentSanitisation,
flagSpec "num-constant-folding" Opt_NumConstantFolding,
flagSpec "show-warning-groups" Opt_ShowWarnGroups,
flagSpec "hide-source-paths" Opt_HideSourcePaths,
flagSpec "show-loaded-modules" Opt_ShowLoadedModules,
......@@ -4437,6 +4439,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
, ([1,2], Opt_CprAnal)
, ([1,2], Opt_WorkerWrapper)
, ([1,2], Opt_SolveConstantDicts)
, ([1,2], Opt_NumConstantFolding)
, ([2], Opt_LiberateCase)
, ([2], Opt_SpecConstr)
......
......@@ -12,7 +12,7 @@ ToDo:
(i1 + i2) only if it results in a valid Float.
-}
{-# LANGUAGE CPP, RankNTypes #-}
{-# LANGUAGE CPP, RankNTypes, PatternSynonyms, ViewPatterns, RecordWildCards #-}
{-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
module PrelRules
......@@ -90,10 +90,14 @@ primOpRules nm DataToTagOp = mkPrimOpRule nm 2 [ dataToTagRule ]
-- Int operations
primOpRules nm IntAddOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (+))
, identityDynFlags zeroi ]
, identityDynFlags zeroi
, numFoldingRules IntAddOp intPrimOps
]
primOpRules nm IntSubOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (-))
, rightIdentityDynFlags zeroi
, equalArgs >> retLit zeroi ]
, equalArgs >> retLit zeroi
, numFoldingRules IntSubOp intPrimOps
]
primOpRules nm IntAddCOp = mkPrimOpRule nm 2 [ binaryLit (intOpC2 (+))
, identityCDynFlags zeroi ]
primOpRules nm IntSubCOp = mkPrimOpRule nm 2 [ binaryLit (intOpC2 (-))
......@@ -101,7 +105,9 @@ primOpRules nm IntSubCOp = mkPrimOpRule nm 2 [ binaryLit (intOpC2 (-))
, equalArgs >> retLitNoC zeroi ]
primOpRules nm IntMulOp = mkPrimOpRule nm 2 [ binaryLit (intOp2 (*))
, zeroElem zeroi
, identityDynFlags onei ]
, identityDynFlags onei
, numFoldingRules IntMulOp intPrimOps
]
primOpRules nm IntQuotOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (intOp2 quot)
, leftZero zeroi
, rightIdentityDynFlags onei
......@@ -136,17 +142,23 @@ primOpRules nm ISrlOp = mkPrimOpRule nm 2 [ shiftRule shiftRightLogical
-- Word operations
primOpRules nm WordAddOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (+))
, identityDynFlags zerow ]
, identityDynFlags zerow
, numFoldingRules WordAddOp wordPrimOps
]
primOpRules nm WordSubOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (-))
, rightIdentityDynFlags zerow
, equalArgs >> retLit zerow ]
, equalArgs >> retLit zerow
, numFoldingRules WordSubOp wordPrimOps
]
primOpRules nm WordAddCOp = mkPrimOpRule nm 2 [ binaryLit (wordOpC2 (+))
, identityCDynFlags zerow ]
primOpRules nm WordSubCOp = mkPrimOpRule nm 2 [ binaryLit (wordOpC2 (-))
, rightIdentityCDynFlags zerow
, equalArgs >> retLitNoC zerow ]
primOpRules nm WordMulOp = mkPrimOpRule nm 2 [ binaryLit (wordOp2 (*))
, identityDynFlags onew ]
, identityDynFlags onew
, numFoldingRules WordMulOp wordPrimOps
]
primOpRules nm WordQuotOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 quot)
, rightIdentityDynFlags onew ]
primOpRules nm WordRemOp = mkPrimOpRule nm 2 [ nonZeroLit 1 >> binaryLit (wordOp2 rem)
......@@ -581,7 +593,10 @@ isMaxBound _ _ = False
-- | Create an Int literal expression while ensuring the given Integer is in the
-- target Int range
intResult :: DynFlags -> Integer -> Maybe CoreExpr
intResult dflags result = Just (Lit (mkMachIntWrap dflags result))
intResult dflags result = Just (intResult' dflags result)
intResult' :: DynFlags -> Integer -> CoreExpr
intResult' dflags result = Lit (mkMachIntWrap dflags result)
-- | Create an unboxed pair of an Int literal expression, ensuring the given
-- Integer is in the target Int range and the corresponding overflow flag
......@@ -596,7 +611,10 @@ intCResult dflags result = Just (mkPair [Lit lit, Lit c])
-- | Create a Word literal expression while ensuring the given Integer is in the
-- target Word range
wordResult :: DynFlags -> Integer -> Maybe CoreExpr
wordResult dflags result = Just (Lit (mkMachWordWrap dflags result))
wordResult dflags result = Just (wordResult' dflags result)
wordResult' :: DynFlags -> Integer -> CoreExpr
wordResult' dflags result = Lit (mkMachWordWrap dflags result)
-- | Create an unboxed pair of a Word literal expression, ensuring the given
-- Integer is in the target Word range and the corresponding carry flag
......@@ -1633,6 +1651,275 @@ match_smallIntegerTo _ _ _ _ _ = Nothing
--------------------------------------------------------
-- Note [Constant folding through nested expressions]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- We use rewrites rules to perform constant folding. It means that we don't
-- have a global view of the expression we are trying to optimise. As a
-- consequence we only perform local (small-step) transformations that either:
-- 1) reduce the number of operations
-- 2) rearrange the expression to increase the odds that other rules will
-- match
--
-- We don't try to handle more complex expression optimisation cases that would
-- require a global view. For example, rewriting expressions to increase
-- sharing (e.g., Horner's method); optimisations that require local
-- transformations increasing the number of operations; rearrangements to
-- cancel/factorize terms (e.g., (a+b-a-b) isn't rearranged to reduce to 0).
--
-- We already have rules to perform constant folding on expressions with the
-- following shape (where a and/or b are literals):
--
-- D) op
-- /\
-- / \
-- / \
-- a b
--
-- To support nested expressions, we match three other shapes of expression
-- trees:
--
-- A) op1 B) op1 C) op1
-- /\ /\ /\
-- / \ / \ / \
-- / \ / \ / \
-- a op2 op2 c op2 op3
-- /\ /\ /\ /\
-- / \ / \ / \ / \
-- b c a b a b c d
--
--
-- R1) +/- simplification:
-- ops = + or -, two literals (not siblings)
--
-- Examples:
-- A: 5 + (10-x) ==> 15-x
-- B: (10+x) + 5 ==> 15+x
-- C: (5+a)-(5-b) ==> 0+(a+b)
--
-- R2) * simplification
-- ops = *, two literals (not siblings)
--
-- Examples:
-- A: 5 * (10*x) ==> 50*x
-- B: (10*x) * 5 ==> 50*x
-- C: (5*a)*(5*b) ==> 25*(a*b)
--
-- R3) * distribution over +/-
-- op1 = *, op2 = + or -, two literals (not siblings)
--
-- This transformation doesn't reduce the number of operations but switches
-- the outer and the inner operations so that the outer is (+) or (-) instead
-- of (*). It increases the odds that other rules will match after this one.
--
-- Examples:
-- A: 5 * (10-x) ==> 50 - (5*x)
-- B: (10+x) * 5 ==> 50 + (5*x)
-- C: Not supported as it would increase the number of operations:
-- (5+a)*(5-b) ==> 25 - 5*b + 5*a - a*b
--
-- R4) Simple factorization
--
-- op1 = + or -, op2/op3 = *,
-- one literal for each innermost * operation (except in the D case),
-- the two other terms are equals
--
-- Examples:
-- A: x - (10*x) ==> (-9)*x
-- B: (10*x) + x ==> 11*x
-- C: (5*x)-(x*3) ==> 2*x
-- D: x+x ==> 2*x
--
-- R5) +/- propagation
--
-- ops = + or -, one literal
--
-- This transformation doesn't reduce the number of operations but propagates
-- the constant to the outer level. It increases the odds that other rules
-- will match after this one.
--
-- Examples:
-- A: x - (10-y) ==> (x+y) - 10
-- B: (10+x) - y ==> 10 + (x-y)
-- C: N/A (caught by the A and B cases)
--
--------------------------------------------------------
-- | Rules to perform constant folding into nested expressions
--
--See Note [Constant folding through nested expressions]
numFoldingRules :: PrimOp -> (DynFlags -> PrimOps) -> RuleM CoreExpr
numFoldingRules op dict = do
[e1,e2] <- getArgs
dflags <- getDynFlags
let PrimOps{..} = dict dflags
if not (gopt Opt_NumConstantFolding dflags)
then mzero
else case BinOpApp e1 op e2 of
-- R1) +/- simplification
x :++: (y :++: v) -> return $ mkL (x+y) `add` v
x :++: (L y :-: v) -> return $ mkL (x+y) `sub` v
x :++: (v :-: L y) -> return $ mkL (x-y) `add` v
L x :-: (y :++: v) -> return $ mkL (x-y) `sub` v
L x :-: (L y :-: v) -> return $ mkL (x-y) `add` v
L x :-: (v :-: L y) -> return $ mkL (x+y) `sub` v
(y :++: v) :-: L x -> return $ mkL (y-x) `add` v
(L y :-: v) :-: L x -> return $ mkL (y-x) `sub` v
(v :-: L y) :-: L x -> return $ mkL (0-y-x) `add` v
(x :++: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (w `add` v)
(w :-: L x) :+: (L y :-: v) -> return $ mkL (y-x) `add` (w `sub` v)
(w :-: L x) :+: (v :-: L y) -> return $ mkL (0-x-y) `add` (w `add` v)
(L x :-: w) :+: (L y :-: v) -> return $ mkL (x+y) `sub` (w `add` v)
(L x :-: w) :+: (v :-: L y) -> return $ mkL (x-y) `add` (v `sub` w)
(w :-: L x) :+: (y :++: v) -> return $ mkL (y-x) `add` (w `add` v)
(L x :-: w) :+: (y :++: v) -> return $ mkL (x+y) `add` (v `sub` w)
(y :++: v) :+: (w :-: L x) -> return $ mkL (y-x) `add` (w `add` v)
(y :++: v) :+: (L x :-: w) -> return $ mkL (x+y) `add` (v `sub` w)
(v :-: L y) :-: (w :-: L x) -> return $ mkL (x-y) `add` (v `sub` w)
(v :-: L y) :-: (L x :-: w) -> return $ mkL (0-x-y) `add` (v `add` w)
(L y :-: v) :-: (w :-: L x) -> return $ mkL (x+y) `sub` (v `add` w)
(L y :-: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (w `add` v)
(x :++: w) :-: (y :++: v) -> return $ mkL (x-y) `add` (w `sub` v)
(w :-: L x) :-: (y :++: v) -> return $ mkL (0-y-x) `add` (w `sub` v)
(L x :-: w) :-: (y :++: v) -> return $ mkL (x-y) `sub` (v `add` w)
(y :++: v) :-: (w :-: L x) -> return $ mkL (y+x) `add` (v `sub` w)
(y :++: v) :-: (L x :-: w) -> return $ mkL (y-x) `add` (v `add` w)
-- R2) * simplification
x :**: (y :**: v) -> return $ mkL (x*y) `mul` v
(x :**: w) :*: (y :**: v) -> return $ mkL (x*y) `mul` (w `mul` v)
-- R3) * distribution over +/-
x :**: (y :++: v) -> return $ mkL (x*y) `add` (mkL x `mul` v)
x :**: (L y :-: v) -> return $ mkL (x*y) `sub` (mkL x `mul` v)
x :**: (v :-: L y) -> return $ (mkL x `mul` v) `sub` mkL (x*y)
-- R4) Simple factorization
v :+: w
| w `cheapEqExpr` v -> return $ mkL 2 `mul` v
w :+: (y :**: v)
| w `cheapEqExpr` v -> return $ mkL (1+y) `mul` v
w :-: (y :**: v)
| w `cheapEqExpr` v -> return $ mkL (1-y) `mul` v
(y :**: v) :+: w
| w `cheapEqExpr` v -> return $ mkL (y+1) `mul` v
(y :**: v) :-: w
| w `cheapEqExpr` v -> return $ mkL (y-1) `mul` v
(x :**: w) :+: (y :**: v)
| w `cheapEqExpr` v -> return $ mkL (x+y) `mul` v
(x :**: w) :-: (y :**: v)
| w `cheapEqExpr` v -> return $ mkL (x-y) `mul` v
-- R5) +/- propagation
w :+: (y :++: v) -> return $ mkL y `add` (w `add` v)
(y :++: v) :+: w -> return $ mkL y `add` (w `add` v)
w :-: (y :++: v) -> return $ (w `sub` v) `sub` mkL y
(y :++: v) :-: w -> return $ mkL y `add` (v `sub` w)
w :-: (L y :-: v) -> return $ (w `add` v) `sub` mkL y
(L y :-: v) :-: w -> return $ mkL y `sub` (w `add` v)
w :+: (L y :-: v) -> return $ mkL y `add` (w `sub` v)
w :+: (v :-: L y) -> return $ (w `add` v) `sub` mkL y
(L y :-: v) :+: w -> return $ mkL y `add` (w `sub` v)
(v :-: L y) :+: w -> return $ (w `add` v) `sub` mkL y
_ -> mzero
-- | Match the application of a binary primop
pattern BinOpApp :: Arg CoreBndr -> PrimOp -> Arg CoreBndr -> CoreExpr
pattern BinOpApp x op y = OpVal op `App` x `App` y
-- | Match a primop
pattern OpVal :: PrimOp -> Arg CoreBndr
pattern OpVal op <- Var (isPrimOpId_maybe -> Just op) where
OpVal op = Var (mkPrimOpId op)
-- | Match a literal
pattern L :: Integer -> Arg CoreBndr
pattern L l <- Lit (isLitValue_maybe -> Just l)
-- | Match an addition
pattern (:+:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
pattern x :+: y <- BinOpApp x (isAddOp -> True) y
-- | Match an addition with a literal (handle commutativity)
pattern (:++:) :: Integer -> Arg CoreBndr -> CoreExpr
pattern l :++: x <- (isAdd -> Just (l,x))
isAdd :: CoreExpr -> Maybe (Integer,CoreExpr)
isAdd e = case e of
L l :+: x -> Just (l,x)
x :+: L l -> Just (l,x)
_ -> Nothing
-- | Match a multiplication
pattern (:*:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
pattern x :*: y <- BinOpApp x (isMulOp -> True) y
-- | Match a multiplication with a literal (handle commutativity)
pattern (:**:) :: Integer -> Arg CoreBndr -> CoreExpr
pattern l :**: x <- (isMul -> Just (l,x))
isMul :: CoreExpr -> Maybe (Integer,CoreExpr)
isMul e = case e of
L l :*: x -> Just (l,x)
x :*: L l -> Just (l,x)
_ -> Nothing
-- | Match a subtraction
pattern (:-:) :: Arg CoreBndr -> Arg CoreBndr -> CoreExpr
pattern x :-: y <- BinOpApp x (isSubOp -> True) y
isSubOp :: PrimOp -> Bool
isSubOp IntSubOp = True
isSubOp WordSubOp = True
isSubOp _ = False
isAddOp :: PrimOp -> Bool
isAddOp IntAddOp = True
isAddOp WordAddOp = True
isAddOp _ = False
isMulOp :: PrimOp -> Bool
isMulOp IntMulOp = True
isMulOp WordMulOp = True
isMulOp _ = False
-- | Explicit "type-class"-like dictionary for numeric primops
--
-- Depends on DynFlags because creating a literal value depends on DynFlags
data PrimOps = PrimOps
{ add :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Add two numbers
, sub :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Sub two numbers
, mul :: CoreExpr -> CoreExpr -> CoreExpr -- ^ Multiply two numbers
, mkL :: Integer -> CoreExpr -- ^ Create a literal value
}
intPrimOps :: DynFlags -> PrimOps
intPrimOps dflags = PrimOps
{ add = \x y -> BinOpApp x IntAddOp y
, sub = \x y -> BinOpApp x IntSubOp y
, mul = \x y -> BinOpApp x IntMulOp y
, mkL = intResult' dflags
}
wordPrimOps :: DynFlags -> PrimOps
wordPrimOps dflags = PrimOps
{ add = \x y -> BinOpApp x WordAddOp y
, sub = \x y -> BinOpApp x WordSubOp y
, mul = \x y -> BinOpApp x WordMulOp y
, mkL = wordResult' dflags
}
--------------------------------------------------------
-- Constant folding through case-expressions
--
......
==================== Tidy Core ====================
Result size of Tidy Core
= {terms: 172, types: 62, coercions: 0, joins: 0/2}
= {terms: 150, types: 60, coercions: 0, joins: 0/0}
-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
Roman.$trModule4 :: GHC.Prim.Addr#
......@@ -59,29 +59,20 @@ Roman.foo3
= Control.Exception.Base.patError @ 'GHC.Types.LiftedRep @ Int lvl
Rec {
-- RHS size: {terms: 52, types: 6, coercions: 0, joins: 0/1}
-- RHS size: {terms: 40, types: 5, coercions: 0, joins: 0/0}
Roman.foo_$s$wgo [Occ=LoopBreaker]
:: GHC.Prim.Int# -> GHC.Prim.Int# -> GHC.Prim.Int#
[GblId, Arity=2, Caf=NoCafRefs, Str=<S,U><S,U>, Unf=OtherCon []]
[GblId, Arity=2, Caf=NoCafRefs, Str=<L,U><S,U>, Unf=OtherCon []]
Roman.foo_$s$wgo
= \ (sc :: GHC.Prim.Int#) (sc1 :: GHC.Prim.Int#) ->
let {
m :: GHC.Prim.Int#
[LclId]
m = GHC.Prim.+#
(GHC.Prim.+#
(GHC.Prim.+#
(GHC.Prim.+# (GHC.Prim.+# (GHC.Prim.+# sc sc) sc) sc) sc)
sc)
sc } in
case GHC.Prim.<=# sc1 0# of {
__DEFAULT ->
case GHC.Prim.<# sc1 100# of {
__DEFAULT ->
case GHC.Prim.<# sc1 500# of {
__DEFAULT ->
Roman.foo_$s$wgo (GHC.Prim.+# m m) (GHC.Prim.-# sc1 1#);
1# -> Roman.foo_$s$wgo m (GHC.Prim.-# sc1 3#)
Roman.foo_$s$wgo (GHC.Prim.*# 14# sc) (GHC.Prim.-# sc1 1#);
1# -> Roman.foo_$s$wgo (GHC.Prim.*# 7# sc) (GHC.Prim.-# sc1 3#)
};
1# -> Roman.foo_$s$wgo sc (GHC.Prim.-# sc1 2#)
};
......@@ -89,31 +80,22 @@ Roman.foo_$s$wgo
}
end Rec }
-- RHS size: {terms: 71, types: 19, coercions: 0, joins: 0/1}
-- RHS size: {terms: 61, types: 18, coercions: 0, joins: 0/0}
Roman.$wgo [InlPrag=NOUSERINLINE[2]]
:: Maybe Int -> Maybe Int -> GHC.Prim.Int#
[GblId,
Arity=2,
Str=<S,1*U><S,1*U>,
Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True, Guidance=IF_ARGS [60 30] 253 0}]
WorkFree=True, Expandable=True, Guidance=IF_ARGS [61 30] 249 0}]
Roman.$wgo
= \ (w :: Maybe Int) (w1 :: Maybe Int) ->
case w1 of {
Nothing -> case Roman.foo3 of wild1 { };
Just x ->
case x of { GHC.Types.I# ipv ->
let {
m :: GHC.Prim.Int#
[LclId]
m = GHC.Prim.+#
(GHC.Prim.+#
(GHC.Prim.+#
(GHC.Prim.+# (GHC.Prim.+# (GHC.Prim.+# ipv ipv) ipv) ipv) ipv)
ipv)
ipv } in
case w of {
Nothing -> Roman.foo_$s$wgo m 10#;
Nothing -> Roman.foo_$s$wgo (GHC.Prim.*# 7# ipv) 10#;
Just n ->
case n of { GHC.Types.I# x2 ->
case GHC.Prim.<=# x2 0# of {
......@@ -122,8 +104,8 @@ Roman.$wgo
__DEFAULT ->
case GHC.Prim.<# x2 500# of {
__DEFAULT ->
Roman.foo_$s$wgo (GHC.Prim.+# m m) (GHC.Prim.-# x2 1#);
1# -> Roman.foo_$s$wgo m (GHC.Prim.-# x2 3#)
Roman.foo_$s$wgo (GHC.Prim.*# 14# ipv) (GHC.Prim.-# x2 1#);
1# -> Roman.foo_$s$wgo (GHC.Prim.*# 7# ipv) (GHC.Prim.-# x2 3#)
};
1# -> Roman.foo_$s$wgo ipv (GHC.Prim.-# x2 2#)
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment