diff --git a/src/CallishOp.hs b/src/CallishOp.hs index 5ed4071219f2637c4bc961c27bdcdebbf316b928..d37914d70a14e9acdf77bb16760a5ae8bd173b38 100644 --- a/src/CallishOp.hs +++ b/src/CallishOp.hs @@ -2,7 +2,7 @@ module CallishOp ( prop_callish_ops_correct -- * Individual tests - , popcnt, pdep, pext + , popcnt, pdep, pext, mul2Op, mul2uOp , refImpl , prop_callish_correct -- * Evaluating callish machops @@ -10,12 +10,13 @@ module CallishOp , evalCallishOpCmm ) where -import Numeric.Natural +import Data.List (intercalate) import Data.Bits import Data.Proxy -import Test.QuickCheck +import Numeric.Natural +import Test.QuickCheck hiding ((.&.)) import Test.Tasty -import Test.Tasty.QuickCheck +import Test.Tasty.QuickCheck hiding ((.&.)) import Data.Foldable (foldl') import Width @@ -27,9 +28,13 @@ import Expr prop_callish_ops_correct :: EvalMethod -> TestTree prop_callish_ops_correct em = testGroup "callish ops" [ testCallishOpNoW8 "bswap" (\(_ :: Proxy w) -> toProp $ bswap @w) - , testCallishOp "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w) - , testCallishOp "pdep" (\(_ :: Proxy w) -> toProp $ pdep @w) - , testCallishOp "pext" (\(_ :: Proxy w) -> toProp $ pext @w) + , testCallishOp "popcnt" (\(_ :: Proxy w) -> toProp $ popcnt @w) + , testCallishOp "pdep" (\(_ :: Proxy w) -> toProp $ pdep @w) + , testCallishOp "pext" (\(_ :: Proxy w) -> toProp $ pext @w) + , testCallishOp "mul2_lo" (\(_ :: Proxy w) -> toProp $ mul2Op @w False) + , testCallishOp "mul2_hi" (\(_ :: Proxy w) -> toProp $ mul2Op @w True) + , testCallishOp "mul2u_lo" (\(_ :: Proxy w) -> toProp $ mul2uOp @w False) + , testCallishOp "mul2u_hi" (\(_ :: Proxy w) -> toProp $ mul2uOp @w True) ] where toProp :: forall args w. (CmmArgs args, Arbitrary args, Show args, KnownWidth w) @@ -54,18 +59,16 @@ prop_callish_ops_correct em = testGroup "callish ops" ] popcnt :: forall w. (KnownWidth w) => CallishOp (Expr w) WordSize -popcnt = CallishOp - { name = "%popcnt" ++ show (widthBits (knownWidth @w)) - , refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret - } +popcnt = simpleCallishOp name refImpl + where + name = "%popcnt" ++ show (widthBits (knownWidth @w)) + refImpl = fromUnsigned . fromIntegral . popCount . toUnsigned . interpret -- | Arguments are @(source, mask)@. pdep :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w -pdep = CallishOp - { name = "%pdep" ++ show (widthBits (knownWidth @w)) - , refImpl = uncurry ref - } +pdep = simpleCallishOp name (uncurry ref) where + name = "%pdep" ++ show (widthBits (knownWidth @w)) ref :: Expr w -> Expr w -> Number w ref x0 mask0 = fromUnsigned $ fromBits $ go (exprBits mask0) (exprBits x0) where @@ -78,11 +81,9 @@ pdep = CallishOp -- | Arguments are @(source, mask)@. pext :: forall w. (KnownWidth w) => CallishOp (Expr w, Expr w) w -pext = CallishOp - { name = "%pext" ++ show (widthBits (knownWidth @w)) - , refImpl = uncurry ref - } +pext = simpleCallishOp name (uncurry ref) where + name = "%pext" ++ show (widthBits (knownWidth @w)) ref :: Expr w -> Expr w -> Number w ref x mask = fromUnsigned @@ -93,17 +94,51 @@ pext = CallishOp -- | Arguments are @(source, mask)@. bswap :: forall w. (KnownWidth w) => CallishOp (Expr w) w -bswap = CallishOp - { name = "%bswap" ++ show (widthBits (knownWidth @w)) - , refImpl = ref . interpret - } +bswap = simpleCallishOp name (ref . interpret) where + name = "%bswap" ++ show (widthBits (knownWidth @w)) ref :: Number w -> Number w ref = fromBytes . reverse . toBytes where toBytes = chunk 8 . toBits fromBytes = fromUnsigned . fromBits . concat +mul2Op :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w +mul2Op high = + CallishOp { name, refImpl, resultArity = 3, select } + where + select = if high then 1 else 2 + name = "%mul2_" <> show (widthBits (knownWidth @w)) + refImpl :: (Expr w, Expr w) -> Number w + refImpl (a,b) + | high = x + | otherwise = y + where + (x,y) = mul2 Signed (interpret a) (interpret b) + +mul2uOp :: forall w. (KnownWidth w) => Bool -> CallishOp (Expr w, Expr w) w +mul2uOp high = + CallishOp { name, refImpl, resultArity = 2, select } + where + select = if high then 0 else 1 + name = "%mul2u_" <> show (widthBits (knownWidth @w)) + refImpl :: (Expr w, Expr w) -> Number w + refImpl (a,b) + | high = x + | otherwise = y + where + (x,y) = mul2 Unsigned (interpret a) (interpret b) + +mul2 :: forall w. (KnownWidth w) + => Signedness -> Number w -> Number w -> (Number w, Number w) +mul2 s a b + | Unsigned <- s = (fromUnsignedC $ fromIntegral hi, fromUnsignedC $ fromIntegral lo) + | otherwise = (fromSigned hi, fromUnsignedC $ fromIntegral lo) + where + r = asInteger s a * asInteger s b + (hi, lo) = r `divMod` maxB + maxB = 2^widthBits (knownWidth @w) + toBits :: forall w. (KnownWidth w) => Number w -> [Bool] toBits x = [ x `testBit` i | i <- [0..widthBits (knownWidth @w)-1] ] @@ -121,9 +156,15 @@ prop_callish_correct em op args = counterexample (evalCallishOpCmm op args) $ io r <- evalCallishOp em op args return $ refImpl op args === r +simpleCallishOp :: String -> (args -> Number result) -> CallishOp args result +simpleCallishOp name refImpl = + CallishOp { name, refImpl, resultArity = 1, select = 0 } + data CallishOp args result = CallishOp { name :: String , refImpl :: args -> Number result + , resultArity :: Int -- ^ the arity of the result + , select :: Int -- ^ which component of the result to select } class CmmArgs arg where @@ -149,14 +190,18 @@ evalCallishOpCmm => CallishOp args width -> args -> String -evalCallishOpCmm op args = unlines - [ "test ( " <> cmmWordType <> " buffer ) {" - , " " <> cmmType width <> " ret;" - , " (ret) = prim " ++ name op ++ argList ++ ";" - , " return (ret);" +evalCallishOpCmm op args = unlines $ + [ "test ( " <> cmmWordType <> " buffer ) {" ] ++ + [ " " <> cmmType width <> " " <> bndr <> ";" | bndr <- bndrs ] ++ + [ " " <> lhs <> " = prim " ++ name op ++ argList ++ ";" + , " return (" <> result <> ");" , "}" ] where + lhs = "(" <> intercalate "," bndrs <> ")" + bndrs :: [String] + bndrs = take (resultArity op) [ "res" <> show n | n <- [0 :: Int ..] ] + result = bndrs !! select op argList = parens $ commaList [exprToCmm e | SomeExpr e <- getArgs args] width = knownWidth @width diff --git a/src/Number.hs b/src/Number.hs index 5bf6178d50eca970d5a98702a2286263372375bf..47464f166ec7d012a65687b6c8fed989173ed86f 100644 --- a/src/Number.hs +++ b/src/Number.hs @@ -13,6 +13,7 @@ module Number , fromUnsigned , fromUnsignedC , chooseNumber + , bounds , signedMinBound , n8, n16, n32, n64 , ones @@ -160,6 +161,17 @@ instance (KnownWidth width) => Bounded (Number width) where signedMinBound :: forall width. (KnownWidth width) => Number width signedMinBound = fromIntegral (toUnsigned (maxBound :: Number width) `div` 2 + 1) +bounds :: forall width. (KnownWidth width) => Signedness -> (Integer, Integer) +bounds Signed = (a, b) + where + w = widthBits (knownWidth @width) - 1 + a = negate (1 `shiftL` w) + b = (1 `shiftL` w) - 1 +bounds Unsigned = (0, b) + where + w = widthBits $ knownWidth @width + b = (1 `shiftL` w) - 1 + instance (KnownWidth width) => Arbitrary (Number width) where arbitrary = chooseNumber (minBound, maxBound) shrink (Number 0) = []