From ea1581cb2fd567e9b2ab8c9628196950f50a1e7e Mon Sep 17 00:00:00 2001
From: Ryan Scott <ryan.gl.scott@gmail.com>
Date: Wed, 15 Jan 2020 13:02:34 -0500
Subject: [PATCH] Use splitLHs{ForAll,Sigma}TyInvis throughout the codebase

Richard points out in #17688 that we use `splitLHsForAllTy` and
`splitLHsSigmaTy` in places that we ought to be using the
corresponding `-Invis` variants instead, identifying two bugs
that are caused by this oversight:

* Certain TH-quoted type signatures, such as those that appear in
  quoted `SPECIALISE` pragmas, silently turn visible `forall`s into
  invisible `forall`s.
* When quoted, the type `forall a -> (a ~ a) => a` will turn into
  `forall a -> a` due to a bug in `DsMeta.repForall` that drops
  contexts that follow visible `forall`s.

These are both ultimately caused by the fact that `splitLHsForAllTy`
and `splitLHsSigmaTy` split apart visible `forall`s in addition to
invisible ones. This patch cleans things up:

* We now use `splitLHsForAllTyInvis` and `splitLHsSigmaTyInvis`
  throughout the codebase. Relatedly, the `splitLHsForAllTy` and
  `splitLHsSigmaTy` have been removed, as they are easy to misuse.
* `DsMeta.repForall` now only handles invisible `forall`s to reduce
  the chance for confusion with visible `forall`s, which need to be
  handled differently. I also renamed it from `repForall` to
  `repForallT` to emphasize that its distinguishing characteristic
  is the fact that it desugars down to `L.H.TH.Syntax.ForallT`.

Fixes #17688.

(cherry picked from commit 18c0d037f90482d082c459c2fb74f287ebe3be90)
---
 compiler/GHC/Hs/Types.hs          | 39 +++++++------------------------
 compiler/deSugar/DsMeta.hs        | 32 +++++++++++++++----------
 compiler/rename/RnSource.hs       |  2 +-
 compiler/typecheck/TcBinds.hs     |  2 +-
 compiler/typecheck/TcDeriv.hs     |  2 +-
 compiler/typecheck/TcHsType.hs    |  2 +-
 testsuite/tests/th/T17688a.hs     | 10 ++++++++
 testsuite/tests/th/T17688a.stderr |  1 +
 testsuite/tests/th/T17688b.hs     | 15 ++++++++++++
 testsuite/tests/th/T17688b.stderr |  2 ++
 testsuite/tests/th/all.T          |  2 ++
 11 files changed, 62 insertions(+), 47 deletions(-)
 create mode 100644 testsuite/tests/th/T17688a.hs
 create mode 100644 testsuite/tests/th/T17688a.stderr
 create mode 100644 testsuite/tests/th/T17688b.hs
 create mode 100644 testsuite/tests/th/T17688b.stderr

diff --git a/compiler/GHC/Hs/Types.hs b/compiler/GHC/Hs/Types.hs
index fcf22584cb96..c534e67d2caa 100644
--- a/compiler/GHC/Hs/Types.hs
+++ b/compiler/GHC/Hs/Types.hs
@@ -56,8 +56,7 @@ module GHC.Hs.Types (
         hsLTyVarName, hsLTyVarNames, hsLTyVarLocName, hsExplicitLTyVarNames,
         splitLHsInstDeclTy, getLHsInstDeclHead, getLHsInstDeclClass_maybe,
         splitLHsPatSynTy,
-        splitLHsForAllTy, splitLHsForAllTyInvis,
-        splitLHsQualTy, splitLHsSigmaTy, splitLHsSigmaTyInvis,
+        splitLHsForAllTyInvis, splitLHsQualTy, splitLHsSigmaTyInvis,
         splitHsFunType, hsTyGetAppHead_maybe,
         mkHsOpTy, mkHsAppTy, mkHsAppTys, mkHsAppKindTy,
         ignoreParens, hsSigType, hsSigWcType,
@@ -1248,21 +1247,9 @@ splitLHsPatSynTy ty = (univs, reqs, exis, provs, ty4)
     (provs, ty4) = splitLHsQualTy ty3
 
 -- | Decompose a sigma type (of the form @forall <tvs>. context => body@)
--- into its constituent parts.
---
--- Note that this function looks through parentheses, so it will work on types
--- such as @(forall a. <...>)@. The downside to this is that it is not
--- generally possible to take the returned types and reconstruct the original
--- type (parentheses and all) from them.
-splitLHsSigmaTy :: LHsType pass
-                -> ([LHsTyVarBndr pass], LHsContext pass, LHsType pass)
-splitLHsSigmaTy ty
-  | (tvs, ty1)  <- splitLHsForAllTy ty
-  , (ctxt, ty2) <- splitLHsQualTy ty1
-  = (tvs, ctxt, ty2)
-
--- | Like 'splitLHsSigmaTy', but only splits type variable binders that were
--- quantified invisibly (e.g., @forall a.@, with a dot).
+-- into its constituent parts. Note that only /invisible/ @forall@s
+-- (i.e., @forall a.@, with a dot) are split apart; /visible/ @forall@s
+-- (i.e., @forall a ->@, with an arrow) are left untouched.
 --
 -- This function is used to split apart certain types, such as instance
 -- declaration types, which disallow visible @forall@s. For instance, if GHC
@@ -1280,20 +1267,10 @@ splitLHsSigmaTyInvis ty
   , (ctxt, ty2) <- splitLHsQualTy ty1
   = (tvs, ctxt, ty2)
 
--- | Decompose a type of the form @forall <tvs>. body@) into its constituent
--- parts.
---
--- Note that this function looks through parentheses, so it will work on types
--- such as @(forall a. <...>)@. The downside to this is that it is not
--- generally possible to take the returned types and reconstruct the original
--- type (parentheses and all) from them.
-splitLHsForAllTy :: LHsType pass -> ([LHsTyVarBndr pass], LHsType pass)
-splitLHsForAllTy (L _ (HsParTy _ ty)) = splitLHsForAllTy ty
-splitLHsForAllTy (L _ (HsForAllTy { hst_bndrs = tvs, hst_body = body })) = (tvs, body)
-splitLHsForAllTy body              = ([], body)
-
--- | Like 'splitLHsForAllTy', but only splits type variable binders that
--- were quantified invisibly (e.g., @forall a.@, with a dot).
+-- | Decompose a type of the form @forall <tvs>. body@ into its constituent
+-- parts. Note that only /invisible/ @forall@s
+-- (i.e., @forall a.@, with a dot) are split apart; /visible/ @forall@s
+-- (i.e., @forall a ->@, with an arrow) are left untouched.
 --
 -- This function is used to split apart certain types, such as instance
 -- declaration types, which disallow visible @forall@s. For instance, if GHC
diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs
index 5290d1a978cc..aa102d1952f4 100644
--- a/compiler/deSugar/DsMeta.hs
+++ b/compiler/deSugar/DsMeta.hs
@@ -208,7 +208,7 @@ get_scoped_tvs (dL->L _ signature)
       --    here 'k' scopes too
       | HsIB { hsib_ext = implicit_vars
              , hsib_body = hs_ty } <- sig
-      , (explicit_vars, _) <- splitLHsForAllTy hs_ty
+      , (explicit_vars, _) <- splitLHsForAllTyInvis hs_ty
       = implicit_vars ++ hsLTyVarNames explicit_vars
     get_scoped_tvs_from_sig (XHsImplicitBndrs nec)
       = noExtCon nec
@@ -1095,7 +1095,7 @@ repContext ctxt = do preds <- repList typeQTyConName repLTy ctxt
 repHsSigType :: LHsSigType GhcRn -> DsM (Core TH.TypeQ)
 repHsSigType (HsIB { hsib_ext = implicit_tvs
                    , hsib_body = body })
-  | (explicit_tvs, ctxt, ty) <- splitLHsSigmaTy body
+  | (explicit_tvs, ctxt, ty) <- splitLHsSigmaTyInvis body
   = addSimpleTyVarBinds implicit_tvs $
       -- See Note [Don't quantify implicit type variables in quotes]
     addHsTyVarBinds explicit_tvs $ \ th_explicit_tvs ->
@@ -1119,21 +1119,29 @@ repLTys tys = mapM repLTy tys
 repLTy :: LHsType GhcRn -> DsM (Core TH.TypeQ)
 repLTy ty = repTy (unLoc ty)
 
-repForall :: ForallVisFlag -> HsType GhcRn -> DsM (Core TH.TypeQ)
--- Arg of repForall is always HsForAllTy or HsQualTy
-repForall fvf ty
- | (tvs, ctxt, tau) <- splitLHsSigmaTy (noLoc ty)
+-- Desugar a type headed by an invisible forall (e.g., @forall a. a@) or
+-- a context (e.g., @Show a => a@) into a ForallT from L.H.TH.Syntax.
+-- In other words, the argument to this function is always an
+-- @HsForAllTy ForallInvis@ or @HsQualTy@.
+-- Types headed by visible foralls (which are desugared to ForallVisT) are
+-- handled separately in repTy.
+repForallT :: HsType GhcRn -> DsM (Core TH.TypeQ)
+repForallT ty
+ | (tvs, ctxt, tau) <- splitLHsSigmaTyInvis (noLoc ty)
  = addHsTyVarBinds tvs $ \bndrs ->
    do { ctxt1  <- repLContext ctxt
-      ; ty1    <- repLTy tau
-      ; case fvf of
-          ForallVis   -> repTForallVis bndrs ty1    -- forall a      -> {...}
-          ForallInvis -> repTForall bndrs ctxt1 ty1 -- forall a. C a => {...}
+      ; tau1   <- repLTy tau
+      ; repTForall bndrs ctxt1 tau1 -- forall a. C a => {...}
       }
 
 repTy :: HsType GhcRn -> DsM (Core TH.TypeQ)
-repTy ty@(HsForAllTy {hst_fvf = fvf}) = repForall fvf         ty
-repTy ty@(HsQualTy {})                = repForall ForallInvis ty
+repTy ty@(HsForAllTy { hst_fvf = fvf, hst_bndrs = tvs, hst_body = body }) =
+  case fvf of
+    ForallInvis -> repForallT ty
+    ForallVis   -> addHsTyVarBinds tvs $ \bndrs ->
+                   do body1 <- repLTy body
+                      repTForallVis bndrs body1
+repTy ty@(HsQualTy {}) = repForallT ty
 
 repTy (HsTyVar _ _ (dL->L _ n))
   | isLiftedTypeKindTyConName n       = repTStar
diff --git a/compiler/rename/RnSource.hs b/compiler/rename/RnSource.hs
index 791b6a4cebb6..b67c2cec888e 100644
--- a/compiler/rename/RnSource.hs
+++ b/compiler/rename/RnSource.hs
@@ -1837,7 +1837,7 @@ rnLDerivStrategy doc mds thing_inside
           do (via_ty', fvs1) <- rnHsSigType doc TypeLevel via_ty
              let HsIB { hsib_ext  = via_imp_tvs
                       , hsib_body = via_body } = via_ty'
-                 (via_exp_tv_bndrs, _, _) = splitLHsSigmaTy via_body
+                 (via_exp_tv_bndrs, _, _) = splitLHsSigmaTyInvis via_body
                  via_exp_tvs = hsLTyVarNames via_exp_tv_bndrs
                  via_tvs = via_imp_tvs ++ via_exp_tvs
              (thing, fvs2) <- extendTyVarEnvFVRn via_tvs thing_inside
diff --git a/compiler/typecheck/TcBinds.hs b/compiler/typecheck/TcBinds.hs
index 6421be4f16e2..a9de7ac1f604 100644
--- a/compiler/typecheck/TcBinds.hs
+++ b/compiler/typecheck/TcBinds.hs
@@ -1634,7 +1634,7 @@ decideGeneralisationPlan dflags lbinds closed sig_fn
       = [ null theta
         | TcIdSig (PartialSig { psig_hs_ty = hs_ty })
             <- mapMaybe sig_fn (collectHsBindListBinders lbinds)
-        , let (_, dL->L _ theta, _) = splitLHsSigmaTy (hsSigWcType hs_ty) ]
+        , let (_, dL->L _ theta, _) = splitLHsSigmaTyInvis (hsSigWcType hs_ty) ]
 
     has_partial_sigs   = not (null partial_sig_mrs)
 
diff --git a/compiler/typecheck/TcDeriv.hs b/compiler/typecheck/TcDeriv.hs
index 876acac8320b..1c1b95c9ca4a 100644
--- a/compiler/typecheck/TcDeriv.hs
+++ b/compiler/typecheck/TcDeriv.hs
@@ -716,7 +716,7 @@ tcStandaloneDerivInstType
 tcStandaloneDerivInstType ctxt
     (HsWC { hswc_body = deriv_ty@(HsIB { hsib_ext = vars
                                        , hsib_body   = deriv_ty_body })})
-  | (tvs, theta, rho) <- splitLHsSigmaTy deriv_ty_body
+  | (tvs, theta, rho) <- splitLHsSigmaTyInvis deriv_ty_body
   , L _ [wc_pred] <- theta
   , L wc_span (HsWildCardTy _) <- ignoreParens wc_pred
   = do dfun_ty <- tcHsClsInstType ctxt $
diff --git a/compiler/typecheck/TcHsType.hs b/compiler/typecheck/TcHsType.hs
index 1e5d2ce48f74..18985ae6ec18 100644
--- a/compiler/typecheck/TcHsType.hs
+++ b/compiler/typecheck/TcHsType.hs
@@ -3130,7 +3130,7 @@ tcHsPartialSigType ctxt sig_ty
   | HsWC { hswc_ext  = sig_wcs, hswc_body = ib_ty } <- sig_ty
   , HsIB { hsib_ext = implicit_hs_tvs
          , hsib_body = hs_ty } <- ib_ty
-  , (explicit_hs_tvs, L _ hs_ctxt, hs_tau) <- splitLHsSigmaTy hs_ty
+  , (explicit_hs_tvs, L _ hs_ctxt, hs_tau) <- splitLHsSigmaTyInvis hs_ty
   = addSigCtxt ctxt hs_ty $
     do { (implicit_tvs, (explicit_tvs, (wcs, wcx, theta, tau)))
             <- solveLocalEqualities "tcHsPartialSigType"    $
diff --git a/testsuite/tests/th/T17688a.hs b/testsuite/tests/th/T17688a.hs
new file mode 100644
index 000000000000..aae0b6da21af
--- /dev/null
+++ b/testsuite/tests/th/T17688a.hs
@@ -0,0 +1,10 @@
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE TemplateHaskell #-}
+module T17688a where
+
+import Language.Haskell.TH
+import System.IO
+
+$( do ty <- [d| {-# SPECIALISE id :: forall a -> a -> a #-} |]
+      runIO $ hPutStrLn stderr $ pprint ty
+      return [] )
diff --git a/testsuite/tests/th/T17688a.stderr b/testsuite/tests/th/T17688a.stderr
new file mode 100644
index 000000000000..f746b553b8ea
--- /dev/null
+++ b/testsuite/tests/th/T17688a.stderr
@@ -0,0 +1 @@
+{-# SPECIALISE GHC.Base.id :: forall a_0 -> a_0 -> a_0 #-}
diff --git a/testsuite/tests/th/T17688b.hs b/testsuite/tests/th/T17688b.hs
new file mode 100644
index 000000000000..f78cf0266ad5
--- /dev/null
+++ b/testsuite/tests/th/T17688b.hs
@@ -0,0 +1,15 @@
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+module T17688b where
+
+import Data.Kind
+import Language.Haskell.TH hiding (Type)
+import System.IO
+
+$(do decs <- [d| type T :: forall (a :: Type) -> (a ~ a) => Type
+                 data T x |]
+     runIO $ hPutStrLn stderr $ pprint decs
+     return [] )
diff --git a/testsuite/tests/th/T17688b.stderr b/testsuite/tests/th/T17688b.stderr
new file mode 100644
index 000000000000..e5384ff04547
--- /dev/null
+++ b/testsuite/tests/th/T17688b.stderr
@@ -0,0 +1,2 @@
+type T_0 :: forall (a_1 :: *) -> a_1 ~ a_1 => *
+data T_0 x_2
diff --git a/testsuite/tests/th/all.T b/testsuite/tests/th/all.T
index a14734997ef2..94360f225e67 100644
--- a/testsuite/tests/th/all.T
+++ b/testsuite/tests/th/all.T
@@ -495,3 +495,5 @@ test('T17379a', normal, compile_fail, [''])
 test('T17379b', normal, compile_fail, [''])
 test('T17461', normal, compile, ['-v0 -ddump-splices -dsuppress-uniques'])
 test('T17511', normal, compile, [''])
+test('T17688a', normal, compile, [''])
+test('T17688b', normal, compile, [''])
-- 
GitLab