From 9f987235c8f20df55b9b16c8f9e0232288c9faeb Mon Sep 17 00:00:00 2001
From: Apoorv Ingle <apoorv-ingle@uiowa.edu>
Date: Mon, 5 Feb 2024 06:39:25 -0600
Subject: [PATCH] Enable mdo statements to use HsExpansions Fixes: #24411 Added
 test T24411 for regression

---
 compiler/GHC/Tc/Gen/App.hs                     |  3 +++
 compiler/GHC/Tc/Gen/Match.hs                   |  7 +++----
 testsuite/tests/typecheck/should_run/T24411.hs | 18 ++++++++++++++++++
 .../tests/typecheck/should_run/T24411.stdout   |  2 ++
 testsuite/tests/typecheck/should_run/all.T     |  1 +
 5 files changed, 27 insertions(+), 4 deletions(-)
 create mode 100644 testsuite/tests/typecheck/should_run/T24411.hs
 create mode 100644 testsuite/tests/typecheck/should_run/T24411.stdout

diff --git a/compiler/GHC/Tc/Gen/App.hs b/compiler/GHC/Tc/Gen/App.hs
index ee6c2fe599d..6b22d661875 100644
--- a/compiler/GHC/Tc/Gen/App.hs
+++ b/compiler/GHC/Tc/Gen/App.hs
@@ -791,6 +791,9 @@ addArgCtxt :: AppCtxt -> LHsExpr GhcRn
 --    b. Or, we are typechecking the second argument which would be a generated lambda
 --       so we set the location to be whatever the location in the context is
 --  See Note [Expanding HsDo with XXExprGhcRn] in GHC.Tc.Gen.Do
+-- For future: we need a cleaner way of doing this bit of adding the right error context.
+-- There is a delicate dance of looking at source locations and reconstructing
+-- whether the piece of code is a `do`-expanded code or some other expanded code.
 addArgCtxt ctxt (L arg_loc arg) thing_inside
   = do { in_generated_code <- inGeneratedCode
        ; case ctxt of
diff --git a/compiler/GHC/Tc/Gen/Match.hs b/compiler/GHC/Tc/Gen/Match.hs
index 576d27d1a36..83ff2e51191 100644
--- a/compiler/GHC/Tc/Gen/Match.hs
+++ b/compiler/GHC/Tc/Gen/Match.hs
@@ -364,10 +364,9 @@ tcDoStmts doExpr@(DoExpr _) ss@(L l stmts) res_ty
                   ; mkExpandedExprTc (HsDo noExtField doExpr ss) <$> tcExpr (unLoc expanded_expr) res_ty }
         }
 
-tcDoStmts mDoExpr@(MDoExpr _) (L l stmts) res_ty
-  = do  { stmts' <- tcStmts (HsDoStmt mDoExpr) tcDoStmt stmts res_ty
-        ; res_ty <- readExpType res_ty
-        ; return (HsDo res_ty mDoExpr (L l stmts')) }
+tcDoStmts mDoExpr@(MDoExpr _) ss@(L _ stmts) res_ty
+  = do  { expanded_expr <- expandDoStmts mDoExpr stmts -- Do expansion on the fly
+        ; mkExpandedExprTc (HsDo noExtField mDoExpr ss) <$> tcExpr (unLoc expanded_expr) res_ty  }
 
 tcDoStmts MonadComp (L l stmts) res_ty
   = do  { stmts' <- tcStmts (HsDoStmt MonadComp) tcMcStmt stmts res_ty
diff --git a/testsuite/tests/typecheck/should_run/T24411.hs b/testsuite/tests/typecheck/should_run/T24411.hs
new file mode 100644
index 00000000000..367eb3dd3a8
--- /dev/null
+++ b/testsuite/tests/typecheck/should_run/T24411.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE ImpredicativeTypes, RecursiveDo #-}
+
+type Id = forall a. a -> a
+
+t :: IO Id
+t = return id
+
+p :: Id -> (Bool, Int)
+p f = (f True, f 3)
+
+foo1 = t >>= \x -> return (p x)
+
+foo2 = mdo { x <- t ; return (p x) }
+
+main = do x <- foo2
+          y <- foo1
+          putStrLn $ show x
+          putStrLn $ show y
diff --git a/testsuite/tests/typecheck/should_run/T24411.stdout b/testsuite/tests/typecheck/should_run/T24411.stdout
new file mode 100644
index 00000000000..6c72082ed05
--- /dev/null
+++ b/testsuite/tests/typecheck/should_run/T24411.stdout
@@ -0,0 +1,2 @@
+(True,3)
+(True,3)
diff --git a/testsuite/tests/typecheck/should_run/all.T b/testsuite/tests/typecheck/should_run/all.T
index 1d170695492..5d33c00809f 100755
--- a/testsuite/tests/typecheck/should_run/all.T
+++ b/testsuite/tests/typecheck/should_run/all.T
@@ -176,3 +176,4 @@ test('T23761b', normal, compile_and_run, [''])
 test('T18324', normal, compile_and_run, [''])
 test('T15598', normal, compile_and_run, [''])
 test('T22086', normal, compile_and_run, [''])
+test('T24411', normal, compile_and_run, [''])
-- 
GitLab