Commit eb975d2e authored by Ben Gamari's avatar Ben Gamari Committed by Ben Gamari

Fix treatment of -0.0

Here we fix a few mis-optimizations that could occur in code with
floating point comparisons with -0.0. These issues arose from our
insistence on rewriting equalities into case analyses and the
simplifier's ignorance of floating-point semantics.

For instance, in Trac #10215 (and the similar issue Trac #9238) we
turned `ds == 0.0` into a case analysis,

```
case ds of
    __DEFAULT -> ...
    0.0 -> ...
```

Where the second alternative matches where `ds` is +0.0 and *also* -0.0.
However, the simplifier doesn't realize this and will introduce a local
inlining of `ds = -- +0.0` as it believes this is the only
value that matches this pattern.

Instead of teaching the simplifier about floating-point semantics
we simply prohibit case analysis on floating-point scrutinees and keep
this logic in the comparison primops, where it belongs.

We do several things here,

 - Add test cases from relevant tickets
 - Clean up a bit of documentation
 - Desugar literal matches against floats into applications of the
   appropriate equality primitive instead of case analysis
 - Add a CoreLint to ensure we don't pattern match on floats in Core

Test Plan: validate with included testcases

Reviewers: goldfire, simonpj, austin

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D1061

GHC Trac Issues: #10215, #9238
parent a52db231
...@@ -32,6 +32,7 @@ import Literal ...@@ -32,6 +32,7 @@ import Literal
import DataCon import DataCon
import TysWiredIn import TysWiredIn
import TysPrim import TysPrim
import TcType ( isFloatingTy )
import Var import Var
import VarEnv import VarEnv
import VarSet import VarSet
...@@ -662,6 +663,15 @@ lintCoreExpr e@(Case scrut var alt_ty alts) = ...@@ -662,6 +663,15 @@ lintCoreExpr e@(Case scrut var alt_ty alts) =
(ptext (sLit "No alternatives for a case scrutinee not known to diverge for sure:") <+> ppr scrut) (ptext (sLit "No alternatives for a case scrutinee not known to diverge for sure:") <+> ppr scrut)
} }
-- See Note [Rules for floating-point comparisons] in PrelRules
; let isLitPat (LitAlt _, _ , _) = True
isLitPat _ = False
; checkL (not $ isFloatingTy scrut_ty && any isLitPat alts)
(ptext (sLit $ "Lint warning: Scrutinising floating-point " ++
"expression with literal pattern in case " ++
"analysis (see Trac #9238).")
$$ text "scrut" <+> ppr scrut)
; case tyConAppTyCon_maybe (idType var) of ; case tyConAppTyCon_maybe (idType var) of
Just tycon Just tycon
| debugIsOn && | debugIsOn &&
......
...@@ -233,6 +233,10 @@ These data types are the heart of the compiler ...@@ -233,6 +233,10 @@ These data types are the heart of the compiler
-- The inner case does not need a @Red@ alternative, because @x@ -- The inner case does not need a @Red@ alternative, because @x@
-- can't be @Red@ at that program point. -- can't be @Red@ at that program point.
-- --
-- 5. Floating-point values must not be scrutinised against literals.
-- See Trac #9238 and Note [Rules for floating-point comparisons]
-- in PrelRules for rationale.
--
-- * Cast an expression to a particular type. -- * Cast an expression to a particular type.
-- This is used to implement @newtype@s (a @newtype@ constructor or -- This is used to implement @newtype@s (a @newtype@ constructor or
-- destructor just becomes a 'Cast' in Core) and GADTs. -- destructor just becomes a 'Cast' in Core) and GADTs.
...@@ -329,6 +333,9 @@ simplifier calling findAlt with argument (LitAlt 3). No no. Integer ...@@ -329,6 +333,9 @@ simplifier calling findAlt with argument (LitAlt 3). No no. Integer
literals are an opaque encoding of an algebraic data type, not of literals are an opaque encoding of an algebraic data type, not of
an unlifted literal, like all the others. an unlifted literal, like all the others.
Also, we do not permit case analysis with literal patterns on floating-point
types. See Trac #9238 and Note [Rules for floating-point comparisons] in
PrelRules for the rationale for this restriction.
-------------------------- CoreSyn INVARIANTS --------------------------- -------------------------- CoreSyn INVARIANTS ---------------------------
......
...@@ -295,10 +295,12 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _ ...@@ -295,10 +295,12 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _
= mk_con_pat intDataCon (HsIntPrim "" int_lit) = mk_con_pat intDataCon (HsIntPrim "" int_lit)
| isWordTy ty, Just int_lit <- mb_int_lit | isWordTy ty, Just int_lit <- mb_int_lit
= mk_con_pat wordDataCon (HsWordPrim "" int_lit) = mk_con_pat wordDataCon (HsWordPrim "" int_lit)
| isFloatTy ty, Just rat_lit <- mb_rat_lit = mk_con_pat floatDataCon (HsFloatPrim rat_lit)
| isDoubleTy ty, Just rat_lit <- mb_rat_lit = mk_con_pat doubleDataCon (HsDoublePrim rat_lit)
| isStringTy ty, Just str_lit <- mb_str_lit | isStringTy ty, Just str_lit <- mb_str_lit
= tidy_lit_pat (HsString "" str_lit) = tidy_lit_pat (HsString "" str_lit)
-- NB: do /not/ convert Float or Double literals to F# 3.8 or D# 5.3
-- If we do convert to the constructor form, we'll generate a case
-- expression on a Float# or Double# and that's not allowed in Core; see
-- Trac #9238 and Note [Rules for floating-point comparisons] in PrelRules
where where
mk_con_pat :: DataCon -> HsLit -> Pat Id mk_con_pat :: DataCon -> HsLit -> Pat Id
mk_con_pat con lit = unLoc (mkPrefixConPat con [noLoc $ LitPat lit] []) mk_con_pat con lit = unLoc (mkPrefixConPat con [noLoc $ LitPat lit] [])
...@@ -309,15 +311,6 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _ ...@@ -309,15 +311,6 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _
(Just _, HsIntegral _ i) -> Just (-i) (Just _, HsIntegral _ i) -> Just (-i)
_ -> Nothing _ -> Nothing
mb_rat_lit :: Maybe FractionalLit
mb_rat_lit = case (mb_neg, val) of
(Nothing, HsIntegral _ i) -> Just (integralFractionalLit (fromInteger i))
(Just _, HsIntegral _ i) -> Just (integralFractionalLit
(fromInteger (-i)))
(Nothing, HsFractional f) -> Just f
(Just _, HsFractional f) -> Just (negateFractionalLit f)
_ -> Nothing
mb_str_lit :: Maybe FastString mb_str_lit :: Maybe FastString
mb_str_lit = case (mb_neg, val) of mb_str_lit = case (mb_neg, val) of
(Nothing, HsIsString _ s) -> Just s (Nothing, HsIsString _ s) -> Just s
......
...@@ -241,19 +241,19 @@ primOpRules nm CharGeOp = mkRelOpRule nm (>=) [ boundsCmp Ge ] ...@@ -241,19 +241,19 @@ primOpRules nm CharGeOp = mkRelOpRule nm (>=) [ boundsCmp Ge ]
primOpRules nm CharLeOp = mkRelOpRule nm (<=) [ boundsCmp Le ] primOpRules nm CharLeOp = mkRelOpRule nm (<=) [ boundsCmp Le ]
primOpRules nm CharLtOp = mkRelOpRule nm (<) [ boundsCmp Lt ] primOpRules nm CharLtOp = mkRelOpRule nm (<) [ boundsCmp Lt ]
primOpRules nm FloatGtOp = mkFloatingRelOpRule nm (>) [] primOpRules nm FloatGtOp = mkFloatingRelOpRule nm (>)
primOpRules nm FloatGeOp = mkFloatingRelOpRule nm (>=) [] primOpRules nm FloatGeOp = mkFloatingRelOpRule nm (>=)
primOpRules nm FloatLeOp = mkFloatingRelOpRule nm (<=) [] primOpRules nm FloatLeOp = mkFloatingRelOpRule nm (<=)
primOpRules nm FloatLtOp = mkFloatingRelOpRule nm (<) [] primOpRules nm FloatLtOp = mkFloatingRelOpRule nm (<)
primOpRules nm FloatEqOp = mkFloatingRelOpRule nm (==) [ litEq True ] primOpRules nm FloatEqOp = mkFloatingRelOpRule nm (==)
primOpRules nm FloatNeOp = mkFloatingRelOpRule nm (/=) [ litEq False ] primOpRules nm FloatNeOp = mkFloatingRelOpRule nm (/=)
primOpRules nm DoubleGtOp = mkFloatingRelOpRule nm (>) [] primOpRules nm DoubleGtOp = mkFloatingRelOpRule nm (>)
primOpRules nm DoubleGeOp = mkFloatingRelOpRule nm (>=) [] primOpRules nm DoubleGeOp = mkFloatingRelOpRule nm (>=)
primOpRules nm DoubleLeOp = mkFloatingRelOpRule nm (<=) [] primOpRules nm DoubleLeOp = mkFloatingRelOpRule nm (<=)
primOpRules nm DoubleLtOp = mkFloatingRelOpRule nm (<) [] primOpRules nm DoubleLtOp = mkFloatingRelOpRule nm (<)
primOpRules nm DoubleEqOp = mkFloatingRelOpRule nm (==) [ litEq True ] primOpRules nm DoubleEqOp = mkFloatingRelOpRule nm (==)
primOpRules nm DoubleNeOp = mkFloatingRelOpRule nm (/=) [ litEq False ] primOpRules nm DoubleNeOp = mkFloatingRelOpRule nm (/=)
primOpRules nm WordGtOp = mkRelOpRule nm (>) [ boundsCmp Gt ] primOpRules nm WordGtOp = mkRelOpRule nm (>) [ boundsCmp Gt ]
primOpRules nm WordGeOp = mkRelOpRule nm (>=) [ boundsCmp Ge ] primOpRules nm WordGeOp = mkRelOpRule nm (>=) [ boundsCmp Ge ]
...@@ -284,29 +284,49 @@ mkPrimOpRule nm arity rules = Just $ mkBasicRule nm arity (msum rules) ...@@ -284,29 +284,49 @@ mkPrimOpRule nm arity rules = Just $ mkBasicRule nm arity (msum rules)
mkRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool) mkRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
-> [RuleM CoreExpr] -> Maybe CoreRule -> [RuleM CoreExpr] -> Maybe CoreRule
mkRelOpRule nm cmp extra mkRelOpRule nm cmp extra
= mkPrimOpRule nm 2 $ rules ++ extra = mkPrimOpRule nm 2 $
binaryCmpLit cmp : equal_rule : extra
where where
rules = [ binaryCmpLit cmp -- x `cmp` x does not depend on x, so
, do equalArgs -- compute it for the arbitrary value 'True'
-- x `cmp` x does not depend on x, so -- and use that result
-- compute it for the arbitrary value 'True' equal_rule = do { equalArgs
-- and use that result ; dflags <- getDynFlags
dflags <- getDynFlags ; return (if cmp True True
return (if cmp True True then trueValInt dflags
then trueValInt dflags else falseValInt dflags) }
else falseValInt dflags) ]
{- Note [Rules for floating-point comparisons]
-- Note [Rules for floating-point comparisons] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We need different rules for floating-point values because for floats
-- it is not true that x = x (for NaNs); so we do not want the equal_rule
-- We need different rules for floating-point values because for floats rule that mkRelOpRule uses.
-- it is not true that x = x. The special case when this does not occur
-- are NaNs. Note also that, in the case of equality/inequality, we do /not/
want to switch to a case-expression. For example, we do not want
to convert
case (eqFloat# x 3.8#) of
True -> this
False -> that
to
case x of
3.8#::Float# -> this
_ -> that
See Trac #9238. Reason: comparing floating-point values for equality
delicate, and we don't want to implement that delicacy in the code for
case expressions. So we make it an invariant of Core that a case
expression never scrutinises a Float# or Double#.
This transformation is what the litEq rule does;
see Note [The litEq rule: converting equality to case].
So we /refrain/ from using litEq for mkFloatingRelOpRule.
-}
mkFloatingRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool) mkFloatingRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
-> [RuleM CoreExpr] -> Maybe CoreRule -> Maybe CoreRule
mkFloatingRelOpRule nm cmp extra -- See Note [Rules for floating-point comparisons] -- See Note [Rules for floating-point comparisons]
= mkPrimOpRule nm 2 $ binaryCmpLit cmp : extra mkFloatingRelOpRule nm cmp
= mkPrimOpRule nm 2 [binaryCmpLit cmp]
-- common constants -- common constants
zeroi, onei, zerow, onew :: DynFlags -> Literal zeroi, onei, zerow, onew :: DynFlags -> Literal
...@@ -428,24 +448,27 @@ doubleOp2 op dflags (MachDouble f1) (MachDouble f2) ...@@ -428,24 +448,27 @@ doubleOp2 op dflags (MachDouble f1) (MachDouble f2)
doubleOp2 _ _ _ _ = Nothing doubleOp2 _ _ _ _ = Nothing
-------------------------- --------------------------
-- This stuff turns {- Note [The litEq rule: converting equality to case]
-- n ==# 3# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- into This stuff turns
-- case n of n ==# 3#
-- 3# -> True into
-- m -> False case n of
-- 3# -> True
-- This is a Good Thing, because it allows case-of case things m -> False
-- to happen, and case-default absorption to happen. For
-- example: This is a Good Thing, because it allows case-of case things
-- to happen, and case-default absorption to happen. For
-- if (n ==# 3#) || (n ==# 4#) then e1 else e2 example:
-- will transform to
-- case n of if (n ==# 3#) || (n ==# 4#) then e1 else e2
-- 3# -> e1 will transform to
-- 4# -> e1 case n of
-- m -> e2 3# -> e1
-- (modulo the usual precautions to avoid duplicating e1) 4# -> e1
m -> e2
(modulo the usual precautions to avoid duplicating e1)
-}
litEq :: Bool -- True <=> equality, False <=> inequality litEq :: Bool -- True <=> equality, False <=> inequality
-> RuleM CoreExpr -> RuleM CoreExpr
......
...@@ -65,7 +65,7 @@ module TcType ( ...@@ -65,7 +65,7 @@ module TcType (
eqType, eqTypes, eqPred, cmpType, cmpTypes, cmpPred, eqTypeX, eqType, eqTypes, eqPred, cmpType, cmpTypes, cmpPred, eqTypeX,
tcEqType, tcEqKind, tcEqType, tcEqKind,
isSigmaTy, isRhoTy, isOverloadedTy, isSigmaTy, isRhoTy, isOverloadedTy,
isDoubleTy, isFloatTy, isIntTy, isWordTy, isStringTy, isFloatingTy, isDoubleTy, isFloatTy, isIntTy, isWordTy, isStringTy,
isIntegerTy, isBoolTy, isUnitTy, isCharTy, isIntegerTy, isBoolTy, isUnitTy, isCharTy,
isTauTy, isTauTyCon, tcIsTyVarTy, tcIsForAllTy, isTauTy, isTauTyCon, tcIsTyVarTy, tcIsForAllTy,
isPredTy, isTyVarClassPred, isTyVarExposed, isTyVarUnderDatatype, isPredTy, isTyVarClassPred, isTyVarExposed, isTyVarUnderDatatype,
...@@ -1439,6 +1439,11 @@ isUnitTy = is_tc unitTyConKey ...@@ -1439,6 +1439,11 @@ isUnitTy = is_tc unitTyConKey
isCharTy = is_tc charTyConKey isCharTy = is_tc charTyConKey
isAnyTy = is_tc anyTyConKey isAnyTy = is_tc anyTyConKey
-- | Does a type represent a floating-point number?
isFloatingTy :: Type -> Bool
isFloatingTy ty = isFloatTy ty || isDoubleTy ty
-- | Is a type 'String'?
isStringTy :: Type -> Bool isStringTy :: Type -> Bool
isStringTy ty isStringTy ty
= case tcSplitTyConApp_maybe ty of = case tcSplitTyConApp_maybe ty of
......
testF :: Float -> Bool
testF x = x == 0 && not (isNegativeZero x)
testD :: Double -> Bool
testD x = x == 0 && not (isNegativeZero x)
main :: IO ()
main = do print $ testF (-0.0)
print $ testD (-0.0)
compareDouble :: Double -> Double -> Ordering
compareDouble x y =
case (isNaN x, isNaN y) of
(True, True) -> EQ
(True, False) -> LT
(False, True) -> GT
(False, False) ->
-- Make -0 less than 0
case (x == 0, y == 0, isNegativeZero x, isNegativeZero y) of
(True, True, True, False) -> LT
(True, True, False, True) -> GT
_ -> x `compare` y
main = do
let l = [-0, 0]
print [ (x, y, compareDouble x y) | x <- l, y <- l ]
[(-0.0,-0.0,EQ),(-0.0,0.0,LT),(0.0,-0.0,GT),(0.0,0.0,EQ)]
...@@ -46,5 +46,7 @@ test('DsStaticPointers', ...@@ -46,5 +46,7 @@ test('DsStaticPointers',
], ],
compile_and_run, ['']) compile_and_run, [''])
test('T8952', normal, compile_and_run, ['']) test('T8952', normal, compile_and_run, [''])
test('T9238', normal, compile_and_run, [''])
test('T9844', normal, compile_and_run, ['']) test('T9844', normal, compile_and_run, [''])
test('T10215', normal, compile_and_run, [''])
test('DsStrictData', normal, compile_and_run, ['']) test('DsStrictData', normal, compile_and_run, [''])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment