From 833e250c74da9899896796b6ff8d1630f8295ec3 Mon Sep 17 00:00:00 2001
From: Simon Peyton Jones <simon.peytonjones@gmail.com>
Date: Thu, 2 Nov 2023 23:24:06 +0000
Subject: [PATCH] Update the unification count in wrapUnifierX

Omitting this caused type inference to fail in #24146.
This was an accidental omision in my refactoring of the
equality solver.
---
 compiler/GHC/Tc/Solver/Monad.hs               | 40 ++++++++++++-------
 .../tests/typecheck/should_compile/T24146.hs  | 18 +++++++++
 .../tests/typecheck/should_compile/all.T      |  1 +
 3 files changed, 44 insertions(+), 15 deletions(-)
 create mode 100644 testsuite/tests/typecheck/should_compile/T24146.hs

diff --git a/compiler/GHC/Tc/Solver/Monad.hs b/compiler/GHC/Tc/Solver/Monad.hs
index 272b3316bacb..f267b917c7e5 100644
--- a/compiler/GHC/Tc/Solver/Monad.hs
+++ b/compiler/GHC/Tc/Solver/Monad.hs
@@ -1197,6 +1197,9 @@ if you do so.
 -- Getters and setters of GHC.Tc.Utils.Env fields
 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
+getUnifiedRef :: TcS (IORef Int)
+getUnifiedRef = TcS (return . tcs_unified)
+
 -- Getter of inerts and worklist
 getInertSetRef :: TcS (IORef InertSet)
 getInertSetRef = TcS (return . tcs_inerts)
@@ -2040,21 +2043,28 @@ wrapUnifierX :: CtEvidence -> Role
              -> (UnifyEnv -> TcM a)  -- Some calls to uType
              -> TcS (a, Bag Ct, [TcTyVar], RewriterSet)
 wrapUnifierX ev role do_unifications
-  = wrapTcS $
-    do { defer_ref   <- TcM.newTcRef emptyBag
-       ; unified_ref <- TcM.newTcRef []
-       ; rewriters   <- TcM.zonkRewriterSet (ctEvRewriters ev)
-       ; let env = UE { u_role      = role
-                      , u_rewriters = rewriters
-                      , u_loc       = ctEvLoc ev
-                      , u_defer     = defer_ref
-                      , u_unified   = Just unified_ref}
-
-       ; res <- do_unifications env
-
-       ; cts     <- TcM.readTcRef defer_ref
-       ; unified <- TcM.readTcRef unified_ref
-       ; return (res, cts, unified, rewriters) }
+  = do { unif_count_ref <- getUnifiedRef
+       ; wrapTcS $
+         do { defer_ref   <- TcM.newTcRef emptyBag
+            ; unified_ref <- TcM.newTcRef []
+            ; rewriters   <- TcM.zonkRewriterSet (ctEvRewriters ev)
+            ; let env = UE { u_role      = role
+                           , u_rewriters = rewriters
+                           , u_loc       = ctEvLoc ev
+                           , u_defer     = defer_ref
+                           , u_unified   = Just unified_ref}
+
+            ; res <- do_unifications env
+
+            ; cts     <- TcM.readTcRef defer_ref
+            ; unified <- TcM.readTcRef unified_ref
+
+            -- Don't forget to update the count of variables
+            -- unified, lest we forget to iterate (#24146)
+            ; unless (null unified) $
+              TcM.updTcRef unif_count_ref (+ (length unified))
+
+            ; return (res, cts, unified, rewriters) } }
 
 
 {-
diff --git a/testsuite/tests/typecheck/should_compile/T24146.hs b/testsuite/tests/typecheck/should_compile/T24146.hs
new file mode 100644
index 000000000000..b46615161e10
--- /dev/null
+++ b/testsuite/tests/typecheck/should_compile/T24146.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE TypeFamilies #-}
+module M where
+
+class (a ~ b) => Aggregate a b where
+instance Aggregate a a where
+
+liftM :: (Aggregate ae am) => (forall r. am -> r) -> ae
+liftM _ = undefined
+
+class Positive a
+
+mytake :: (Positive n) => n -> r
+mytake = undefined
+
+x :: (Positive n) => n
+x = liftM mytake
diff --git a/testsuite/tests/typecheck/should_compile/all.T b/testsuite/tests/typecheck/should_compile/all.T
index 5e13e55a84fd..218a17b65fc1 100644
--- a/testsuite/tests/typecheck/should_compile/all.T
+++ b/testsuite/tests/typecheck/should_compile/all.T
@@ -902,3 +902,4 @@ test('InstanceWarnings', normal, multimod_compile, ['InstanceWarnings', ''])
 test('T23861', normal, compile, [''])
 test('T23918', normal, compile, [''])
 test('T17564', normal, compile, [''])
+test('T24146', normal, compile, [''])
-- 
GitLab