Commit 74311e10 authored by Sebastian Graf's avatar Sebastian Graf Committed by Marge Bot

PmCheck: Implement Long-distance information with Covered sets

Consider

```hs
data T = A | B | C

f :: T -> Int
f A = 1
f x = case x of
  A -> 2
  B -> 3
  C -> 4
```

Clearly, the RHS returning 2 is redundant. But we don't currently see
that, because our approximation to the covered set of the inner case
expression just picks up the positive information from surrounding
pattern matches. It lacks the context sensivity that `x` can't be `A`
anymore!

Therefore, we adopt the conceptually and practically superior approach
of reusing the covered set of a particular GRHS from an outer pattern
match. In this case, we begin checking the `case` expression with the
covered set of `f`s second clause, which encodes the information that
`x` can't be `A` anymore. After this MR, we will successfully warn about
the RHS returning 2 being redundant.

Perhaps surprisingly, this was a great simplification to the code of
both the coverage checker and the desugarer.

Found a redundant case alternative in `unix` submodule, so we have to
bump it with a fix.

Metric Decrease:
    T12227
parent 817f93ea
Pipeline #16284 passed with stages
in 427 minutes and 2 seconds
......@@ -1482,17 +1482,20 @@ genMachOp_slow opt op [x, y] = case op of
MO_FF_Conv _ _ -> panicOp
MO_V_Insert {} -> panicOp
MO_V_Extract {} -> panicOp
MO_VS_Neg {} -> panicOp
MO_VF_Insert {} -> panicOp
MO_VF_Extract {} -> panicOp
MO_VF_Neg {} -> panicOp
MO_AlignmentCheck {} -> panicOp
#if __GLASGOW_HASKELL__ < 811
MO_VF_Extract {} -> panicOp
MO_V_Extract {} -> panicOp
#endif
where
binLlvmOp ty binOp = runExprData $ do
vx <- exprToVarW x
......
......@@ -353,7 +353,9 @@ compileForeign hsc_env lang stub_c = do
LangObjc -> Cobjc
LangObjcxx -> Cobjcxx
LangAsm -> As True -- allow CPP
#if __GLASGOW_HASKELL__ < 811
RawObject -> panic "compileForeign: should be unreachable"
#endif
(_, stub_o, _) <- runPipeline StopLn hsc_env
(stub_c, Nothing, Just (RealPhase phase))
Nothing (Temporary TFL_GhcSession)
......
......@@ -75,6 +75,7 @@ import UniqSet( nonDetEltsUniqSet )
import MonadUtils
import qualified GHC.LanguageExtensions as LangExt
import Control.Monad
import Data.List.NonEmpty ( nonEmpty )
{-**********************************************************************
* *
......@@ -175,8 +176,8 @@ dsHsBind dflags b@(FunBind { fun_id = L _ fun
dsHsBind dflags (PatBind { pat_lhs = pat, pat_rhs = grhss
, pat_ext = NPatBindTc _ ty
, pat_ticks = (rhs_tick, var_ticks) })
= do { body_expr <- dsGuarded grhss ty
; checkGuardMatches PatBindGuards grhss
= do { rhss_deltas <- checkGuardMatches PatBindGuards grhss
; body_expr <- dsGuarded grhss ty (nonEmpty rhss_deltas)
; let body' = mkOptTickBox rhs_tick body_expr
pat' = decideBangHood dflags pat
; (force_var,sel_binds) <- mkSelectorBinds var_ticks pat body'
......
......@@ -69,6 +69,7 @@ import Outputable
import PatSyn
import Control.Monad
import Data.List.NonEmpty ( nonEmpty )
{-
************************************************************************
......@@ -216,8 +217,8 @@ dsUnliftedBind (PatBind {pat_lhs = pat, pat_rhs = grhss
, pat_ext = NPatBindTc _ ty }) body
= -- let C x# y# = rhs in body
-- ==> case rhs of C x# y# -> body
do { rhs <- dsGuarded grhss ty
; checkGuardMatches PatBindGuards grhss
do { rhs_deltas <- checkGuardMatches PatBindGuards grhss
; rhs <- dsGuarded grhss ty (nonEmpty rhs_deltas)
; let upat = unLoc pat
eqn = EqnInfo { eqn_pats = [upat],
eqn_orig = FromSource,
......@@ -446,9 +447,9 @@ dsExpr (HsMultiIf res_ty alts)
= mkErrorExpr
| otherwise
= do { match_result <- liftM (foldr1 combineMatchResults)
(mapM (dsGRHS IfAlt res_ty) alts)
; checkGuardMatches IfAlt (GRHSs noExtField alts (noLoc emptyLocalBinds))
= do { let grhss = GRHSs noExtField alts (noLoc emptyLocalBinds)
; rhss_deltas <- checkGuardMatches IfAlt grhss
; match_result <- dsGRHSs IfAlt grhss res_ty (nonEmpty rhss_deltas)
; error_expr <- mkErrorExpr
; extractMatchResult match_result error_expr }
where
......
......@@ -9,7 +9,7 @@ Matching guarded right-hand-sides (GRHSs)
{-# LANGUAGE CPP #-}
{-# LANGUAGE ViewPatterns #-}
module GHC.HsToCore.GuardedRHSs ( dsGuarded, dsGRHSs, dsGRHS, isTrueLHsExpr ) where
module GHC.HsToCore.GuardedRHSs ( dsGuarded, dsGRHSs, isTrueLHsExpr ) where
#include "HsVersions.h"
......@@ -23,15 +23,15 @@ import GHC.Core.Make
import GHC.Core
import GHC.Core.Utils (bindNonRec)
import BasicTypes (Origin(FromSource))
import GHC.Driver.Session
import GHC.HsToCore.PmCheck (needToRunPmCheck, addTyCsDs, addPatTmCs, addScrutTmCs)
import GHC.HsToCore.Monad
import GHC.HsToCore.Utils
import GHC.HsToCore.PmCheck.Types ( Deltas, initDeltas )
import Type ( Type )
import Util
import SrcLoc
import Outputable
import Control.Monad ( zipWithM )
import Data.List.NonEmpty ( NonEmpty, toList )
{-
@dsGuarded@ is used for pattern bindings.
......@@ -46,32 +46,38 @@ producing an expression with a runtime error in the corner if
necessary. The type argument gives the type of the @ei@.
-}
dsGuarded :: GRHSs GhcTc (LHsExpr GhcTc) -> Type -> DsM CoreExpr
dsGuarded grhss rhs_ty = do
match_result <- dsGRHSs PatBindRhs grhss rhs_ty
dsGuarded :: GRHSs GhcTc (LHsExpr GhcTc) -> Type -> Maybe (NonEmpty Deltas) -> DsM CoreExpr
dsGuarded grhss rhs_ty mb_rhss_deltas = do
match_result <- dsGRHSs PatBindRhs grhss rhs_ty mb_rhss_deltas
error_expr <- mkErrorAppDs nON_EXHAUSTIVE_GUARDS_ERROR_ID rhs_ty empty
extractMatchResult match_result error_expr
-- In contrast, @dsGRHSs@ produces a @MatchResult@.
dsGRHSs :: HsMatchContext GhcRn
-> GRHSs GhcTc (LHsExpr GhcTc) -- Guarded RHSs
-> Type -- Type of RHS
-> GRHSs GhcTc (LHsExpr GhcTc) -- ^ Guarded RHSs
-> Type -- ^ Type of RHS
-> Maybe (NonEmpty Deltas) -- ^ Refined pattern match checking
-- models, one for each GRHS. Defaults
-- to 'initDeltas' if 'Nothing'.
-> DsM MatchResult
dsGRHSs hs_ctx (GRHSs _ grhss binds) rhs_ty
dsGRHSs hs_ctx (GRHSs _ grhss binds) rhs_ty mb_rhss_deltas
= ASSERT( notNull grhss )
do { match_results <- mapM (dsGRHS hs_ctx rhs_ty) grhss
do { match_results <- case toList <$> mb_rhss_deltas of
Nothing -> mapM (dsGRHS hs_ctx rhs_ty initDeltas) grhss
Just rhss_deltas -> ASSERT( length grhss == length rhss_deltas )
zipWithM (dsGRHS hs_ctx rhs_ty) rhss_deltas grhss
; let match_result1 = foldr1 combineMatchResults match_results
match_result2 = adjustMatchResultDs (dsLocalBinds binds) match_result1
-- NB: nested dsLet inside matchResult
; return match_result2 }
dsGRHSs _ (XGRHSs nec) _ = noExtCon nec
dsGRHSs _ (XGRHSs nec) _ _ = noExtCon nec
dsGRHS :: HsMatchContext GhcRn -> Type -> LGRHS GhcTc (LHsExpr GhcTc)
dsGRHS :: HsMatchContext GhcRn -> Type -> Deltas -> LGRHS GhcTc (LHsExpr GhcTc)
-> DsM MatchResult
dsGRHS hs_ctx rhs_ty (L _ (GRHS _ guards rhs))
= matchGuards (map unLoc guards) (PatGuard hs_ctx) rhs rhs_ty
dsGRHS _ _ (L _ (XGRHS nec)) = noExtCon nec
dsGRHS hs_ctx rhs_ty rhs_deltas (L _ (GRHS _ guards rhs))
= updPmDeltas rhs_deltas (matchGuards (map unLoc guards) (PatGuard hs_ctx) rhs rhs_ty)
dsGRHS _ _ _ (L _ (XGRHS nec)) = noExtCon nec
{-
************************************************************************
......@@ -120,18 +126,9 @@ matchGuards (LetStmt _ binds : stmts) ctx rhs rhs_ty = do
matchGuards (BindStmt _ pat bind_rhs _ _ : stmts) ctx rhs rhs_ty = do
let upat = unLoc pat
dicts = collectEvVarsPat upat
match_var <- selectMatchVar upat
dflags <- getDynFlags
match_result <-
-- See Note [Type and Term Equality Propagation] in Check
applyWhen (needToRunPmCheck dflags FromSource)
-- FromSource might not be accurate, but at worst
-- we do superfluous calls to the pattern match
-- oracle.
(addTyCsDs dicts . addScrutTmCs (Just bind_rhs) [match_var] . addPatTmCs [upat] [match_var])
(matchGuards stmts ctx rhs rhs_ty)
match_result <- matchGuards stmts ctx rhs rhs_ty
core_rhs <- dsLExpr bind_rhs
match_result' <- matchSinglePatVar match_var (StmtCtxt ctx) pat rhs_ty
match_result
......
......@@ -62,7 +62,7 @@ import FastString
import Unique
import UniqDFM
import Control.Monad( when, unless )
import Control.Monad( unless )
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NEL
import qualified Data.Map as Map
......@@ -742,36 +742,49 @@ matchWrapper ctxt mb_scr (MG { mg_alts = L _ matches
[] -> mapM newSysLocalDsNoLP arg_tys
(m:_) -> selectMatchVars (map unLoc (hsLMatchPats m))
; eqns_info <- mapM (mk_eqn_info new_vars) matches
-- Pattern match check warnings for /this match-group/.
-- @rhss_deltas@ is a flat list of covered Deltas for each RHS.
-- Each Match will split off one Deltas for its RHSs from this.
; rhss_deltas <- if isMatchContextPmChecked dflags origin ctxt
then addScrutTmCs mb_scr new_vars $
-- See Note [Type and Term Equality Propagation]
checkMatches (DsMatchContext ctxt locn) new_vars matches
else pure [] -- Ultimately this will result in passing Nothing
-- to dsGRHSs as match_deltas
-- Pattern match check warnings for /this match-group/
; when (isMatchContextPmChecked dflags origin ctxt) $
addScrutTmCs mb_scr new_vars $
-- See Note [Type and Term Equality Propagation]
checkMatches dflags (DsMatchContext ctxt locn) new_vars matches
; eqns_info <- mk_eqn_infos matches rhss_deltas
; result_expr <- handleWarnings $
matchEquations ctxt new_vars eqns_info rhs_ty
; return (new_vars, result_expr) }
where
-- rhss_deltas is a flat list, whereas there are multiple GRHSs per match.
-- mk_eqn_infos will thread rhss_deltas as state through calls to
-- mk_eqn_info, distributing each rhss_deltas to a GRHS.
mk_eqn_infos (L _ match : matches) rhss_deltas
= do { (info, rhss_deltas') <- mk_eqn_info match rhss_deltas
; infos <- mk_eqn_infos matches rhss_deltas'
; return (info:infos) }
mk_eqn_infos [] _ = return []
-- Called once per equation in the match, or alternative in the case
mk_eqn_info vars (L _ (Match { m_pats = pats, m_grhss = grhss }))
mk_eqn_info (Match { m_pats = pats, m_grhss = grhss }) rhss_deltas
| XGRHSs nec <- grhss = noExtCon nec
| GRHSs _ grhss' _ <- grhss, let n_grhss = length grhss'
= do { dflags <- getDynFlags
; let upats = map (unLoc . decideBangHood dflags) pats
dicts = collectEvVarsPats upats
; match_result <-
-- Extend the environment with knowledge about
-- the matches before desugaring the RHS
-- See Note [Type and Term Equality Propagation]
applyWhen (needToRunPmCheck dflags origin)
(addTyCsDs dicts . addScrutTmCs mb_scr vars . addPatTmCs upats vars)
(dsGRHSs ctxt grhss rhs_ty)
; return (EqnInfo { eqn_pats = upats
, eqn_orig = FromSource
, eqn_rhs = match_result }) }
mk_eqn_info _ (L _ (XMatch nec)) = noExtCon nec
-- Split off one Deltas for each GRHS of the current Match from the
-- flat list of GRHS Deltas *for all matches* (see the call to
-- checkMatches above).
; let (match_deltas, rhss_deltas') = splitAt n_grhss rhss_deltas
-- The list of Deltas is empty iff we don't perform any coverage
-- checking, in which case nonEmpty does the right thing by passing
-- Nothing.
; match_result <- dsGRHSs ctxt grhss rhs_ty (NEL.nonEmpty match_deltas)
; return ( EqnInfo { eqn_pats = upats
, eqn_orig = FromSource
, eqn_rhs = match_result }
, rhss_deltas' ) }
mk_eqn_info (XMatch nec) _ = noExtCon nec
handleWarnings = if isGenerated origin
then discardWarningsDs
......
......@@ -30,7 +30,7 @@ module GHC.HsToCore.Monad (
DsMetaEnv, DsMetaVal(..), dsGetMetaEnv, dsLookupMetaEnv, dsExtendMetaEnv,
-- Getting and setting pattern match oracle states
getPmDelta, updPmDelta,
getPmDeltas, updPmDeltas,
-- Get COMPLETE sets of a TyCon
dsGetCompleteMatches,
......@@ -282,7 +282,7 @@ mkDsEnvs dflags mod rdr_env type_env fam_inst_env msg_var cc_st_var
}
lcl_env = DsLclEnv { dsl_meta = emptyNameEnv
, dsl_loc = real_span
, dsl_delta = initDelta
, dsl_deltas = initDeltas
}
in (gbl_env, lcl_env)
......@@ -381,14 +381,14 @@ the @SrcSpan@ being carried around.
getGhcModeDs :: DsM GhcMode
getGhcModeDs = getDynFlags >>= return . ghcMode
-- | Get the current pattern match oracle state. See 'dsl_delta'.
getPmDelta :: DsM Delta
getPmDelta = do { env <- getLclEnv; return (dsl_delta env) }
-- | Get the current pattern match oracle state. See 'dsl_deltas'.
getPmDeltas :: DsM Deltas
getPmDeltas = do { env <- getLclEnv; return (dsl_deltas env) }
-- | Set the pattern match oracle state within the scope of the given action.
-- See 'dsl_delta'.
updPmDelta :: Delta -> DsM a -> DsM a
updPmDelta delta = updLclEnv (\env -> env { dsl_delta = delta })
-- See 'dsl_deltas'.
updPmDeltas :: Deltas -> DsM a -> DsM a
updPmDeltas delta = updLclEnv (\env -> env { dsl_deltas = delta })
getSrcSpanDs :: DsM SrcSpan
getSrcSpanDs = do { env <- getLclEnv
......
......@@ -17,7 +17,7 @@ module GHC.HsToCore.PmCheck (
needToRunPmCheck, isMatchContextPmChecked,
-- See Note [Type and Term Equality Propagation]
addTyCsDs, addScrutTmCs, addPatTmCs
addTyCsDs, addScrutTmCs
) where
#include "HsVersions.h"
......@@ -109,8 +109,8 @@ data PmGrd
-- | @PmLet x expr@ corresponds to a @let x = expr@ guard. This actually
-- /binds/ @x@.
| PmLet {
pm_id :: !Id,
pm_let_expr :: !CoreExpr
pm_id :: !Id,
_pm_let_expr :: !CoreExpr
}
-- | Should not be user-facing.
......@@ -160,10 +160,11 @@ data GrdTree
-- tree. 'redundantAndInaccessibleRhss' can figure out redundant and proper
-- inaccessible RHSs from this.
data AnnotatedTree
= AccessibleRhs !RhsInfo
-- ^ A RHS deemed accessible.
= AccessibleRhs !Deltas !RhsInfo
-- ^ A RHS deemed accessible. The 'Deltas' is the (non-empty) set of covered
-- values.
| InaccessibleRhs !RhsInfo
-- ^ A RHS deemed inaccessible; no value could possibly reach it.
-- ^ A RHS deemed inaccessible; it covers no value.
| MayDiverge !AnnotatedTree
-- ^ Asserts that the tree may force diverging values, so not all of its
-- clauses can be redundant.
......@@ -194,7 +195,7 @@ instance Outputable GrdTree where
ppr Empty = text "<empty case>"
instance Outputable AnnotatedTree where
ppr (AccessibleRhs info) = pprRhsInfo info
ppr (AccessibleRhs _ info) = pprRhsInfo info
ppr (InaccessibleRhs info) = text "inaccessible" <+> pprRhsInfo info
ppr (MayDiverge t) = text "div" <+> ppr t
-- Format nested Sequences in blocks "{ grds1; grds2; ... }"
......@@ -204,17 +205,6 @@ instance Outputable AnnotatedTree where
collect_seqs t = [ppr t]
ppr EmptyAnn = text "<empty case>"
newtype Deltas = MkDeltas (Bag Delta)
instance Outputable Deltas where
ppr (MkDeltas deltas) = ppr deltas
instance Semigroup Deltas where
MkDeltas l <> MkDeltas r = MkDeltas (l `unionBags` r)
liftDeltasM :: Monad m => (Delta -> m (Maybe Delta)) -> Deltas -> m Deltas
liftDeltasM f (MkDeltas ds) = MkDeltas . catBagMaybes <$> (traverse f ds)
-- | Lift 'addPmCts' over 'Deltas'.
addPmCtsDeltas :: Deltas -> PmCts -> DsM Deltas
addPmCtsDeltas deltas cts = liftDeltasM (\d -> addPmCts d cts) deltas
......@@ -272,7 +262,8 @@ checkSingle dflags ctxt@(DsMatchContext kind locn) var p = do
-- Omitting checking this flag emits redundancy warnings twice in obscure
-- cases like #17646.
when (exhaustive dflags kind) $ do
missing <- MkDeltas . unitBag <$> getPmDelta
-- TODO: This could probably call checkMatches, like checkGuardMatches.
missing <- getPmDeltas
tracePm "checkSingle: missing" (ppr missing)
fam_insts <- dsGetFamInstEnvs
grd_tree <- mkGrdTreeRhs (L locn $ ppr p) <$> translatePat fam_insts var p
......@@ -280,12 +271,13 @@ checkSingle dflags ctxt@(DsMatchContext kind locn) var p = do
dsPmWarn dflags ctxt [var] res
-- | Exhaustive for guard matches, is used for guards in pattern bindings and
-- in @MultiIf@ expressions.
checkGuardMatches :: HsMatchContext GhcRn -- Match context
-> GRHSs GhcTc (LHsExpr GhcTc) -- Guarded RHSs
-> DsM ()
-- in @MultiIf@ expressions. Returns the 'Deltas' covered by the RHSs.
checkGuardMatches
:: HsMatchContext GhcRn -- ^ Match context, for warning messages
-> GRHSs GhcTc (LHsExpr GhcTc) -- ^ The GRHSs to check
-> DsM [Deltas] -- ^ Covered 'Deltas' for each RHS, for long
-- distance info
checkGuardMatches hs_ctx guards@(GRHSs _ grhss _) = do
dflags <- getDynFlags
let combinedLoc = foldl1 combineSrcSpans (map getLoc grhss)
dsMatchContext = DsMatchContext hs_ctx combinedLoc
match = L combinedLoc $
......@@ -293,20 +285,37 @@ checkGuardMatches hs_ctx guards@(GRHSs _ grhss _) = do
, m_ctxt = hs_ctx
, m_pats = []
, m_grhss = guards }
checkMatches dflags dsMatchContext [] [match]
checkMatches dsMatchContext [] [match]
checkGuardMatches _ (XGRHSs nec) = noExtCon nec
-- | Check a matchgroup (case, functions, etc.)
checkMatches :: DynFlags -> DsMatchContext
-> [Id] -> [LMatch GhcTc (LHsExpr GhcTc)] -> DsM ()
checkMatches dflags ctxt vars matches = do
-- | Check a list of syntactic /match/es (part of case, functions, etc.), each
-- with a /pat/ and one or more /grhss/:
--
-- @
-- f x y | x == y = 1 -- match on x and y with two guarded RHSs
-- | otherwise = 2
-- f _ _ = 3 -- clause with a single, un-guarded RHS
-- @
--
-- Returns one 'Deltas' for each GRHS, representing its covered values, or the
-- incoming uncovered 'Deltas' (from 'getPmDeltas') if the GRHS is inaccessible.
-- Since there is at least one /grhs/ per /match/, the list of 'Deltas' is at
-- least as long as the list of matches.
checkMatches
:: DsMatchContext -- ^ Match context, for warnings messages
-> [Id] -- ^ Match variables, i.e. x and y above
-> [LMatch GhcTc (LHsExpr GhcTc)] -- ^ List of matches
-> DsM [Deltas] -- ^ One covered 'Deltas' per RHS, for long
-- distance info.
checkMatches ctxt vars matches = do
dflags <- getDynFlags
tracePm "checkMatches" (hang (vcat [ppr ctxt
, ppr vars
, text "Matches:"])
2
(vcat (map ppr matches)))
init_deltas <- MkDeltas . unitBag <$> getPmDelta
init_deltas <- getPmDeltas
missing <- case matches of
-- This must be an -XEmptyCase. See Note [Checking EmptyCase]
[] | [var] <- vars -> addPmCtDeltas init_deltas (PmNotBotCt var)
......@@ -317,6 +326,21 @@ checkMatches dflags ctxt vars matches = do
dsPmWarn dflags ctxt vars res
return (extractRhsDeltas init_deltas (cr_clauses res))
-- | Extract the 'Deltas' reaching the RHSs of the 'AnnotatedTree'.
-- For 'AccessibleRhs's, this is stored in the tree node, whereas
-- 'InaccessibleRhs's fall back to the supplied original 'Deltas'.
-- See @Note [Recovering from unsatisfiable pattern-matching constraints]@.
extractRhsDeltas :: Deltas -> AnnotatedTree -> [Deltas]
extractRhsDeltas orig_deltas = fromOL . go
where
go (AccessibleRhs deltas _) = unitOL deltas
go (InaccessibleRhs _) = unitOL orig_deltas
go (MayDiverge t) = go t
go (SequenceAnn l r) = go l Semi.<> go r
go EmptyAnn = nilOL
{- Note [Checking EmptyCase]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-XEmptyCase is useful for matching on empty data types like 'Void'. For example,
......@@ -920,7 +944,9 @@ checkGrdTree' :: GrdTree -> Deltas -> DsM CheckResult
-- RHS: Check that it covers something and wrap Inaccessible if not
checkGrdTree' (Rhs sdoc) deltas = do
is_covered <- isInhabited deltas
let clauses = if is_covered then AccessibleRhs sdoc else InaccessibleRhs sdoc
let clauses
| is_covered = AccessibleRhs deltas sdoc
| otherwise = InaccessibleRhs sdoc
pure CheckResult
{ cr_clauses = clauses
, cr_uncov = MkDeltas emptyBag
......@@ -1005,22 +1031,24 @@ f x = case x of
(_:_) -> True
[] -> False -- can't happen
Functions `addScrutTmCs' and `addPatTmCs' are responsible for generating
Functions `addScrutTmCs' is responsible for generating
these constraints.
-}
locallyExtendPmDelta :: (Delta -> DsM (Maybe Delta)) -> DsM a -> DsM a
locallyExtendPmDelta ext k = getPmDelta >>= ext >>= \case
locallyExtendPmDelta :: (Deltas -> DsM Deltas) -> DsM a -> DsM a
locallyExtendPmDelta ext k = getPmDeltas >>= ext >>= \deltas -> do
inh <- isInhabited deltas
-- If adding a constraint would lead to a contradiction, don't add it.
-- See @Note [Recovering from unsatisfiable pattern-matching constraints]@
-- for why this is done.
Nothing -> k
Just delta' -> updPmDelta delta' k
if inh
then updPmDeltas deltas k
else k
-- | Add in-scope type constraints
addTyCsDs :: Bag EvVar -> DsM a -> DsM a
addTyCsDs ev_vars =
locallyExtendPmDelta (\delta -> addPmCts delta (PmTyCt . evVarPred <$> ev_vars))
locallyExtendPmDelta (\deltas -> addPmCtsDeltas deltas (PmTyCt . evVarPred <$> ev_vars))
-- | Add equalities for the scrutinee to the local 'DsM' environment when
-- checking a case expression:
......@@ -1031,51 +1059,9 @@ addScrutTmCs :: Maybe (LHsExpr GhcTc) -> [Id] -> DsM a -> DsM a
addScrutTmCs Nothing _ k = k
addScrutTmCs (Just scr) [x] k = do
scr_e <- dsLExpr scr
locallyExtendPmDelta (\delta -> addPmCts delta (unitBag (PmCoreCt x scr_e))) k
locallyExtendPmDelta (\deltas -> addPmCtsDeltas deltas (unitBag (PmCoreCt x scr_e))) k
addScrutTmCs _ _ _ = panic "addScrutTmCs: HsCase with more than one case binder"
addPmConCts :: Delta -> Id -> PmAltCon -> [TyVar] -> [EvVar] -> [Id] -> DsM (Maybe Delta)
addPmConCts delta x con tvs dicts fields = runMaybeT $ do
delta_ty <- MaybeT $ addPmCts delta (listToBag (PmTyCt . evVarPred <$> dicts))
delta_tm_ty <- MaybeT $ addPmCts delta_ty (unitBag (PmConCt x con tvs fields))
pure delta_tm_ty
-- | Add equalities to the local 'DsM' environment when checking the RHS of a
-- case expression:
-- case e of x { p1 -> e1; ... pn -> en }
-- When we go deeper to check e.g. e1 we record (x ~ p1).
addPatTmCs :: [Pat GhcTc] -- LHS (should have length 1)
-> [Id] -- MatchVars (should have length 1)
-> DsM a
-> DsM a
-- Computes an approximation of the Covered set for p1 (which pmCheck currently
-- discards).
addPatTmCs ps xs k = do
fam_insts <- dsGetFamInstEnvs
grds <- concat <$> zipWithM (translatePat fam_insts) xs ps
locallyExtendPmDelta (\delta -> computeCovered grds delta) k
-- | A dead simple version of 'pmCheck' that only computes the Covered set.
-- So it only cares about collecting positive info.
-- We use it to collect info from a pattern when we check its RHS.
-- See 'addPatTmCs'.
computeCovered :: GrdVec -> Delta -> DsM (Maybe Delta)
-- The duplication with 'pmCheck' is really unfortunate, but it's simpler than
-- separating out the common cases with 'pmCheck', because that would make the
-- ConVar case harder to understand.
computeCovered [] delta = pure (Just delta)
computeCovered (PmLet { pm_id = x, pm_let_expr = e } : ps) delta = do
delta' <- expectJust "x is fresh" <$> addPmCts delta (unitBag (PmCoreCt x e))
computeCovered ps delta'
computeCovered (PmBang{} : ps) delta = do
computeCovered ps delta
computeCovered (p : ps) delta
| PmCon{ pm_id = x, pm_con_con = con, pm_con_tvs = tvs, pm_con_args = args
, pm_con_dicts = dicts } <- p
= addPmConCts delta x con tvs dicts args >>= \case
Nothing -> pure Nothing
Just delta' -> computeCovered ps delta'
{-
%************************************************************************
%* *
......@@ -1114,7 +1100,7 @@ redundantAndInaccessibleRhss tree = (fromOL ol_red, fromOL ol_inacc)
-- even safely delete the equation without altering semantics)
-- See Note [Determining inaccessible clauses]
go :: AnnotatedTree -> (OrdList RhsInfo, OrdList RhsInfo, OrdList RhsInfo)
go (AccessibleRhs info) = (unitOL info, nilOL, nilOL)
go (AccessibleRhs _ info) = (unitOL info, nilOL, nilOL)
go (InaccessibleRhs info) = (nilOL, nilOL, unitOL info) -- presumably redundant
go (MayDiverge t) = case go t of
-- See Note [Determining inaccessible clauses]
......
......@@ -12,7 +12,7 @@ Authors: George Karachalias <george.karachalias@cs.kuleuven.be>
module GHC.HsToCore.PmCheck.Oracle (
DsM, tracePm, mkPmId,
Delta, initDelta, lookupRefuts, lookupSolution,
Delta, initDeltas, lookupRefuts, lookupSolution,
PmCt(PmTyCt), PmCts, pattern PmVarCt, pattern PmCoreCt,
pattern PmConCt, pattern PmNotConCt, pattern PmBotCt,
......
......@@ -29,7 +29,8 @@ module GHC.HsToCore.PmCheck.Types (
setIndirectSDIE, setEntrySDIE, traverseSDIE,
-- * The pattern match oracle
VarInfo(..), TmState(..), TyState(..), Delta(..), initDelta
VarInfo(..), TmState(..), TyState(..), Delta(..),
Deltas(..), initDeltas, liftDeltasM
) where
#include "HsVersions.h"
......@@ -64,6 +65,7 @@ import Numeric (fromRat)
import Data.Foldable (find)
import qualified Data.List.NonEmpty as NonEmpty
import Data.Ratio
import qualified Data.Semigroup as Semi
-- | Literals (simple and overloaded ones) for pattern match checking.
--
......@@ -520,8 +522,7 @@ instance Outputable TyState where
initTyState :: TyState
initTyState = TySt emptyBag
-- | Term and type constraints to accompany each value vector abstraction.
-- For efficiency, we store the term oracle state instead of the term
-- | An inert set of canonical (i.e. mutually compatible) term and type
-- constraints.
data Delta = MkDelta { delta_ty_st :: TyState -- Type oracle; things like a~Int
, delta_tm_st :: TmState } -- Term oracle; things like x~Nothing
......@@ -537,3 +538,18 @@ instance Outputable Delta where
ppr (delta_tm_st delta),
ppr (delta_ty_st delta)
]
-- | A disjunctive bag of 'Delta's, representing a refinement type.
newtype Deltas = MkDeltas (Bag Delta)
initDeltas :: Deltas
initDeltas = MkDeltas (unitBag initDelta)
instance Outputable Deltas where
ppr (MkDeltas deltas) = ppr deltas
instance Semigroup Deltas where
MkDeltas l <> MkDeltas r = MkDeltas (l `unionBags` r)
liftDeltasM :: Monad m => (Delta -> m (Maybe Delta)) -> Deltas -> m Deltas
liftDeltasM f (MkDeltas ds) = MkDeltas . catBagMaybes <$> (traverse f ds)