From a40e47816f360a9b5c2c9315ba9e99bf0d0c4ceb Mon Sep 17 00:00:00 2001
From: Sylvain Henry <sylvain@haskus.fr>
Date: Wed, 24 Jan 2024 15:11:44 +0100
Subject: [PATCH] Perf: add constant folding for bitcast between float and word
 (#24331)

---
 compiler/GHC/Core.hs                          |  5 ++-
 compiler/GHC/Core/Opt/ConstantFold.hs         | 33 +++++++++++++++++++
 .../numeric/should_compile/T24331.stderr      | 10 +++---
 3 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/compiler/GHC/Core.hs b/compiler/GHC/Core.hs
index cff6cb4871dc..d204151cefbe 100644
--- a/compiler/GHC/Core.hs
+++ b/compiler/GHC/Core.hs
@@ -25,7 +25,7 @@ module GHC.Core (
         mkIntLit, mkIntLitWrap,
         mkWordLit, mkWordLitWrap,
         mkWord8Lit,
-        mkWord64LitWord64, mkInt64LitInt64,
+        mkWord32LitWord32, mkWord64LitWord64, mkInt64LitInt64,
         mkCharLit, mkStringLit,
         mkFloatLit, mkFloatLitFloat,
         mkDoubleLit, mkDoubleLitDouble,
@@ -1901,6 +1901,9 @@ mkWordLitWrap platform w = Lit (mkLitWordWrap platform w)
 mkWord8Lit :: Integer -> Expr b
 mkWord8Lit    w = Lit (mkLitWord8 w)
 
+mkWord32LitWord32 :: Word32 -> Expr b
+mkWord32LitWord32 w = Lit (mkLitWord32 (toInteger w))
+
 mkWord64LitWord64 :: Word64 -> Expr b
 mkWord64LitWord64 w = Lit (mkLitWord64 (toInteger w))
 
diff --git a/compiler/GHC/Core/Opt/ConstantFold.hs b/compiler/GHC/Core/Opt/ConstantFold.hs
index a06db85798dd..02c6fa61e8b1 100644
--- a/compiler/GHC/Core/Opt/ConstantFold.hs
+++ b/compiler/GHC/Core/Opt/ConstantFold.hs
@@ -34,6 +34,7 @@ where
 import GHC.Prelude
 
 import GHC.Platform
+import GHC.Float
 
 import GHC.Types.Id.Make ( unboxedUnitExpr )
 import GHC.Types.Id
@@ -657,6 +658,38 @@ primOpRules nm = \case
                                        , removeOp32
                                        , narrowSubsumesAnd WordAndOp Narrow32WordOp 32 ]
 
+   CastWord64ToDoubleOp -> mkPrimOpRule nm 1
+      [ unaryLit $ \_env -> \case
+         LitNumber _ n
+             | v <- castWord64ToDouble (fromInteger n)
+             -- we can't represent those float literals in Core until #18897 is fixed
+             , not (isNaN v || isInfinite v || isNegativeZero v)
+             -> Just (mkDoubleLitDouble v)
+         _   -> Nothing
+      ]
+
+   CastWord32ToFloatOp -> mkPrimOpRule nm 1
+      [ unaryLit $ \_env -> \case
+          LitNumber _ n
+              | v <- castWord32ToFloat (fromInteger n)
+              -- we can't represent those float literals in Core until #18897 is fixed
+              , not (isNaN v || isInfinite v || isNegativeZero v)
+              -> Just (mkFloatLitFloat v)
+          _   -> Nothing
+      ]
+
+   CastDoubleToWord64Op -> mkPrimOpRule nm 1
+      [ unaryLit $ \_env -> \case
+         LitDouble n -> Just (mkWord64LitWord64 (castDoubleToWord64 (fromRational n)))
+         _           -> Nothing
+      ]
+
+   CastFloatToWord32Op -> mkPrimOpRule nm 1
+      [ unaryLit $ \_env -> \case
+          LitFloat n -> Just (mkWord32LitWord32 (castFloatToWord32 (fromRational n)))
+          _          -> Nothing
+      ]
+
    OrdOp          -> mkPrimOpRule nm 1 [ liftLit charToIntLit
                                        , semiInversePrimOp ChrOp ]
    ChrOp          -> mkPrimOpRule nm 1 [ do [Lit lit] <- getArgs
diff --git a/testsuite/tests/numeric/should_compile/T24331.stderr b/testsuite/tests/numeric/should_compile/T24331.stderr
index 84d0286f27c5..4f51e1596a51 100644
--- a/testsuite/tests/numeric/should_compile/T24331.stderr
+++ b/testsuite/tests/numeric/should_compile/T24331.stderr
@@ -1,15 +1,15 @@
 
 ==================== Tidy Core ====================
 Result size of Tidy Core
-  = {terms: 16, types: 4, coercions: 0, joins: 0/0}
+  = {terms: 12, types: 4, coercions: 0, joins: 0/0}
 
-a = W64# (castDoubleToWord64# 1.0##)
+a = W64# 4607182418800017408#Word64
 
-b = W32# (castFloatToWord32# 2.0#)
+b = W32# 1073741824#Word32
 
-c = D# (castWord64ToDouble# 4621819117588971520#Word64)
+c = D# 10.0##
 
-d = F# (castWord32ToFloat# 1084227584#Word32)
+d = F# 5.0#
 
 
 
-- 
GitLab