From f746e1557b80c5c097529c9198bab2a25083d0fb Mon Sep 17 00:00:00 2001 From: Ben Gamari <ben@smart-cactus.org> Date: Wed, 22 Jan 2025 13:08:22 -0500 Subject: [PATCH 1/4] Number: Introduce bounds --- src/Number.hs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/Number.hs b/src/Number.hs index 5bf6178..47464f1 100644 --- a/src/Number.hs +++ b/src/Number.hs @@ -13,6 +13,7 @@ module Number , fromUnsigned , fromUnsignedC , chooseNumber + , bounds , signedMinBound , n8, n16, n32, n64 , ones @@ -160,6 +161,17 @@ instance (KnownWidth width) => Bounded (Number width) where signedMinBound :: forall width. (KnownWidth width) => Number width signedMinBound = fromIntegral (toUnsigned (maxBound :: Number width) `div` 2 + 1) +bounds :: forall width. (KnownWidth width) => Signedness -> (Integer, Integer) +bounds Signed = (a, b) + where + w = widthBits (knownWidth @width) - 1 + a = negate (1 `shiftL` w) + b = (1 `shiftL` w) - 1 +bounds Unsigned = (0, b) + where + w = widthBits $ knownWidth @width + b = (1 `shiftL` w) - 1 + instance (KnownWidth width) => Arbitrary (Number width) where arbitrary = chooseNumber (minBound, maxBound) shrink (Number 0) = [] -- GitLab From a091008fbcd8471bbbf1f111879021159b10f94e Mon Sep 17 00:00:00 2001 From: Ben Gamari <ben@smart-cactus.org> Date: Wed, 22 Jan 2025 13:08:45 -0500 Subject: [PATCH 2/4] CallishOp: Add support for non-unary results --- src/CallishOp.hs | 49 ++++++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/CallishOp.hs b/src/CallishOp.hs index 5ed4071..bde43a3 100644 --- a/src/CallishOp.hs +++ b/src/CallishOp.hs @@ -10,9 +10,10 @@ module CallishOp , evalCallishOpCmm ) where -import Numeric.Natural +import Data.List (intercalate) import Data.Bits import Data.Proxy +import Numeric.Natural import Test.QuickCheck import Test.Tasty import Test.Tasty.QuickCheck @@ -54,18 +55,16 @@ prop_callish_ops_correct em = testGroup "callish ops" ] popcnt :: forall w. (KnownWidth w) => CallishOp (Expr w) WordSize -popcnt = CallishOp - { name = "%popcnt" ++ show (widthBits (knownWidth @w)) - , refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret - } +popcnt = simpleCallishOp name refImpl + where + name = "%popcnt" ++ show (widthBits (knownWidth @w)) + refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret -- | Arguments are @(source, mask)@. pdep :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w -pdep = CallishOp - { name = "%pdep" ++ show (widthBits (knownWidth @w)) - , refImpl = uncurry ref - } +pdep = simpleCallishOp name (uncurry ref) where + name = "%pdep" ++ show (widthBits (knownWidth @w)) ref :: Expr w -> Expr w -> Number w ref x0 mask0 = fromUnsigned $ fromBits $ go (exprBits mask0) (exprBits x0) where @@ -78,11 +77,9 @@ pdep = CallishOp -- | Arguments are @(source, mask)@. pext :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w -pext = CallishOp - { name = "%pext" ++ show (widthBits (knownWidth @w)) - , refImpl = uncurry ref - } +pext = simpleCallishOp name (uncurry ref) where + name = "%pext" ++ show (widthBits (knownWidth @w)) ref :: Expr w -> Expr w -> Number w ref x mask = fromUnsigned @@ -93,11 +90,9 @@ pext = CallishOp -- | Arguments are @(source, mask)@. bswap :: forall w. (KnownWidth w) => CallishOp (Expr w) w -bswap = CallishOp - { name = "%bswap" ++ show (widthBits (knownWidth @w)) - , refImpl = ref . interpret - } +bswap = simpleCallishOp name (ref . interpret) where + name = "%bswap" ++ show (widthBits (knownWidth @w)) ref :: Number w -> Number w ref = fromBytes . reverse . toBytes where @@ -121,9 +116,15 @@ prop_callish_correct em op args = counterexample (evalCallishOpCmm op args) $ io r <- evalCallishOp em op args return $ refImpl op args === r +simpleCallishOp :: String -> (args -> Number result) -> CallishOp args result +simpleCallishOp name refImpl = + CallishOp { name, refImpl, resultArity = 1, select = 0 } + data CallishOp args result = CallishOp { name :: String , refImpl :: args -> Number result + , resultArity :: Int -- ^ the arity of the result + , select :: Int -- ^ which component of the result to select } class CmmArgs arg where @@ -149,14 +150,18 @@ evalCallishOpCmm => CallishOp args width -> args -> String -evalCallishOpCmm op args = unlines - [ "test ( " <> cmmWordType <> " buffer ) {" - , " " <> cmmType width <> " ret;" - , " (ret) = prim " ++ name op ++ argList ++ ";" - , " return (ret);" +evalCallishOpCmm op args = unlines $ + [ "test ( " <> cmmWordType <> " buffer ) {" ] ++ + [ " " <> cmmType width <> " " <> bndr <> ";" | bndr <- bndrs ] ++ + [ " " <> lhs <> " = prim " ++ name op ++ argList ++ ";" + , " return (" <> result <> ");" , "}" ] where + lhs = "(" <> intercalate "," bndrs <> ")" + bndrs :: [String] + bndrs = take (resultArity op) [ "res" <> show n | n <- [0 :: Int ..] ] + result = bndrs !! select op argList = parens $ commaList [exprToCmm e | SomeExpr e <- getArgs args] width = knownWidth @width -- GitLab From 0ada780e7bb78f1cbdfe7062ef13da4d7751319d Mon Sep 17 00:00:00 2001 From: Ben Gamari <ben@smart-cactus.org> Date: Wed, 22 Jan 2025 13:11:18 -0500 Subject: [PATCH 3/4] CallishOp: Test Op_Mul2 correctness --- src/CallishOp.hs | 57 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/src/CallishOp.hs b/src/CallishOp.hs index bde43a3..8d5cc28 100644 --- a/src/CallishOp.hs +++ b/src/CallishOp.hs @@ -2,7 +2,7 @@ module CallishOp ( prop_callish_ops_correct -- * Individual tests - , popcnt, pdep, pext + , popcnt, pdep, pext, mul2Op, mul2uOp , refImpl , prop_callish_correct -- * Evaluating callish machops @@ -14,9 +14,9 @@ import Data.List (intercalate) import Data.Bits import Data.Proxy import Numeric.Natural -import Test.QuickCheck +import Test.QuickCheck hiding ((.&.)) import Test.Tasty -import Test.Tasty.QuickCheck +import Test.Tasty.QuickCheck hiding ((.&.)) import Data.Foldable (foldl') import Width @@ -28,9 +28,13 @@ import Expr prop_callish_ops_correct :: EvalMethod -> TestTree prop_callish_ops_correct em = testGroup "callish ops" [ testCallishOpNoW8 "bswap" (\(_ :: Proxy w) -> toProp $ bswap @w) - , testCallishOp "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w) - , testCallishOp "pdep" (\(_ :: Proxy w) -> toProp $ pdep @w) - , testCallishOp "pext" (\(_ :: Proxy w) -> toProp $ pext @w) + , testCallishOp "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w) + , testCallishOp "pdep" (\(_ :: Proxy w) -> toProp $ pdep @w) + , testCallishOp "pext" (\(_ :: Proxy w) -> toProp $ pext @w) + , testCallishOp "mul2_lo" (\(_ :: Proxy w) -> toProp $ mul2Op @w False) + , testCallishOp "mul2_hi" (\(_ :: Proxy w) -> toProp $ mul2Op @w True) + , testCallishOp "mul2u_lo" (\(_ :: Proxy w) -> toProp $ mul2uOp @w False) + , testCallishOp "mul2u_hi" (\(_ :: Proxy w) -> toProp $ mul2uOp @w True) ] where toProp :: forall args w. (CmmArgs args, Arbitrary args, Show args, KnownWidth w) @@ -99,6 +103,47 @@ bswap = simpleCallishOp name (ref . interpret) toBytes = chunk 8 . toBits fromBytes = fromUnsigned . fromBits . concat +mul2Op :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w +mul2Op high = + CallishOp { name, refImpl, resultArity = 3, select } + where + select = if high then 1 else 2 + name = "%mul2_" <> show (widthBits (knownWidth @w)) + refImpl :: (Expr w, Expr w) -> Number w + refImpl (a,b) + | high = x + | otherwise = y + where + (x,y) = mul2 Signed (interpret a) (interpret b) + +mul2uOp :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w +mul2uOp high = + CallishOp { name, refImpl, resultArity = 2, select } + where + select = if high then 0 else 1 + name = "%mul2u_" <> show (widthBits (knownWidth @w)) + refImpl :: (Expr w, Expr w) -> Number w + refImpl (a,b) + | high = x + | otherwise = y + where + (x,y) = mul2 Unsigned (interpret a) (interpret b) + +mul2 :: forall w. (KnownWidth w) + => Signedness -> Number w -> Number w -> (Number w, Number w) +mul2 Unsigned a b = + (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo) + where + r = asInteger Unsigned a * asInteger Unsigned b + (hi, lo) = r `divMod` maxB + maxB = 2^widthBits (knownWidth @w) +mul2 Signed a b = + (fromSigned hi, fromUnsignedC $ fromIntegral lo) + where + r = asInteger Signed a * asInteger Signed b + (hi, lo) = r `divMod` maxB + maxB = 2^widthBits (knownWidth @w) + toBits :: forall w. (KnownWidth w) => Number w -> [Bool] toBits x = [ x `testBit` i | i <- [0..widthBits (knownWidth @w)-1] ] -- GitLab From a6f2cb85d64bcc69f99a7082ea6d1f295f8a66bd Mon Sep 17 00:00:00 2001 From: Ben Gamari <ben@smart-cactus.org> Date: Wed, 22 Jan 2025 19:51:35 -0500 Subject: [PATCH 4/4] Simplify mul2 reference implementation --- src/CallishOp.hs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/CallishOp.hs b/src/CallishOp.hs index 8d5cc28..d37914d 100644 --- a/src/CallishOp.hs +++ b/src/CallishOp.hs @@ -131,16 +131,11 @@ mul2uOp high = mul2 :: forall w. (KnownWidth w) => Signedness -> Number w -> Number w -> (Number w, Number w) -mul2 Unsigned a b = - (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo) +mul2 s a b + | Unsigned <- s = (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo) + | otherwise = (fromSigned hi, fromUnsignedC $ fromIntegral lo) where - r = asInteger Unsigned a * asInteger Unsigned b - (hi, lo) = r `divMod` maxB - maxB = 2^widthBits (knownWidth @w) -mul2 Signed a b = - (fromSigned hi, fromUnsignedC $ fromIntegral lo) - where - r = asInteger Signed a * asInteger Signed b + r = asInteger s a * asInteger s b (hi, lo) = r `divMod` maxB maxB = 2^widthBits (knownWidth @w) -- GitLab