From 1d05971e24f6cb1120789d1e1ab4f086eebd504a Mon Sep 17 00:00:00 2001
From: sheaf <sam.derbyshire@gmail.com>
Date: Wed, 14 Jun 2023 21:38:04 +0200
Subject: [PATCH] Propagate long-distance information in do-notation

The preceding commit re-enabled pattern-match checking inside record
updates. This revealed that #21360 was in fact NOT fixed by e74fc066.

This commit makes sure we correctly propagate long-distance information
in do blocks, e.g. in

```haskell
data T = A { fld :: Int } | B

f :: T -> Maybe T
f r = do
  a@A{} <- Just r
  Just $ case a of { A _ -> A 9 }
```

we need to propagate the fact that "a" is headed by the constructor "A"
to see that the case expression "case a of { A _ -> A 9 }" cannot fail.

Fixes #21360
---
 compiler/GHC/HsToCore/Expr.hs                 | 73 ++++++++++++++-----
 compiler/GHC/HsToCore/GuardedRHSs.hs          | 13 ++--
 compiler/GHC/HsToCore/ListComp.hs             |  8 +-
 compiler/GHC/HsToCore/Match.hs                | 24 +++++-
 compiler/GHC/HsToCore/Pmc.hs                  | 58 ++++++++++-----
 compiler/GHC/HsToCore/Utils.hs                |  2 +-
 compiler/Language/Haskell/Syntax/Expr.hs      | 10 +--
 .../tests/pmcheck/should_compile/T21360.hs    | 31 ++++----
 .../tests/pmcheck/should_compile/T21360b.hs   | 10 +++
 testsuite/tests/pmcheck/should_compile/all.T  |  1 +
 10 files changed, 162 insertions(+), 68 deletions(-)
 create mode 100644 testsuite/tests/pmcheck/should_compile/T21360b.hs

diff --git a/compiler/GHC/HsToCore/Expr.hs b/compiler/GHC/HsToCore/Expr.hs
index be9347e0e2c8..29fc3a5713c3 100644
--- a/compiler/GHC/HsToCore/Expr.hs
+++ b/compiler/GHC/HsToCore/Expr.hs
@@ -28,7 +28,7 @@ import GHC.HsToCore.ListComp
 import GHC.HsToCore.Utils
 import GHC.HsToCore.Arrows
 import GHC.HsToCore.Monad
-import GHC.HsToCore.Pmc ( addTyCs, pmcGRHSs )
+import GHC.HsToCore.Pmc
 import GHC.HsToCore.Errors.Types
 import GHC.Types.SourceText
 import GHC.Types.Name
@@ -223,19 +223,11 @@ dsUnliftedBind bind body = pprPanic "dsLet: unlifted" (ppr bind $$ ppr body)
 ************************************************************************
 -}
 
-
--- | Replace the body of the function with this block to test the hsExprType
--- function in GHC.Tc.Zonk.Type:
--- putSrcSpanDs loc $ do
---   { core_expr <- dsExpr e
---   ; massertPpr (exprType core_expr `eqType` hsExprType e)
---                (ppr e <+> dcolon <+> ppr (hsExprType e) $$
---                 ppr core_expr <+> dcolon <+> ppr (exprType core_expr))
---   ; return core_expr }
+-- | Desugar a located typechecked expression.
 dsLExpr :: LHsExpr GhcTc -> DsM CoreExpr
-dsLExpr (L loc e) =
-  putSrcSpanDsA loc $ dsExpr e
+dsLExpr (L loc e) = putSrcSpanDsA loc $ dsExpr e
 
+-- | Desugar a typechecked expression.
 dsExpr :: HsExpr GhcTc -> DsM CoreExpr
 dsExpr (HsVar    _ (L _ id))           = dsHsVar id
 dsExpr (HsRecSel _ (FieldOcc id _))    = dsHsVar id
@@ -691,11 +683,13 @@ dsDo ctx stmts
            ; dsLocalBinds binds rest }
 
     go _ (BindStmt xbs pat rhs) stmts
-      = do  { body     <- goL stmts
-            ; rhs'     <- dsLExpr rhs
-            ; var   <- selectSimpleMatchVarL (xbstc_boundResultMult xbs) pat
+      = do  { var   <- selectSimpleMatchVarL (xbstc_boundResultMult xbs) pat
+            ; rhs'  <- dsLExpr rhs
             ; match <- matchSinglePatVar var Nothing (StmtCtxt (HsDoStmt ctx)) pat
-                         (xbstc_boundResultType xbs) (cantFailMatchResult body)
+                                 (xbstc_boundResultType xbs) (MR_Infallible $ goL stmts)
+            -- NB: "goL stmts" needs to happen inside matchSinglePatVar, and not
+            -- before it, so that long-distance information is properly threaded.
+            -- See Note [Long-distance information in do notation].
             ; match_code <- dsHandleMonadicFailure ctx pat match (xbstc_failOp xbs)
             ; dsSyntaxExpr (xbstc_bindOp xbs) [rhs', Lam var match_code] }
 
@@ -774,7 +768,52 @@ dsDo ctx stmts
     go _ (ParStmt   {}) _ = panic "dsDo ParStmt"
     go _ (TransStmt {}) _ = panic "dsDo TransStmt"
 
-{-
+{- Note [Long-distance information in do notation]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider T21360:
+
+  data Foo = A Int | B
+
+  swooble :: Foo -> Maybe Foo
+  swooble foo = do
+    bar@A{} <- Just foo
+    return $ case bar of { A _ -> A 9 }
+
+The pattern-match checker **should not** complain that the case statement
+is incomplete, because we know that 'bar' is headed by the constructor 'A',
+due to the pattern match in the line above. However, we need to ensure that we
+propagate this long-distance information; failing to do so lead to #21360.
+
+To do this, we use "matchSinglePatVar" to handle the first pattern match
+
+  bar@A{} <- Just foo
+
+"matchSinglePatVar" then threads through the long-distance information to the
+desugaring of the remaining statements by using updPmNablasMatchResult.
+This avoids any spurious pattern-match warnings when handling the case
+statement on the last line.
+
+Other places that requires from the same treatment:
+
+  - monad comprehensions, e.g.
+
+     blorble :: Foo -> Maybe Foo
+     blorble foo = [ case bar of { A _ -> A 9 } | bar@A{} <- Just foo ]
+
+     See GHC.HsToCore.ListComp.dsMcBindStmt. Also tested in T21360.
+
+  - guards, e.g.
+
+      giddy :: Maybe Char -> Char
+      giddy x
+        | y@(Just _) <- x
+        , let z = case y of { Just w -> w }
+        = z
+
+    We don't want any inexhaustive pattern match warnings for the case statement,
+    because we already know 'y' is of the form "Just ...".
+    See test case T21360b.
+
 ************************************************************************
 *                                                                      *
    Desugaring Variables
diff --git a/compiler/GHC/HsToCore/GuardedRHSs.hs b/compiler/GHC/HsToCore/GuardedRHSs.hs
index 8a24b00a590d..7490bace47d7 100644
--- a/compiler/GHC/HsToCore/GuardedRHSs.hs
+++ b/compiler/GHC/HsToCore/GuardedRHSs.hs
@@ -79,7 +79,7 @@ dsGRHSs hs_ctx (GRHSs _ grhss binds) rhs_ty rhss_nablas
 dsGRHS :: HsMatchContext GhcTc -> Type -> Nablas -> LGRHS GhcTc (LHsExpr GhcTc)
        -> DsM (MatchResult CoreExpr)
 dsGRHS hs_ctx rhs_ty rhs_nablas (L _ (GRHS _ guards rhs))
-  = matchGuards (map unLoc guards) (PatGuard hs_ctx) rhs_nablas rhs rhs_ty
+  = matchGuards (map unLoc guards) hs_ctx rhs_nablas rhs rhs_ty
 
 {-
 ************************************************************************
@@ -90,7 +90,7 @@ dsGRHS hs_ctx rhs_ty rhs_nablas (L _ (GRHS _ guards rhs))
 -}
 
 matchGuards :: [GuardStmt GhcTc]     -- Guard
-            -> HsStmtContext GhcTc   -- Context
+            -> HsMatchContext GhcTc  -- Context
             -> Nablas                -- The RHS's covered set for PmCheck
             -> LHsExpr GhcTc         -- RHS
             -> Type                  -- Type of RHS of guard
@@ -130,15 +130,16 @@ matchGuards (LetStmt _ binds : stmts) ctx nablas rhs rhs_ty = do
 matchGuards (BindStmt _ pat bind_rhs : stmts) ctx nablas rhs rhs_ty = do
     let upat = unLoc pat
     match_var <- selectMatchVar ManyTy upat
-       -- We only allow unrestricted patterns in guard, hence the `Many`
+       -- We only allow unrestricted patterns in guards, hence the `Many`
        -- above. It isn't clear what linear patterns would mean, maybe we will
        -- figure it out in the future.
 
     match_result <- matchGuards stmts ctx nablas rhs rhs_ty
     core_rhs <- dsLExpr bind_rhs
-    match_result' <- matchSinglePatVar match_var (Just core_rhs) (StmtCtxt ctx)
-                                       pat rhs_ty match_result
-    pure $ bindNonRec match_var core_rhs <$> match_result'
+    match_result' <-
+      matchSinglePatVar match_var (Just core_rhs) (StmtCtxt $ PatGuard ctx)
+      pat rhs_ty match_result
+    return $ bindNonRec match_var core_rhs <$> match_result'
 
 matchGuards (LastStmt  {} : _) _ _ _ _ = panic "matchGuards LastStmt"
 matchGuards (ParStmt   {} : _) _ _ _ _ = panic "matchGuards ParStmt"
diff --git a/compiler/GHC/HsToCore/ListComp.hs b/compiler/GHC/HsToCore/ListComp.hs
index 34604eec3d56..07fd1a3bc8ce 100644
--- a/compiler/GHC/HsToCore/ListComp.hs
+++ b/compiler/GHC/HsToCore/ListComp.hs
@@ -603,10 +603,12 @@ dsMcBindStmt :: LPat GhcTc
              -> [ExprLStmt GhcTc]
              -> DsM CoreExpr
 dsMcBindStmt pat rhs' bind_op fail_op res1_ty stmts
-  = do  { body     <- dsMcStmts stmts
-        ; var      <- selectSimpleMatchVarL ManyTy pat
+  = do  { var   <- selectSimpleMatchVarL ManyTy pat
         ; match <- matchSinglePatVar var Nothing (StmtCtxt (HsDoStmt (DoExpr Nothing))) pat
-                                  res1_ty (cantFailMatchResult body)
+                      res1_ty (MR_Infallible $ dsMcStmts stmts)
+            -- NB: dsMcStmts needs to happen inside matchSinglePatVar, and not
+            -- before it, so that long-distance information is properly threaded.
+            -- See Note [Long-distance information in do notation] in GHC.HsToCore.Expr.
         ; match_code <- dsHandleMonadicFailure MonadComp pat match fail_op
         ; dsSyntaxExpr bind_op [rhs', Lam var match_code] }
 
diff --git a/compiler/GHC/HsToCore/Match.hs b/compiler/GHC/HsToCore/Match.hs
index 6be944d1242e..49dc24ae074d 100644
--- a/compiler/GHC/HsToCore/Match.hs
+++ b/compiler/GHC/HsToCore/Match.hs
@@ -1,4 +1,5 @@
 
+{-# LANGUAGE LambdaCase #-}
 {-# LANGUAGE MonadComprehensions #-}
 {-# LANGUAGE OverloadedLists #-}
 {-# LANGUAGE PatternSynonyms #-}
@@ -72,7 +73,7 @@ import GHC.Data.FastString
 import GHC.Types.Unique
 import GHC.Types.Unique.DFM
 
-import Control.Monad ( zipWithM, unless, when )
+import Control.Monad ( zipWithM, unless )
 import Data.List.NonEmpty (NonEmpty(..))
 import qualified Data.List.NonEmpty as NEL
 import qualified Data.Map as Map
@@ -948,16 +949,31 @@ matchSinglePatVar var mb_scrut ctx pat ty match_result
   = assertPpr (isInternalName (idName var)) (ppr var) $
     do { dflags <- getDynFlags
        ; locn   <- getSrcSpanDs
-       -- Pattern match check warnings
-       ; when (isMatchContextPmChecked dflags FromSource ctx) $
+       -- Pattern match check warnings.
+       -- See Note [Long-distance information in matchWrapper] and
+       -- Note [Long-distance information in do notation] in GHC.HsToCore.Expr.
+       ; ldi_nablas <-
+         if isMatchContextPmChecked dflags FromSource ctx
+         then
            addCoreScrutTmCs (maybeToList mb_scrut) [var] $
            pmcPatBind (DsMatchContext ctx locn) var (unLoc pat)
+         else getLdiNablas
 
        ; let eqn_info = EqnInfo { eqn_pats = [unLoc (decideBangHood dflags pat)]
                                 , eqn_orig = FromSource
-                                , eqn_rhs  = match_result }
+                                , eqn_rhs  =
+               updPmNablasMatchResult ldi_nablas match_result }
+               -- See Note [Long-distance information in do notation]
+               -- in GHC.HsToCore.Expr.
+
        ; match [var] ty [eqn_info] }
 
+updPmNablasMatchResult :: Nablas -> MatchResult r -> MatchResult r
+updPmNablasMatchResult nablas = \case
+  MR_Infallible body_fn -> MR_Infallible $
+    updPmNablas nablas body_fn
+  MR_Fallible body_fn -> MR_Fallible $ \fail ->
+    updPmNablas nablas $ body_fn fail
 
 {-
 ************************************************************************
diff --git a/compiler/GHC/HsToCore/Pmc.hs b/compiler/GHC/HsToCore/Pmc.hs
index aa72db0aed57..f588c842c8b9 100644
--- a/compiler/GHC/HsToCore/Pmc.hs
+++ b/compiler/GHC/HsToCore/Pmc.hs
@@ -98,16 +98,31 @@ noCheckDs :: DsM a -> DsM a
 noCheckDs = updTopFlags (\dflags -> foldl' wopt_unset dflags allPmCheckWarnings)
 
 -- | Check a pattern binding (let, where) for exhaustiveness.
-pmcPatBind :: DsMatchContext -> Id -> Pat GhcTc -> DsM ()
--- See Note [pmcPatBind only checks PatBindRhs]
-pmcPatBind ctxt@(DsMatchContext PatBindRhs loc) var p = do
-  !missing <- getLdiNablas
-  pat_bind <- noCheckDs $ desugarPatBind loc var p
-  tracePm "pmcPatBind {" (vcat [ppr ctxt, ppr var, ppr p, ppr pat_bind, ppr missing])
-  result <- unCA (checkPatBind pat_bind) missing
-  tracePm "}: " (ppr (cr_uncov result))
-  formatReportWarnings ReportPatBind ctxt [var] result
-pmcPatBind _ _ _ = pure ()
+pmcPatBind :: DsMatchContext -> Id -> Pat GhcTc -> DsM Nablas
+pmcPatBind ctxt@(DsMatchContext match_ctxt loc) var p
+  = mb_discard_warnings $ do
+      !missing <- getLdiNablas
+      pat_bind <- noCheckDs $ desugarPatBind loc var p
+      tracePm "pmcPatBind {" (vcat [ppr ctxt, ppr var, ppr p, ppr pat_bind, ppr missing])
+      result <- unCA (checkPatBind pat_bind) missing
+      let ldi = ldiGRHS $ ( \ pb -> case pb of PmPatBind grhs -> grhs) $ cr_ret result
+      tracePm "pmcPatBind }: " $
+        vcat [ text "cr_uncov:" <+> ppr (cr_uncov result)
+             , text "ldi:" <+> ppr ldi ]
+      formatReportWarnings ReportPatBind ctxt [var] result
+      return ldi
+  where
+    -- See Note [pmcPatBind doesn't warn on pattern guards]
+    mb_discard_warnings
+      = if want_pmc match_ctxt
+        then id
+        else discardWarningsDs
+    want_pmc PatBindRhs = True
+    want_pmc (StmtCtxt stmt_ctxt) =
+      case stmt_ctxt of
+        PatGuard {} -> False
+        _           -> True
+    want_pmc _ = False
 
 -- | Exhaustive for guard matches, is used for guards in pattern bindings and
 -- in @MultiIf@ expressions. Returns the 'Nablas' covered by the RHSs.
@@ -178,22 +193,29 @@ pmcMatches ctxt vars matches = {-# SCC "pmcMatches" #-} do
       {-# SCC "formatReportWarnings" #-} formatReportWarnings ReportMatchGroup ctxt vars result
       return (NE.toList (ldiMatchGroup (cr_ret result)))
 
-{- Note [pmcPatBind only checks PatBindRhs]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-@pmcPatBind@'s sole purpose is to check vanilla pattern bindings, like
+{- Note [pmcPatBind doesn't warn on pattern guards]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+@pmcPatBind@'s main purpose is to check vanilla pattern bindings, like
 @x :: Int; Just x = e@, which is in a @PatBindRhs@ context.
 But its caller is also called for individual pattern guards in a @StmtCtxt@.
 For example, both pattern guards in @f x y | True <- x, False <- y = ...@ will
-go through this function. It makes no sense to do coverage checking there:
+go through this function. It makes no sense to report pattern match warnings
+for these pattern guards:
+
   * Pattern guards may well fail. Fall-through is not an unrecoverable panic,
     but rather behavior the programmer expects, so inexhaustivity should not be
     reported.
+
   * Redundancy is already reported for the whole GRHS via one of the other
-    exported coverage checking functions. Also reporting individual redundant
+    exported coverage checking functions. Also, reporting individual redundant
     guards is... redundant. See #17646.
-Note that we can't just omit checking of @StmtCtxt@ altogether (by adjusting
-'isMatchContextPmChecked'), because that affects the other checking functions,
-too.
+
+However, we should not skip pattern-match checking altogether, as it may reveal
+important long-distance information. One example is described in
+Note [Long-distance information in do notation] in GHC.HsToCore.Expr.
+
+Instead, we simply discard warnings when in pattern-guards, by using the function
+discardWarningsDs.
 -}
 
 --
diff --git a/compiler/GHC/HsToCore/Utils.hs b/compiler/GHC/HsToCore/Utils.hs
index 0f696d6b28d2..6ba9d842d943 100644
--- a/compiler/GHC/HsToCore/Utils.hs
+++ b/compiler/GHC/HsToCore/Utils.hs
@@ -141,7 +141,7 @@ selectMatchVar _w (VarPat _ var)    = return (localiseId (unLoc var))
                                   -- multiplicity stored within the variable
                                   -- itself. It's easier to pull it from the
                                   -- variable, so we ignore the multiplicity.
-selectMatchVar _w (AsPat _ var _ _) = assert (isManyTy _w ) (return (unLoc var))
+selectMatchVar _w (AsPat _ var _ _) = assert (isManyTy _w ) (return (localiseId (unLoc var)))
 selectMatchVar w other_pat        = newSysLocalDs w (hsPatType other_pat)
 
 {- Note [Localise pattern binders]
diff --git a/compiler/Language/Haskell/Syntax/Expr.hs b/compiler/Language/Haskell/Syntax/Expr.hs
index 46419787f82c..c79b54518a4b 100644
--- a/compiler/Language/Haskell/Syntax/Expr.hs
+++ b/compiler/Language/Haskell/Syntax/Expr.hs
@@ -1600,11 +1600,11 @@ isPatSynCtxt ctxt =
 
 -- | Haskell Statement Context.
 data HsStmtContext p
-  = HsDoStmt HsDoFlavour             -- ^Context for HsDo (do-notation and comprehensions)
-  | PatGuard (HsMatchContext p)      -- ^Pattern guard for specified thing
-  | ParStmtCtxt (HsStmtContext p)    -- ^A branch of a parallel stmt
-  | TransStmtCtxt (HsStmtContext p)  -- ^A branch of a transform stmt
-  | ArrowExpr                        -- ^do-notation in an arrow-command context
+  = HsDoStmt HsDoFlavour             -- ^ Context for HsDo (do-notation and comprehensions)
+  | PatGuard (HsMatchContext p)      -- ^ Pattern guard for specified thing
+  | ParStmtCtxt (HsStmtContext p)    -- ^ A branch of a parallel stmt
+  | TransStmtCtxt (HsStmtContext p)  -- ^ A branch of a transform stmt
+  | ArrowExpr                        -- ^ do-notation in an arrow-command context
 
 -- | Haskell arrow match context.
 data HsArrowMatchContext
diff --git a/testsuite/tests/pmcheck/should_compile/T21360.hs b/testsuite/tests/pmcheck/should_compile/T21360.hs
index 80a8afebde05..4bbe563194b0 100644
--- a/testsuite/tests/pmcheck/should_compile/T21360.hs
+++ b/testsuite/tests/pmcheck/should_compile/T21360.hs
@@ -1,20 +1,23 @@
+{-# LANGUAGE MonadComprehensions #-}
+
 module T21360 where
 
 data Foo = A {a :: Int} | B deriving Show
 
-foo = A 4
-
--- wibble is safe - no warning
-wibble = do
-  case foo of
-    bar@A{} -> Just bar{a = 9}
-    _ -> fail ":("
-
--- using guards doesn't throw a warning
-twomble | bar@A{} <- foo = Just bar{a = 9}
-        | otherwise  = fail ":("
+sworble :: Foo -> Maybe Foo
+sworble foo = do
+  bar@A{} <- Just foo
+  return $ bar { a = 9 }
+    -- we should not get a warning, because long-distance information
+    -- from the previous line should allow us to see that the record update
+    -- is not partial
 
--- sworble has the same semantics as wibble and twomble - but we get a warning!
-sworble = do
+swooble :: Foo -> Maybe Foo
+swooble foo = do
   bar@A{} <- Just foo
-  Just bar{a = 9}
+  return $ case bar of { A _ -> A 9 }
+  -- same here
+
+-- same as swooble but using a monad comprehension
+blorble :: Foo -> Maybe Foo
+blorble foo = [ case bar of { A _ -> A 9 } | bar@A{} <- Just foo ]
diff --git a/testsuite/tests/pmcheck/should_compile/T21360b.hs b/testsuite/tests/pmcheck/should_compile/T21360b.hs
new file mode 100644
index 000000000000..25059e6f9363
--- /dev/null
+++ b/testsuite/tests/pmcheck/should_compile/T21360b.hs
@@ -0,0 +1,10 @@
+module T21360b where
+
+foo :: Maybe Char -> Char
+foo x
+  | y@(Just _) <- x
+  , let z = case y of { Just w -> w }
+  , let _ = case x of { Just _ -> 'r' }
+  = z
+  | otherwise
+  = 'o'
diff --git a/testsuite/tests/pmcheck/should_compile/all.T b/testsuite/tests/pmcheck/should_compile/all.T
index f9470110a647..141aa0433b31 100644
--- a/testsuite/tests/pmcheck/should_compile/all.T
+++ b/testsuite/tests/pmcheck/should_compile/all.T
@@ -90,6 +90,7 @@ test('T19622', normal, compile, [overlapping_incomplete])
 test('T20631', normal, compile, [overlapping_incomplete])
 test('T20642', normal, compile, [overlapping_incomplete])
 test('T21360', normal, compile, [overlapping_incomplete+'-Wincomplete-record-updates'])
+test('T21360b', normal, compile, [overlapping_incomplete+'-Wincomplete-record-updates'])
 test('T23520', normal, compile, [overlapping_incomplete+'-Wincomplete-record-updates'])
 
 # Other tests
-- 
GitLab