From 5a2400c6570e4069b04f8d727c9058620bb99f3c Mon Sep 17 00:00:00 2001
From: Viktor Dukhovni <ietf-dane@dukhovni.org>
Date: Mon, 5 Oct 2020 01:43:26 -0400
Subject: [PATCH] Naming, value types and tests for Addr# atomics

The atomic Exchange and CAS operations on integral types are updated to
take and return more natural `Word#` rather than `Int#` values.  These
are bit-block not arithmetic operations, and the sign bit plays no
special role.

Standardises the names to `atomic<OpType><ValType>Addr#`, where `OpType` is one
of `Cas` or `Exchange` and `ValType` is presently either `Word` or `Addr`.
Eventually, variants for `Word32` and `Word64` can and should be added,
once #11953 and related issues (e.g. #13825) are resolved.

Adds tests for `Addr#` CAS that mirror existing tests for
`MutableByteArray#`.
---
 compiler/GHC/Builtin/primops.txt.pp           | 36 ++++++++-
 compiler/GHC/StgToCmm/Prim.hs                 |  7 +-
 libraries/base/GHC/Ptr.hs                     |  2 +-
 libraries/ghc-prim/changelog.md               |  9 ++-
 .../tests/codeGen/should_compile/cg011.hs     |  8 +-
 .../tests/codeGen/should_run/cgrun080.hs      | 15 ++--
 .../concurrent/should_run/AtomicPrimops.hs    | 80 ++++++++++++++++---
 .../should_run/AtomicPrimops.stdout           |  1 +
 8 files changed, 128 insertions(+), 30 deletions(-)

diff --git a/compiler/GHC/Builtin/primops.txt.pp b/compiler/GHC/Builtin/primops.txt.pp
index 261d02aa6736..170c6b2f8df8 100644
--- a/compiler/GHC/Builtin/primops.txt.pp
+++ b/compiler/GHC/Builtin/primops.txt.pp
@@ -2471,17 +2471,47 @@ primop  WriteOffAddrOp_Word64 "writeWord64OffAddr#" GenPrimOp
    with has_side_effects = True
         can_fail         = True
 
-primop  InterlockedExchange_Addr "interlockedExchangeAddr#" GenPrimOp
+primop  InterlockedExchange_Addr "atomicExchangeAddrAddr#" GenPrimOp
    Addr# -> Addr# -> State# s -> (# State# s, Addr# #)
    {The atomic exchange operation. Atomically exchanges the value at the first address
     with the Addr# given as second argument. Implies a read barrier.}
    with has_side_effects = True
+        can_fail         = True
 
-primop  InterlockedExchange_Int "interlockedExchangeInt#" GenPrimOp
-   Addr# -> Int# -> State# s -> (# State# s, Int# #)
+primop  InterlockedExchange_Word "atomicExchangeWordAddr#" GenPrimOp
+   Addr# -> Word# -> State# s -> (# State# s, Word# #)
    {The atomic exchange operation. Atomically exchanges the value at the address
     with the given value. Returns the old value. Implies a read barrier.}
    with has_side_effects = True
+        can_fail         = True
+
+primop  CasAddrOp_Addr "atomicCasAddrAddr#" GenPrimOp
+   Addr# -> Addr# -> Addr# -> State# s -> (# State# s, Addr# #)
+   { Compare and swap on a word-sized memory location.
+
+     Use as: \s -> atomicCasAddrAddr# location expected desired s
+
+     This version always returns the old value read. This follows the normal
+     protocol for CAS operations (and matches the underlying instruction on
+     most architectures).
+
+     Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail         = True
+
+primop  CasAddrOp_Word "atomicCasWordAddr#" GenPrimOp
+   Addr# -> Word# -> Word# -> State# s -> (# State# s, Word# #)
+   { Compare and swap on a word-sized and aligned memory location.
+
+     Use as: \s -> atomicCasWordAddr# location expected desired s
+
+     This version always returns the old value read. This follows the normal
+     protocol for CAS operations (and matches the underlying instruction on
+     most architectures).
+
+     Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail         = True
 
 ------------------------------------------------------------------------
 section "Mutable variables"
diff --git a/compiler/GHC/StgToCmm/Prim.hs b/compiler/GHC/StgToCmm/Prim.hs
index afbcc34836c1..1bc6954f3c05 100644
--- a/compiler/GHC/StgToCmm/Prim.hs
+++ b/compiler/GHC/StgToCmm/Prim.hs
@@ -846,9 +846,14 @@ emitPrimOp dflags primop = case primop of
 -- Atomic operations
   InterlockedExchange_Addr -> \[src, value] -> opIntoRegs $ \[res] ->
     emitPrimCall [res] (MO_Xchg (wordWidth platform)) [src, value]
-  InterlockedExchange_Int -> \[src, value] -> opIntoRegs $ \[res] ->
+  InterlockedExchange_Word -> \[src, value] -> opIntoRegs $ \[res] ->
     emitPrimCall [res] (MO_Xchg (wordWidth platform)) [src, value]
 
+  CasAddrOp_Addr -> \[dst, expected, new] -> opIntoRegs $ \[res] ->
+    emitPrimCall [res] (MO_Cmpxchg (wordWidth platform)) [dst, expected, new]
+  CasAddrOp_Word -> \[dst, expected, new] -> opIntoRegs $ \[res] ->
+    emitPrimCall [res] (MO_Cmpxchg (wordWidth platform)) [dst, expected, new]
+
 -- SIMD primops
   (VecBroadcastOp vcat n w) -> \[e] -> opIntoRegs $ \[res] -> do
     checkVecCompatibility dflags vcat n w
diff --git a/libraries/base/GHC/Ptr.hs b/libraries/base/GHC/Ptr.hs
index 6cbcc07ddceb..612d3ef94b17 100644
--- a/libraries/base/GHC/Ptr.hs
+++ b/libraries/base/GHC/Ptr.hs
@@ -171,7 +171,7 @@ castPtrToFunPtr (Ptr addr) = FunPtr addr
 exchangePtr :: Ptr (Ptr a) -> Ptr b -> IO (Ptr c)
 exchangePtr (Ptr dst) (Ptr val) =
   IO $ \s ->
-      case (interlockedExchangeAddr# dst val s) of
+      case (atomicExchangeAddrAddr# dst val s) of
         (# s2, old_val #) -> (# s2, Ptr old_val #)
 
 ------------------------------------------------------------------------
diff --git a/libraries/ghc-prim/changelog.md b/libraries/ghc-prim/changelog.md
index 3df0b5e2ed75..4a3e9f640c0d 100644
--- a/libraries/ghc-prim/changelog.md
+++ b/libraries/ghc-prim/changelog.md
@@ -21,8 +21,13 @@
 
 - Add primops for atomic exchange:
 
-        interlockedExchangeAddr# :: Addr# -> Addr# -> State# s -> (# State# s, Addr# #)
-        interlockedExchangeInt# :: Addr# -> Int# -> State# s -> (# State# s, Int# #)
+        atomicExchangeAddrAddr# :: Addr# -> Addr# -> State# s -> (# State# s, Addr# #)
+        atomicExchangeWordAddr# :: Addr# -> Word# -> State# s -> (# State# s, Word# #)
+
+- Add primops for atomic compare and swap at a given Addr#:
+
+        atomicCasAddrAddr# :: Addr# -> Addr# -> Addr# -> State# s -> (# State# s, Addr# #)
+        atomicCasWordAddr# :: Addr# -> Word# -> Word# -> State# s -> (# State# s, Word# #)
 
 - Add an explicit fixity for `(~)` and `(~~)`: 
 
diff --git a/testsuite/tests/codeGen/should_compile/cg011.hs b/testsuite/tests/codeGen/should_compile/cg011.hs
index 5d8096854705..77f2e6f7f908 100644
--- a/testsuite/tests/codeGen/should_compile/cg011.hs
+++ b/testsuite/tests/codeGen/should_compile/cg011.hs
@@ -1,11 +1,11 @@
 {-# LANGUAGE CPP, MagicHash, BlockArguments, UnboxedTuples #-}
 
--- Tests compilation for interlockedExchange primop.
+-- Tests compilation for atomicExchangeWordAddr# primop.
 
 module M where
 
-import GHC.Exts (interlockedExchangeInt#, Int#, Addr#, State# )
+import GHC.Exts (atomicExchangeWordAddr#, Word#, Addr#, State# )
 
-swap :: Addr# -> Int# -> State# s -> (# #)
-swap ptr val s = case (interlockedExchangeInt# ptr val s) of
+swap :: Addr# -> Word# -> State# s -> (# #)
+swap ptr val s = case (atomicExchangeWordAddr# ptr val s) of
             (# s2, old_val #) -> (# #)
diff --git a/testsuite/tests/codeGen/should_run/cgrun080.hs b/testsuite/tests/codeGen/should_run/cgrun080.hs
index 5390dd11aefe..78d54700f951 100644
--- a/testsuite/tests/codeGen/should_run/cgrun080.hs
+++ b/testsuite/tests/codeGen/should_run/cgrun080.hs
@@ -25,8 +25,8 @@ import GHC.Types
 
 main = do
    alloca $ \ptr_i -> do
-      poke ptr_i (1 :: Int)
-      w1 <- newEmptyMVar :: IO (MVar Int)
+      poke ptr_i (1 :: Word)
+      w1 <- newEmptyMVar :: IO (MVar Word)
       forkIO $ do
          v <- swapN 50000 2 ptr_i
          putMVar w1 v
@@ -37,15 +37,14 @@ main = do
       -- Should be [1,2,3]
       print $ sort [v0,v1,v2]
 
-swapN :: Int -> Int -> Ptr Int -> IO Int
+swapN :: Word -> Word -> Ptr Word -> IO Word
 swapN 0 val ptr = return val
 swapN n val ptr = do
    val' <- swap ptr val
    swapN (n-1) val' ptr
 
 
-swap :: Ptr Int -> Int -> IO Int
-swap (Ptr ptr) (I# val) = do
-   IO $ \s -> case (interlockedExchangeInt# ptr val s) of
-            (# s2, old_val #) -> (# s2, I# old_val #)
-
+swap :: Ptr Word -> Word -> IO Word
+swap (Ptr ptr) (W# val) = do
+   IO $ \s -> case (atomicExchangeWordAddr# ptr val s) of
+            (# s2, old_val #) -> (# s2, W# old_val #)
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs
index 1789e26bbb46..aeed9eaab62c 100644
--- a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs
+++ b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs
@@ -6,6 +6,8 @@ module Main ( main ) where
 import Control.Concurrent
 import Control.Concurrent.MVar
 import Control.Monad (when)
+import Foreign.Marshal.Alloc
+import Foreign.Ptr
 import Foreign.Storable
 import GHC.Exts
 import GHC.IO
@@ -22,6 +24,7 @@ main = do
     fetchOrTest
     fetchXorTest
     casTest
+    casTestAddr
     readWriteTest
 
 -- | Test fetchAddIntArray# by having two threads concurrenctly
@@ -54,12 +57,14 @@ fetchXorTest = do
     work mba 0 val = return ()
     work mba n val = fetchXorIntArray mba 0 val >> work mba (n-1) val
 
-    -- Initial value is a large prime and the two patterns are 1010...
-    -- and 0101...
+    -- The two patterns are 1010...  and 0101...  The second pattern is larger
+    -- than maxBound, avoid warnings by initialising as a Word.
     (n0, t1pat, t2pat)
         | sizeOf (undefined :: Int) == 8 =
-            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
-        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+            ( 0x00000000ffffffff, 0x5555555555555555
+            , fromIntegral (0x9999999999999999 :: Word))
+        | otherwise = ( 0x0000ffff, 0x55555555
+                      , fromIntegral (0x99999999 :: Word))
     expected
         | sizeOf (undefined :: Int) == 8 = 4294967295
         | otherwise = 65535
@@ -90,13 +95,15 @@ fetchOpTest op expected name = do
 
 -- | Initial value and operation arguments for race test.
 --
--- Initial value is a large prime and the two patterns are 1010...
--- and 0101...
+-- The two patterns are 1010...  and 0101...  The second pattern is larger than
+-- maxBound, avoid warnings by initialising as a Word.
 n0, t1pat, t2pat :: Int
 (n0, t1pat, t2pat)
     | sizeOf (undefined :: Int) == 8 =
-        (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
-    | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+        ( 0x00000000ffffffff, 0x5555555555555555
+        , fromIntegral (0x9999999999999999 :: Word))
+    | otherwise = ( 0x0000ffff, 0x55555555
+                  , fromIntegral (0x99999999 :: Word))
 
 fetchAndTest :: IO ()
 fetchAndTest = fetchOpTest fetchAndIntArray expected "fetchAndTest"
@@ -120,8 +127,10 @@ fetchNandTest = do
 fetchOrTest :: IO ()
 fetchOrTest = fetchOpTest fetchOrIntArray expected "fetchOrTest"
   where expected
-            | sizeOf (undefined :: Int) == 8 = 15987178197787607039
-            | otherwise = 3722313727
+            | sizeOf (undefined :: Int) == 8
+            = fromIntegral (15987178197787607039 :: Word)
+            | otherwise
+            = fromIntegral (3722313727 :: Word)
 
 -- | Test casIntArray# by using it to emulate fetchAddIntArray# and
 -- then having two threads concurrenctly increment a counter,
@@ -131,7 +140,7 @@ casTest = do
     tot <- race 0
         (\ mba -> work mba iters 1)
         (\ mba -> work mba iters 2)
-    assertEq 3000000 tot "casTest"
+    assertEq (3 * iters) tot "casTest"
   where
     work :: MByteArray -> Int -> Int -> IO ()
     work mba 0 val = return ()
@@ -179,6 +188,45 @@ race n0 thread1 thread2 = do
     mapM_ takeMVar [done1, done2]
     readIntArray mba 0
 
+-- | Test atomicCasWordAddr# by having two threads concurrenctly increment a
+-- counter, checking the sum at the end.
+casTestAddr :: IO ()
+casTestAddr = do
+    tot <- raceAddr 0
+        (\ addr -> work addr (fromIntegral iters) 1)
+        (\ addr -> work addr (fromIntegral iters) 2)
+    assertEq (3 * fromIntegral iters) tot "casTestAddr"
+  where
+    work :: Ptr Word -> Word -> Word -> IO ()
+    work ptr 0 val = return ()
+    work ptr n val = add ptr val >> work ptr (n-1) val
+
+    -- Fetch-and-add implemented using CAS.
+    add :: Ptr Word -> Word -> IO ()
+    add ptr n = peek ptr >>= go
+      where
+        go old = do
+            old' <- atomicCasWordPtr ptr old (old + n)
+            when (old /= old') $ go old'
+
+    -- | Create two threads that mutate the byte array passed to them
+    -- concurrently. The array is one word large.
+    raceAddr :: Word                -- ^ Initial value of array element
+            -> (Ptr Word -> IO ())  -- ^ Thread 1 action
+            -> (Ptr Word -> IO ())  -- ^ Thread 2 action
+            -> IO Word              -- ^ Final value of array element
+    raceAddr n0 thread1 thread2 = do
+        done1 <- newEmptyMVar
+        done2 <- newEmptyMVar
+        ptr <- asWordPtr <$> callocBytes (sizeOf (undefined :: Word))
+        forkIO $ thread1 ptr >> putMVar done1 ()
+        forkIO $ thread2 ptr >> putMVar done2 ()
+        mapM_ takeMVar [done1, done2]
+        peek ptr
+      where
+        asWordPtr :: Ptr a -> Ptr Word
+        asWordPtr = castPtr
+
 ------------------------------------------------------------------------
 -- Test helper
 
@@ -254,3 +302,13 @@ casIntArray :: MByteArray -> Int -> Int -> Int -> IO Int
 casIntArray (MBA mba#) (I# ix#) (I# old#) (I# new#) = IO $ \ s# ->
     case casIntArray# mba# ix# old# new# s# of
         (# s2#, old2# #) -> (# s2#, I# old2# #)
+
+------------------------------------------------------------------------
+-- Wrappers around Addr#
+
+-- Should this be added to Foreign.Storable?  Similar to poke, but does the
+-- update atomically.
+atomicCasWordPtr :: Ptr Word -> Word -> Word -> IO Word
+atomicCasWordPtr (Ptr addr#) (W# old#) (W# new#) = IO $ \ s# ->
+    case atomicCasWordAddr# addr# old# new# s# of
+        (# s2#, old2# #) -> (# s2#, W# old2# #)
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout
index c37041a04098..c9ea7ee5007d 100644
--- a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout
+++ b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout
@@ -4,4 +4,5 @@ fetchNandTest: OK
 fetchOrTest: OK
 fetchXorTest: OK
 casTest: OK
+casTestAddr: OK
 readWriteTest: OK
-- 
GitLab