From edd8bc43566b3f002758e5d08c399b6f4c3d7443 Mon Sep 17 00:00:00 2001
From: Krzysztof Gogolewski <krzysztof.gogolewski@tweag.io>
Date: Thu, 10 Aug 2023 00:13:59 +0200
Subject: [PATCH] Fix MultiWayIf linearity checking (#23814)

Co-authored-by: Thomas BAGREL <thomas.bagrel@tweag.io>
---
 compiler/GHC/Tc/Gen/Expr.hs                   | 27 ++++++++++++++++++-
 .../tests/linear/should_compile/T23814.hs     | 17 ++++++++++++
 testsuite/tests/linear/should_compile/all.T   |  1 +
 .../tests/linear/should_fail/T23814fail.hs    | 17 ++++++++++++
 .../linear/should_fail/T23814fail.stderr      | 17 ++++++++++++
 testsuite/tests/linear/should_fail/all.T      |  1 +
 6 files changed, 79 insertions(+), 1 deletion(-)
 create mode 100644 testsuite/tests/linear/should_compile/T23814.hs
 create mode 100644 testsuite/tests/linear/should_fail/T23814fail.hs
 create mode 100644 testsuite/tests/linear/should_fail/T23814fail.stderr

diff --git a/compiler/GHC/Tc/Gen/Expr.hs b/compiler/GHC/Tc/Gen/Expr.hs
index c7a3c412b3e6..007eeb9dfbd0 100644
--- a/compiler/GHC/Tc/Gen/Expr.hs
+++ b/compiler/GHC/Tc/Gen/Expr.hs
@@ -396,9 +396,34 @@ tcExpr (HsIf x pred b1 b2) res_ty
        ; tcEmitBindingUsage (supUE u1 u2)
        ; return (HsIf x pred' b1' b2') }
 
+{-
+Note [MultiWayIf linearity checking]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we'd like to compute the usage environment for
+
+if | b1 -> e1
+   | b2 -> e2
+   | otherwise -> e3
+
+and let u1, u2, v1, v2, v3 denote the usage env for b1, b2, e1, e2, e3
+respectively.
+
+Since a multi-way if is mere sugar for nested if expressions, the usage
+environment should ideally be u1 + sup(v1, u2 + sup(v2, v3)).
+However, currently we don't support linear guards (#19193). All variables
+used in guards from u1 and u2 will have multiplicity Many.
+But in that case, we have equality u1 + sup(x,y) = sup(u1 + x, y),
+                      and likewise u2 + sup(x,y) = sup(u2 + x, y) for any x,y.
+Using this identity, we can just compute sup(u1 + v1, u2 + v2, v3) instead.
+This is simple to do, since we get u_i + v_i directly from tcGRHS.
+If we add linear guards, this code will have to be revisited.
+Not using 'sup' caused #23814.
+-}
+
 tcExpr (HsMultiIf _ alts) res_ty
-  = do { alts' <- mapM (wrapLocMA $ tcGRHS match_ctxt res_ty) alts
+  = do { (ues, alts') <- mapAndUnzipM (\alt -> tcCollectingUsage $ wrapLocMA (tcGRHS match_ctxt res_ty) alt) alts
        ; res_ty <- readExpType res_ty
+       ; tcEmitBindingUsage (supUEs ues)  -- See Note [MultiWayIf linearity checking]
        ; return (HsMultiIf res_ty alts') }
   where match_ctxt = MC { mc_what = IfAlt, mc_body = tcBody }
 
diff --git a/testsuite/tests/linear/should_compile/T23814.hs b/testsuite/tests/linear/should_compile/T23814.hs
new file mode 100644
index 000000000000..d072452f5776
--- /dev/null
+++ b/testsuite/tests/linear/should_compile/T23814.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE LinearTypes #-}
+{-# LANGUAGE MultiWayIf #-}
+
+module T23814 where
+
+f :: Bool -> Int %1 -> Int
+f b x =
+  if
+    | b -> x
+    | otherwise -> x
+
+g :: Bool -> Bool -> Int %1 -> Int %1 -> (Int, Int)
+g b c x y =
+  if
+    | b -> (x,y)
+    | c -> (y,x)
+    | otherwise -> (x,y)
diff --git a/testsuite/tests/linear/should_compile/all.T b/testsuite/tests/linear/should_compile/all.T
index 4250d3432c10..39d0f82d5f8b 100644
--- a/testsuite/tests/linear/should_compile/all.T
+++ b/testsuite/tests/linear/should_compile/all.T
@@ -42,3 +42,4 @@ test('T20023', normal, compile, [''])
 test('T22546', normal, compile, [''])
 test('T23025', normal, compile, ['-dlinear-core-lint'])
 test('LinearRecUpd', normal, compile, [''])
+test('T23814', normal, compile, [''])
diff --git a/testsuite/tests/linear/should_fail/T23814fail.hs b/testsuite/tests/linear/should_fail/T23814fail.hs
new file mode 100644
index 000000000000..56ad8bdfe4c5
--- /dev/null
+++ b/testsuite/tests/linear/should_fail/T23814fail.hs
@@ -0,0 +1,17 @@
+{-# LANGUAGE LinearTypes #-}
+{-# LANGUAGE MultiWayIf #-}
+
+module T23814fail where
+
+f' :: Bool -> Int %1 -> Int
+f' b x =
+  if
+    | b -> x
+    | otherwise -> 0
+
+g' :: Bool -> Bool -> Int %1 -> Int
+g' b c x =
+   if
+     | b -> x
+     | c -> 0
+     | otherwise -> 0
diff --git a/testsuite/tests/linear/should_fail/T23814fail.stderr b/testsuite/tests/linear/should_fail/T23814fail.stderr
new file mode 100644
index 000000000000..7dad7ee0093f
--- /dev/null
+++ b/testsuite/tests/linear/should_fail/T23814fail.stderr
@@ -0,0 +1,17 @@
+
+T23814fail.hs:7:6: error: [GHC-18872]
+    • Couldn't match type ‘Many’ with ‘One’
+        arising from multiplicity of ‘x’
+    • In an equation for ‘f'’:
+          f' b x
+            = if | b -> x
+                 | otherwise -> 0
+
+T23814fail.hs:13:8: error: [GHC-18872]
+    • Couldn't match type ‘Many’ with ‘One’
+        arising from multiplicity of ‘x’
+    • In an equation for ‘g'’:
+          g' b c x
+            = if | b -> x
+                 | c -> 0
+                 | otherwise -> 0
diff --git a/testsuite/tests/linear/should_fail/all.T b/testsuite/tests/linear/should_fail/all.T
index 2d7c6ed50997..f98692689c41 100644
--- a/testsuite/tests/linear/should_fail/all.T
+++ b/testsuite/tests/linear/should_fail/all.T
@@ -41,3 +41,4 @@ test('T19120', normal, compile_fail, [''])
 test('T20083', normal, compile_fail, ['-XLinearTypes'])
 test('T19361', normal, compile_fail, [''])
 test('T21278', normal, compile_fail, ['-XLinearTypes'])
+test('T23814fail', normal, compile_fail, [''])
-- 
GitLab