Commit c08a6c2d authored by Alex D's avatar Alex D 🍄
Browse files

Revamp derived Eq instance code generation (#17240)

This patch improves code generation for derived Eq instances.
The idea is to use 'dataToTag' to evaluate both arguments.
This allows to 'short-circuit' when tags do not match.
Unfortunately, inner evals are still present when we branch
on tags. This is due to the way 'dataToTag#' primop
evaluates its argument in the code generator. #21207 was
created to explore further optimizations.

Metric Decrease:
    LargeRecord
parent 8ff32124
......@@ -152,6 +152,12 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
data Foo ... = N1 | N2 ... | Nn | O1 a b | O2 Int | O3 Double b b | ...
* We first attempt to compare the constructor tags. If tags don't
match - we immediately bail out. Otherwise, we then generate one
branch per constructor comparing only the fields as we already
know that the tags match. Note that it only makes sense to check
the tag if there is more than one data constructor.
* For the ordinary constructors (if any), we emit clauses to do The
Usual Thing, e.g.,:
......@@ -164,23 +170,29 @@ possibly zero of them). Here's an example, with both \tr{N}ullary and
case (a1 `eqFloat#` a2) of r -> r
for that particular test.
* For nullary constructors, we emit a
catch-all clause of the form:
* For nullary constructors, we emit a catch-all clause that always
returns True since we already know that the tags match.
* So, given this data type:
data T = A | B Int | C Char
(==) a b = case (dataToTag# a) of { a# ->
case (dataToTag# b) of { b# ->
case (a# ==# b#) of {
r -> r }}}
We roughly get:
(==) a b =
case dataToTag# a /= dataToTag# b of
True -> False
False -> case a of -- Here we already know that tags match
B a1 -> case b of
B b1 -> a1 == b1 -- Only one branch
C a1 -> case b of
C b1 -> a1 == b1 -- Only one branch
_ -> True -- catch-all to match all nullary ctors
An older approach preferred regular pattern matches in some cases
but with dataToTag# forcing it's argument, and work on improving
join points, this seems no longer necessary.
* If there aren't any nullary constructors, we emit a simpler
catch-all:
(==) a b = False
* For the @(/=)@ method, we normally just use the default method.
If the type is an enumeration type, we could/may/should? generate
special code that calls @dataToTag#@, much like for @(==)@ shown
......@@ -202,58 +214,68 @@ gen_Eq_binds loc dit@(DerivInstTys{ dit_rep_tc = tycon
return (method_binds, emptyBag)
where
all_cons = getPossibleDataCons tycon tycon_args
(nullary_cons, non_nullary_cons) = partition isNullarySrcDataCon all_cons
-- For nullary constructors, use the getTag stuff.
(tag_match_cons, pat_match_cons) = (nullary_cons, non_nullary_cons)
no_tag_match_cons = null tag_match_cons
-- (LHS patterns, result)
fall_through_eqn :: [([LPat (GhcPass 'Parsed)] , LHsExpr GhcPs)]
fall_through_eqn
| no_tag_match_cons -- All constructors have arguments
= case pat_match_cons of
[] -> [] -- No constructors; no fall-though case
[_] -> [] -- One constructor; no fall-though case
_ -> -- Two or more constructors; add fall-through of
-- (==) _ _ = False
[([nlWildPat, nlWildPat], false_Expr)]
| otherwise -- One or more tag_match cons; add fall-through of
-- extract tags compare for equality,
-- The case `(C1 x) == (C1 y)` can no longer happen
-- at this point as it's matched earlier.
= [([a_Pat, b_Pat],
untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
(genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
non_nullary_cons = filter (not . isNullarySrcDataCon) all_cons
-- Generate tag check. See #17240
eq_expr_with_tag_check = nlHsCase
(nlHsPar (untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
(nlHsOpApp (nlHsVar ah_RDR) neInt_RDR (nlHsVar bh_RDR))))
[ mkHsCaseAlt (nlLitPat (HsIntPrim NoSourceText 1)) false_Expr
, mkHsCaseAlt nlWildPat (
nlHsCase
(nlHsVar a_RDR)
-- Only one branch to match all nullary constructors
-- as we already know the tags match but do not emit
-- the branch if there are no nullary constructors
(let non_nullary_pats = map pats_etc non_nullary_cons
in if null non_nullary_cons
then non_nullary_pats
else non_nullary_pats ++ [mkHsCaseAlt nlWildPat true_Expr]))
]
method_binds = unitBag eq_bind
eq_bind
= mkFunBindEC 2 loc eq_RDR (const true_Expr)
(map pats_etc pat_match_cons
++ fall_through_eqn)
eq_bind = mkFunBindEC 2 loc eq_RDR (const true_Expr) binds
where
binds
| null all_cons = []
-- Tag checking is redundant when there is only one data constructor
| [data_con] <- all_cons
, (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
, data_con_RDR <- getRdrName data_con
, con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
, con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
, eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
= [([con1_pat, con2_pat], eq_expr)]
-- This is an enum (all constructors are nullary) - just do a simple tag check
| all isNullarySrcDataCon all_cons
= [([a_Pat, b_Pat], untag_Expr [(a_RDR,ah_RDR), (b_RDR,bh_RDR)]
(genPrimOpApp (nlHsVar ah_RDR) eqInt_RDR (nlHsVar bh_RDR)))]
| otherwise
= [([a_Pat, b_Pat], eq_expr_with_tag_check)]
------------------------------------------------------------------
pats_etc data_con
= let
con1_pat = nlParPat $ nlConVarPat data_con_RDR as_needed
con2_pat = nlParPat $ nlConVarPat data_con_RDR bs_needed
data_con_RDR = getRdrName data_con
con_arity = length tys_needed
as_needed = take con_arity as_RDRs
bs_needed = take con_arity bs_RDRs
tys_needed = derivDataConInstArgTys data_con dit
in
([con1_pat, con2_pat], nested_eq_expr tys_needed as_needed bs_needed)
nested_eq_expr [] [] [] = true_Expr
nested_eq_expr tys as bs
= foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
-- Using 'foldr1' here ensures that the derived code is correctly
-- associated. See #10859.
where
nested_eq_expr [] [] [] = true_Expr
nested_eq_expr tys as bs
= foldr1 and_Expr (zipWith3Equal "nested_eq" nested_eq tys as bs)
-- Using 'foldr1' here ensures that the derived code is correctly
-- associated. See #10859.
where
nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
nested_eq ty a b = nlHsPar (eq_Expr ty (nlHsVar a) (nlHsVar b))
gen_con_fields_and_tys data_con
| tys_needed <- derivDataConInstArgTys data_con dit
, con_arity <- length tys_needed
, as_needed <- take con_arity as_RDRs
, bs_needed <- take con_arity bs_RDRs
= (as_needed, bs_needed, tys_needed)
pats_etc data_con
| (as_needed, bs_needed, tys_needed) <- gen_con_fields_and_tys data_con
, data_con_RDR <- getRdrName data_con
, con1_pat <- nlParPat $ nlConVarPat data_con_RDR as_needed
, con2_pat <- nlParPat $ nlConVarPat data_con_RDR bs_needed
, fields_eq_expr <- nested_eq_expr tys_needed as_needed bs_needed
= mkHsCaseAlt con1_pat (nlHsCase (nlHsVar b_RDR) [mkHsCaseAlt con2_pat fields_eq_expr])
{-
************************************************************************
......@@ -1473,7 +1495,7 @@ gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstrTag_RDR,
dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR,
constr_RDR, dataType_RDR,
eqChar_RDR , ltChar_RDR , geChar_RDR , gtChar_RDR , leChar_RDR ,
eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR ,
eqInt_RDR , ltInt_RDR , geInt_RDR , gtInt_RDR , leInt_RDR , neInt_RDR ,
eqInt8_RDR , ltInt8_RDR , geInt8_RDR , gtInt8_RDR , leInt8_RDR ,
eqInt16_RDR , ltInt16_RDR , geInt16_RDR , gtInt16_RDR , leInt16_RDR ,
eqInt32_RDR , ltInt32_RDR , geInt32_RDR , gtInt32_RDR , leInt32_RDR ,
......@@ -1513,6 +1535,7 @@ gtChar_RDR = varQual_RDR gHC_PRIM (fsLit "gtChar#")
geChar_RDR = varQual_RDR gHC_PRIM (fsLit "geChar#")
eqInt_RDR = varQual_RDR gHC_PRIM (fsLit "==#")
neInt_RDR = varQual_RDR gHC_PRIM (fsLit "/=#")
ltInt_RDR = varQual_RDR gHC_PRIM (fsLit "<#" )
leInt_RDR = varQual_RDR gHC_PRIM (fsLit "<=#")
gtInt_RDR = varQual_RDR gHC_PRIM (fsLit ">#" )
......
module T17240 where
data T = A | B Int | C Char | D Int deriving Eq
data Nullary = X | Y | Z deriving Eq
==================== Derived instances ====================
Derived class instances:
instance GHC.Classes.Eq T17240.Nullary where
(GHC.Classes.==) a b
= case (GHC.Prim.dataToTag# a) of
a#
-> case (GHC.Prim.dataToTag# b) of
b# -> (GHC.Prim.tagToEnum# (a# GHC.Prim.==# b#))
instance GHC.Classes.Eq T17240.T where
(GHC.Classes.==) a b
= case
(case (GHC.Prim.dataToTag# a) of
a# -> case (GHC.Prim.dataToTag# b) of b# -> a# GHC.Prim./=# b#)
of
1# -> GHC.Types.False
_ -> case a of
(T17240.B a1)
-> case b of (T17240.B b1) -> ((a1 GHC.Classes.== b1))
(T17240.C a1)
-> case b of (T17240.C b1) -> ((a1 GHC.Classes.== b1))
(T17240.D a1)
-> case b of (T17240.D b1) -> ((a1 GHC.Classes.== b1))
_ -> GHC.Types.True
Derived type family instances:
==================== Filling in method body ====================
GHC.Classes.Eq [T17240.Nullary]
GHC.Classes./= = GHC.Classes.$dm/= @(T17240.Nullary)
==================== Filling in method body ====================
GHC.Classes.Eq [T17240.T]
GHC.Classes./= = GHC.Classes.$dm/= @(T17240.T)
......@@ -122,6 +122,7 @@ test('T15831', normal, compile, [''])
test('T16179', normal, compile, [''])
test('T16341', normal, compile, [''])
test('T16518', normal, compile, [''])
test('T17240', normal, compile, ['-ddump-deriv -dsuppress-uniques'])
test('T17324', normal, compile, [''])
test('T17339', normal, compile,
['-ddump-simpl -dsuppress-idinfo -dno-typeable-binds'])
......
......@@ -3,8 +3,8 @@ drvfail011.hs:8:1: error:
• No instance for (Eq a) arising from a use of ‘==’
Possible fix: add (Eq a) to the context of the instance declaration
• In the expression: a1 == b1
In an equation for ‘==’: (==) (T1 a1) (T1 b1) = ((a1 == b1))
In a case alternative: (T1 b1) -> ((a1 == b1))
In the expression: case b of (T1 b1) -> ((a1 == b1))
When typechecking the code for ‘==’
in a derived instance for ‘Eq (T a)’:
To see the code I am typechecking, use -ddump-deriv
In the instance declaration for ‘Eq (T a)’
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment