From a6316eb4503fa0e2f13087dad6c5b2ff34367c41 Mon Sep 17 00:00:00 2001
From: Ben Gamari <ben@smart-cactus.org>
Date: Tue, 18 Apr 2023 07:17:40 -0400
Subject: [PATCH] STM: Use acquire loads when possible

Full sequential consistency is not needed here.
---
 rts/STM.c | 48 +++++++++++++++++++++++++-----------------------
 1 file changed, 25 insertions(+), 23 deletions(-)

diff --git a/rts/STM.c b/rts/STM.c
index 805091303389..dd136100b085 100644
--- a/rts/STM.c
+++ b/rts/STM.c
@@ -187,7 +187,7 @@ static StgClosure *lock_tvar(Capability *cap STG_UNUSED,
                              StgTVar *s STG_UNUSED) {
   StgClosure *result;
   TRACE("%p : lock_tvar(%p)", trec, s);
-  result = SEQ_CST_LOAD(&s->current_value);
+  result = ACQUIRE_LOAD(&s->current_value);
   return result;
 }
 
@@ -198,7 +198,7 @@ static void unlock_tvar(Capability *cap,
                         StgBool force_update) {
   TRACE("%p : unlock_tvar(%p)", trec, s);
   if (force_update) {
-    StgClosure *old_value = SEQ_CST_LOAD(&s->current_value);
+    StgClosure *old_value = ACQUIRE_LOAD(&s->current_value);
     RELEASE_STORE(&s->current_value, c);
     dirty_TVAR(cap, s, old_value);
   }
@@ -210,7 +210,7 @@ static StgBool cond_lock_tvar(Capability *cap STG_UNUSED,
                               StgClosure *expected) {
   StgClosure *result;
   TRACE("%p : cond_lock_tvar(%p, %p)", trec, s, expected);
-  result = SEQ_CST_LOAD(&s->current_value);
+  result = ACQUIRE_LOAD(&s->current_value);
   TRACE("%p : %s", trec, (result == expected) ? "success" : "failure");
   return (result == expected);
 }
@@ -231,7 +231,7 @@ static void lock_stm(StgTRecHeader *trec) {
 static void unlock_stm(StgTRecHeader *trec STG_UNUSED) {
   TRACE("%p : unlock_stm()", trec);
   ASSERT(smp_locked == trec);
-  SEQ_CST_STORE(&smp_locked, 0);
+  RELEASE_STORE(&smp_locked, 0);
 }
 
 static StgClosure *lock_tvar(Capability *cap STG_UNUSED,
@@ -240,7 +240,7 @@ static StgClosure *lock_tvar(Capability *cap STG_UNUSED,
   StgClosure *result;
   TRACE("%p : lock_tvar(%p)", trec, s);
   ASSERT(smp_locked == trec);
-  result = SEQ_CST_LOAD(&s->current_value);
+  result = ACQUIRE_LOAD(&s->current_value);
   return result;
 }
 
@@ -252,7 +252,7 @@ static void *unlock_tvar(Capability *cap,
   TRACE("%p : unlock_tvar(%p, %p)", trec, s, c);
   ASSERT(smp_locked == trec);
   if (force_update) {
-    StgClosure *old_value = SEQ_CST_LOAD(&s->current_value);
+    StgClosure *old_value = ACQUIRE_LOAD(&s->current_value);
     RELEASE_STORE(&s->current_value, c);
     dirty_TVAR(cap, s, old_value);
   }
@@ -265,7 +265,7 @@ static StgBool cond_lock_tvar(Capability *cap STG_UNUSED,
   StgClosure *result;
   TRACE("%p : cond_lock_tvar(%p, %p)", trec, s, expected);
   ASSERT(smp_locked == trec);
-  result = SEQ_CST_LOAD(&s->current_value);
+  result = ACQUIRE_LOAD(&s->current_value);
   TRACE("%p : %d", result ? "success" : "failure");
   return (result == expected);
 }
@@ -291,9 +291,11 @@ static StgClosure *lock_tvar(Capability *cap,
   StgClosure *result;
   TRACE("%p : lock_tvar(%p)", trec, s);
   do {
+    const StgInfoTable *info;
     do {
-      result = SEQ_CST_LOAD(&s->current_value);
-    } while (GET_INFO(UNTAG_CLOSURE(result)) == &stg_TREC_HEADER_info);
+      result = ACQUIRE_LOAD(&s->current_value);
+      info = GET_INFO(UNTAG_CLOSURE(result));
+    } while (info == &stg_TREC_HEADER_info);
   } while (cas((void *) &s->current_value,
                (StgWord)result, (StgWord)trec) != (StgWord)result);
 
@@ -311,7 +313,7 @@ static void unlock_tvar(Capability *cap,
                         StgClosure *c,
                         StgBool force_update STG_UNUSED) {
   TRACE("%p : unlock_tvar(%p, %p)", trec, s, c);
-  ASSERT(SEQ_CST_LOAD(&s->current_value) == (StgClosure *)trec);
+  ASSERT(ACQUIRE_LOAD(&s->current_value) == (StgClosure *)trec);
   RELEASE_STORE(&s->current_value, c);
   dirty_TVAR(cap, s, (StgClosure *) trec);
 }
@@ -375,7 +377,7 @@ static void unpark_waiters_on(Capability *cap, StgTVar *s) {
   StgTVarWatchQueue *trail;
   TRACE("unpark_waiters_on tvar=%p", s);
   // unblock TSOs in reverse order, to be a bit fairer (#2319)
-  for (q = SEQ_CST_LOAD(&s->first_watch_queue_entry), trail = q;
+  for (q = ACQUIRE_LOAD(&s->first_watch_queue_entry), trail = q;
        q != END_STM_WATCH_QUEUE;
        q = q -> next_queue_entry) {
     trail = q;
@@ -532,16 +534,16 @@ static void build_watch_queue_entries_for_trec(Capability *cap,
     StgTVarWatchQueue *fq;
     s = e -> tvar;
     TRACE("%p : adding tso=%p to watch queue for tvar=%p", trec, tso, s);
-    ACQ_ASSERT(SEQ_CST_LOAD(&s->current_value) == (StgClosure *)trec);
-    NACQ_ASSERT(SEQ_CST_LOAD(&s->current_value) == e -> expected_value);
-    fq = SEQ_CST_LOAD(&s->first_watch_queue_entry);
+    ACQ_ASSERT(ACQUIRE_LOAD(&s->current_value) == (StgClosure *)trec);
+    NACQ_ASSERT(ACQUIRE_LOAD(&s->current_value) == e -> expected_value);
+    fq = ACQUIRE_LOAD(&s->first_watch_queue_entry);
     q = alloc_stg_tvar_watch_queue(cap, (StgClosure*) tso);
     q -> next_queue_entry = fq;
     q -> prev_queue_entry = END_STM_WATCH_QUEUE;
     if (fq != END_STM_WATCH_QUEUE) {
       fq -> prev_queue_entry = q;
     }
-    SEQ_CST_STORE(&s->first_watch_queue_entry, q);
+    RELEASE_STORE(&s->first_watch_queue_entry, q);
     e -> new_value = (StgClosure *) q;
     dirty_TVAR(cap, s, (StgClosure *) fq); // we modified first_watch_queue_entry
   });
@@ -569,7 +571,7 @@ static void remove_watch_queue_entries_for_trec(Capability *cap,
           trec,
           q -> closure,
           s);
-    ACQ_ASSERT(SEQ_CST_LOAD(&s->current_value) == (StgClosure *)trec);
+    ACQ_ASSERT(ACQUIRE_LOAD(&s->current_value) == (StgClosure *)trec);
     nq = q -> next_queue_entry;
     pq = q -> prev_queue_entry;
     if (nq != END_STM_WATCH_QUEUE) {
@@ -578,8 +580,8 @@ static void remove_watch_queue_entries_for_trec(Capability *cap,
     if (pq != END_STM_WATCH_QUEUE) {
       pq -> next_queue_entry = nq;
     } else {
-      ASSERT(SEQ_CST_LOAD(&s->first_watch_queue_entry) == q);
-      SEQ_CST_STORE(&s->first_watch_queue_entry, nq);
+      ASSERT(ACQUIRE_LOAD(&s->first_watch_queue_entry) == q);
+      RELEASE_STORE(&s->first_watch_queue_entry, nq);
       dirty_TVAR(cap, s, (StgClosure *) q); // we modified first_watch_queue_entry
     }
     free_stg_tvar_watch_queue(cap, q);
@@ -727,7 +729,7 @@ static StgBool entry_is_read_only(TRecEntry *e) {
 static StgBool tvar_is_locked(StgTVar *s, StgTRecHeader *h) {
   StgClosure *c;
   StgBool result;
-  c = SEQ_CST_LOAD(&s->current_value);
+  c = ACQUIRE_LOAD(&s->current_value);
   result = (c == (StgClosure *) h);
   return result;
 }
@@ -803,13 +805,13 @@ static StgBool validate_and_acquire_ownership (Capability *cap,
           // The memory ordering here must ensure that we have two distinct
           // reads to current_value, with the read from num_updates between
           // them.
-          if (SEQ_CST_LOAD(&s->current_value) != e -> expected_value) {
+          if (ACQUIRE_LOAD(&s->current_value) != e -> expected_value) {
             TRACE("%p : doesn't match", trec);
             result = false;
             BREAK_FOR_EACH;
           }
           e->num_updates = SEQ_CST_LOAD(&s->num_updates);
-          if (SEQ_CST_LOAD(&s->current_value) != e -> expected_value) {
+          if (ACQUIRE_LOAD(&s->current_value) != e -> expected_value) {
             TRACE("%p : doesn't match (race)", trec);
             result = false;
             BREAK_FOR_EACH;
@@ -852,7 +854,7 @@ static StgBool check_read_only(StgTRecHeader *trec STG_UNUSED) {
 
         // We must first load current_value then num_updates; this is inverse of
         // the order of the stores in stmCommitTransaction.
-        StgClosure *current_value = SEQ_CST_LOAD(&s->current_value);
+        StgClosure *current_value = ACQUIRE_LOAD(&s->current_value);
         StgInt num_updates = SEQ_CST_LOAD(&s->num_updates);
 
         // Note we need both checks and in this order as the TVar could be
@@ -1186,7 +1188,7 @@ StgBool stmCommitNestedTransaction(Capability *cap, StgTRecHeader *trec) {
             unlock_tvar(cap, trec, s, e -> expected_value, false);
         }
         merge_update_into(cap, et, s, e -> expected_value, e -> new_value);
-        ACQ_ASSERT(s -> current_value != (StgClosure *)trec);
+        ACQ_ASSERT(ACQUIRE_LOAD(&s->current_value) != (StgClosure *)trec);
       });
     } else {
         revert_ownership(cap, trec, false);
-- 
GitLab