From f746e1557b80c5c097529c9198bab2a25083d0fb Mon Sep 17 00:00:00 2001
From: Ben Gamari <ben@smart-cactus.org>
Date: Wed, 22 Jan 2025 13:08:22 -0500
Subject: [PATCH 1/4] Number: Introduce bounds

---
 src/Number.hs | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/src/Number.hs b/src/Number.hs
index 5bf6178..47464f1 100644
--- a/src/Number.hs
+++ b/src/Number.hs
@@ -13,6 +13,7 @@ module Number
     , fromUnsigned
     , fromUnsignedC
     , chooseNumber
+    , bounds
     , signedMinBound
     , n8, n16, n32, n64
     , ones
@@ -160,6 +161,17 @@ instance (KnownWidth width) => Bounded (Number width) where
 signedMinBound :: forall width. (KnownWidth width) => Number width
 signedMinBound = fromIntegral (toUnsigned (maxBound :: Number width) `div` 2 + 1)
 
+bounds :: forall width. (KnownWidth width) => Signedness -> (Integer, Integer)
+bounds Signed = (a, b)
+  where
+    w = widthBits (knownWidth @width) - 1
+    a = negate (1 `shiftL` w)
+    b = (1 `shiftL` w) - 1
+bounds Unsigned = (0, b)
+  where
+    w = widthBits $ knownWidth @width
+    b = (1 `shiftL` w) - 1
+
 instance (KnownWidth width) => Arbitrary (Number width) where
     arbitrary = chooseNumber (minBound, maxBound)
     shrink (Number 0) = []
-- 
GitLab


From a091008fbcd8471bbbf1f111879021159b10f94e Mon Sep 17 00:00:00 2001
From: Ben Gamari <ben@smart-cactus.org>
Date: Wed, 22 Jan 2025 13:08:45 -0500
Subject: [PATCH 2/4] CallishOp: Add support for non-unary results

---
 src/CallishOp.hs | 49 ++++++++++++++++++++++++++----------------------
 1 file changed, 27 insertions(+), 22 deletions(-)

diff --git a/src/CallishOp.hs b/src/CallishOp.hs
index 5ed4071..bde43a3 100644
--- a/src/CallishOp.hs
+++ b/src/CallishOp.hs
@@ -10,9 +10,10 @@ module CallishOp
     , evalCallishOpCmm
     ) where
 
-import Numeric.Natural
+import Data.List (intercalate)
 import Data.Bits
 import Data.Proxy
+import Numeric.Natural
 import Test.QuickCheck
 import Test.Tasty
 import Test.Tasty.QuickCheck
@@ -54,18 +55,16 @@ prop_callish_ops_correct em = testGroup "callish ops"
         ]
 
 popcnt :: forall w. (KnownWidth w) => CallishOp (Expr w) WordSize
-popcnt = CallishOp
-    { name = "%popcnt" ++ show (widthBits (knownWidth @w))
-    , refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret
-    }
+popcnt = simpleCallishOp name refImpl
+  where
+    name = "%popcnt" ++ show (widthBits (knownWidth @w))
+    refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret
 
 -- | Arguments are @(source, mask)@.
 pdep :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w
-pdep = CallishOp
-    { name = "%pdep" ++ show (widthBits (knownWidth @w))
-    , refImpl = uncurry ref
-    }
+pdep = simpleCallishOp name (uncurry ref)
   where
+    name = "%pdep" ++ show (widthBits (knownWidth @w))
     ref :: Expr w -> Expr w -> Number w
     ref x0 mask0 = fromUnsigned $ fromBits $ go (exprBits mask0) (exprBits x0)
       where
@@ -78,11 +77,9 @@ pdep = CallishOp
 
 -- | Arguments are @(source, mask)@.
 pext :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w
-pext = CallishOp
-    { name = "%pext" ++ show (widthBits (knownWidth @w))
-    , refImpl = uncurry ref
-    }
+pext = simpleCallishOp name (uncurry ref)
   where
+    name = "%pext" ++ show (widthBits (knownWidth @w))
     ref :: Expr w -> Expr w -> Number w
     ref x mask =
         fromUnsigned
@@ -93,11 +90,9 @@ pext = CallishOp
 
 -- | Arguments are @(source, mask)@.
 bswap :: forall w. (KnownWidth w) => CallishOp (Expr w) w
-bswap = CallishOp
-    { name = "%bswap" ++ show (widthBits (knownWidth @w))
-    , refImpl = ref . interpret
-    }
+bswap = simpleCallishOp name (ref . interpret)
   where
+    name = "%bswap" ++ show (widthBits (knownWidth @w))
     ref :: Number w -> Number w
     ref = fromBytes . reverse . toBytes
       where
@@ -121,9 +116,15 @@ prop_callish_correct em op args = counterexample (evalCallishOpCmm op args) $ io
     r <- evalCallishOp em op args
     return $ refImpl op args === r
 
+simpleCallishOp :: String -> (args -> Number result) -> CallishOp args result
+simpleCallishOp name refImpl =
+    CallishOp { name, refImpl, resultArity = 1, select = 0 }
+
 data CallishOp args result
     = CallishOp { name :: String
                 , refImpl :: args -> Number result
+                , resultArity :: Int -- ^ the arity of the result
+                , select :: Int      -- ^ which component of the result to select
                 }
 
 class CmmArgs arg where
@@ -149,14 +150,18 @@ evalCallishOpCmm
     => CallishOp args width
     -> args
     -> String
-evalCallishOpCmm op args = unlines
-    [ "test ( " <> cmmWordType <> " buffer ) {"
-    , "  " <> cmmType width <> " ret;"
-    , "  (ret) = prim " ++ name op ++ argList ++ ";"
-    , "  return (ret);"
+evalCallishOpCmm op args = unlines $
+    [ "test ( " <> cmmWordType <> " buffer ) {" ] ++
+    [ "  " <> cmmType width <> " " <> bndr <> ";" | bndr <- bndrs ] ++
+    [ "  " <> lhs <> " = prim " ++ name op ++ argList ++ ";"
+    , "  return (" <> result <> ");"
     , "}"
     ]
   where
+    lhs = "(" <> intercalate "," bndrs <> ")"
+    bndrs :: [String]
+    bndrs = take (resultArity op) [ "res" <> show n | n <- [0 :: Int ..] ]
+    result = bndrs !! select op
     argList = parens $ commaList [exprToCmm e | SomeExpr e <- getArgs args]
     width = knownWidth @width
 
-- 
GitLab


From 0ada780e7bb78f1cbdfe7062ef13da4d7751319d Mon Sep 17 00:00:00 2001
From: Ben Gamari <ben@smart-cactus.org>
Date: Wed, 22 Jan 2025 13:11:18 -0500
Subject: [PATCH 3/4] CallishOp: Test Op_Mul2 correctness

---
 src/CallishOp.hs | 57 +++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 51 insertions(+), 6 deletions(-)

diff --git a/src/CallishOp.hs b/src/CallishOp.hs
index bde43a3..8d5cc28 100644
--- a/src/CallishOp.hs
+++ b/src/CallishOp.hs
@@ -2,7 +2,7 @@
 module CallishOp
     ( prop_callish_ops_correct
       -- * Individual tests
-    , popcnt, pdep, pext
+    , popcnt, pdep, pext, mul2Op, mul2uOp
     , refImpl
     , prop_callish_correct
       -- * Evaluating callish machops
@@ -14,9 +14,9 @@ import Data.List (intercalate)
 import Data.Bits
 import Data.Proxy
 import Numeric.Natural
-import Test.QuickCheck
+import Test.QuickCheck hiding ((.&.))
 import Test.Tasty
-import Test.Tasty.QuickCheck
+import Test.Tasty.QuickCheck hiding ((.&.))
 import Data.Foldable (foldl')
 
 import Width
@@ -28,9 +28,13 @@ import Expr
 prop_callish_ops_correct :: EvalMethod -> TestTree
 prop_callish_ops_correct em = testGroup "callish ops"
     [ testCallishOpNoW8 "bswap"  (\(_ :: Proxy w) -> toProp $ bswap @w)
-    , testCallishOp "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w)
-    , testCallishOp "pdep"   (\(_ :: Proxy w) -> toProp $ pdep @w)
-    , testCallishOp "pext"   (\(_ :: Proxy w) -> toProp $ pext @w)
+    , testCallishOp     "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w)
+    , testCallishOp     "pdep"   (\(_ :: Proxy w) -> toProp $ pdep @w)
+    , testCallishOp     "pext"   (\(_ :: Proxy w) -> toProp $ pext @w)
+    , testCallishOp "mul2_lo"    (\(_ :: Proxy w) -> toProp $ mul2Op @w False)
+    , testCallishOp "mul2_hi"    (\(_ :: Proxy w) -> toProp $ mul2Op @w True)
+    , testCallishOp "mul2u_lo"   (\(_ :: Proxy w) -> toProp $ mul2uOp @w False)
+    , testCallishOp "mul2u_hi"   (\(_ :: Proxy w) -> toProp $ mul2uOp @w True)
     ]
   where
     toProp :: forall args w. (CmmArgs args, Arbitrary args, Show args, KnownWidth w)
@@ -99,6 +103,47 @@ bswap = simpleCallishOp name (ref . interpret)
         toBytes = chunk 8 . toBits
         fromBytes = fromUnsigned . fromBits . concat
 
+mul2Op :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w
+mul2Op high =
+    CallishOp { name, refImpl, resultArity = 3, select }
+  where
+    select = if high then 1 else 2
+    name = "%mul2_" <> show (widthBits (knownWidth @w))
+    refImpl :: (Expr w, Expr w) -> Number w
+    refImpl (a,b)
+      | high      = x
+      | otherwise = y
+      where
+        (x,y) = mul2 Signed (interpret a) (interpret b)
+
+mul2uOp :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w
+mul2uOp high =
+    CallishOp { name, refImpl, resultArity = 2, select }
+  where
+    select = if high then 0 else 1
+    name = "%mul2u_" <> show (widthBits (knownWidth @w))
+    refImpl :: (Expr w, Expr w) -> Number w
+    refImpl (a,b)
+      | high      = x
+      | otherwise = y
+      where
+        (x,y) = mul2 Unsigned (interpret a) (interpret b)
+
+mul2 :: forall w. (KnownWidth w)
+     => Signedness -> Number w -> Number w -> (Number w, Number w)
+mul2 Unsigned a b =
+    (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo)
+  where
+    r = asInteger Unsigned a * asInteger Unsigned b
+    (hi, lo) = r `divMod` maxB
+    maxB = 2^widthBits (knownWidth @w)
+mul2 Signed a b =
+    (fromSigned hi, fromUnsignedC $ fromIntegral lo)
+  where
+    r = asInteger Signed a * asInteger Signed b
+    (hi, lo) = r `divMod` maxB
+    maxB = 2^widthBits (knownWidth @w)
+
 toBits :: forall w. (KnownWidth w) => Number w -> [Bool]
 toBits x =
     [ x `testBit` i | i <- [0..widthBits (knownWidth @w)-1] ]
-- 
GitLab


From a6f2cb85d64bcc69f99a7082ea6d1f295f8a66bd Mon Sep 17 00:00:00 2001
From: Ben Gamari <ben@smart-cactus.org>
Date: Wed, 22 Jan 2025 19:51:35 -0500
Subject: [PATCH 4/4] Simplify mul2 reference implementation

---
 src/CallishOp.hs | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/src/CallishOp.hs b/src/CallishOp.hs
index 8d5cc28..d37914d 100644
--- a/src/CallishOp.hs
+++ b/src/CallishOp.hs
@@ -131,16 +131,11 @@ mul2uOp high =
 
 mul2 :: forall w. (KnownWidth w)
      => Signedness -> Number w -> Number w -> (Number w, Number w)
-mul2 Unsigned a b =
-    (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo)
+mul2 s a b
+  | Unsigned <- s = (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo)
+  | otherwise     = (fromSigned hi, fromUnsignedC $ fromIntegral lo)
   where
-    r = asInteger Unsigned a * asInteger Unsigned b
-    (hi, lo) = r `divMod` maxB
-    maxB = 2^widthBits (knownWidth @w)
-mul2 Signed a b =
-    (fromSigned hi, fromUnsignedC $ fromIntegral lo)
-  where
-    r = asInteger Signed a * asInteger Signed b
+    r = asInteger s a * asInteger s b
     (hi, lo) = r `divMod` maxB
     maxB = 2^widthBits (knownWidth @w)
 
-- 
GitLab