Commit 00cbbab3 authored by eir@cis.upenn.edu's avatar eir@cis.upenn.edu

Refactor the typechecker to use ExpTypes.

The idea here is described in [wiki:Typechecker]. Briefly,
this refactor keeps solid track of "synthesis" mode vs
"checking" in GHC's bidirectional type-checking algorithm.
When in synthesis mode, the expected type is just an IORef
to write to.

In addition, this patch does a significant reworking of
RebindableSyntax, allowing much more freedom in the types
of the rebindable operators. For example, we can now have
`negate :: Int -> Bool` and
`(>>=) :: m a -> (forall x. a x -> m b) -> m b`. The magic
is in tcSyntaxOp.

This addresses tickets #11397, #11452, and #11458.

Tests:
  typecheck/should_compile/{RebindHR,RebindNegate,T11397,T11458}
  th/T11452
parent 2899aa58
...@@ -31,6 +31,7 @@ import Id ...@@ -31,6 +31,7 @@ import Id
import ConLike import ConLike
import DataCon import DataCon
import Name import Name
import FamInstEnv
import TysWiredIn import TysWiredIn
import TyCon import TyCon
import SrcLoc import SrcLoc
...@@ -148,7 +149,8 @@ type PmResult = ( [[LPat Id]] ...@@ -148,7 +149,8 @@ type PmResult = ( [[LPat Id]]
checkSingle :: Id -> Pat Id -> DsM PmResult checkSingle :: Id -> Pat Id -> DsM PmResult
checkSingle var p = do checkSingle var p = do
let lp = [noLoc p] let lp = [noLoc p]
vec <- liftUs (translatePat p) fam_insts <- dsGetFamInstEnvs
vec <- liftUs (translatePat fam_insts p)
vsa <- initial_uncovered [var] vsa <- initial_uncovered [var]
(c,d,us') <- patVectProc False (vec,[]) vsa -- no guards (c,d,us') <- patVectProc False (vec,[]) vsa -- no guards
us <- pruneVSA us' us <- pruneVSA us'
...@@ -171,7 +173,8 @@ checkMatches oversimplify vars matches ...@@ -171,7 +173,8 @@ checkMatches oversimplify vars matches
return ([], [], missing') return ([], [], missing')
go (m:ms) missing = do go (m:ms) missing = do
clause <- liftUs (translateMatch m) fam_insts <- dsGetFamInstEnvs
clause <- liftUs (translateMatch fam_insts m)
(c, d, us ) <- patVectProc oversimplify clause missing (c, d, us ) <- patVectProc oversimplify clause missing
(rs, is, us') <- go ms us (rs, is, us') <- go ms us
return $ case (c,d) of return $ case (c,d) of
...@@ -209,7 +212,8 @@ noFailingGuards clauses = sum [ countPatVecs gvs | (_, gvs) <- clauses ] ...@@ -209,7 +212,8 @@ noFailingGuards clauses = sum [ countPatVecs gvs | (_, gvs) <- clauses ]
computeNoGuards :: [LMatch Id (LHsExpr Id)] -> PmM Int computeNoGuards :: [LMatch Id (LHsExpr Id)] -> PmM Int
computeNoGuards matches = do computeNoGuards matches = do
matches' <- mapM (liftUs . translateMatch) matches fam_insts <- dsGetFamInstEnvs
matches' <- mapM (liftUs . translateMatch fam_insts) matches
return (noFailingGuards matches') return (noFailingGuards matches')
maximum_failing_guards :: Int maximum_failing_guards :: Int
...@@ -264,46 +268,47 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit } ...@@ -264,46 +268,47 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
-- ----------------------------------------------------------------------- -- -----------------------------------------------------------------------
-- * Transform (Pat Id) into of (PmPat Id) -- * Transform (Pat Id) into of (PmPat Id)
translatePat :: Pat Id -> UniqSM PatVec translatePat :: FamInstEnvs -> Pat Id -> UniqSM PatVec
translatePat pat = case pat of translatePat fam_insts pat = case pat of
WildPat ty -> mkPmVarsSM [ty] WildPat ty -> mkPmVarsSM [ty]
VarPat id -> return [PmVar (unLoc id)] VarPat id -> return [PmVar (unLoc id)]
ParPat p -> translatePat (unLoc p) ParPat p -> translatePat fam_insts (unLoc p)
LazyPat _ -> mkPmVarsSM [hsPatType pat] -- like a variable LazyPat _ -> mkPmVarsSM [hsPatType pat] -- like a variable
-- ignore strictness annotations for now -- ignore strictness annotations for now
BangPat p -> translatePat (unLoc p) BangPat p -> translatePat fam_insts (unLoc p)
AsPat lid p -> do AsPat lid p -> do
-- Note [Translating As Patterns] -- Note [Translating As Patterns]
ps <- translatePat (unLoc p) ps <- translatePat fam_insts (unLoc p)
let [e] = map valAbsToPmExpr (coercePatVec ps) let [e] = map valAbsToPmExpr (coercePatVec ps)
g = PmGrd [PmVar (unLoc lid)] e g = PmGrd [PmVar (unLoc lid)] e
return (ps ++ [g]) return (ps ++ [g])
SigPatOut p _ty -> translatePat (unLoc p) SigPatOut p _ty -> translatePat fam_insts (unLoc p)
-- See Note [Translate CoPats] -- See Note [Translate CoPats]
CoPat wrapper p ty CoPat wrapper p ty
| isIdHsWrapper wrapper -> translatePat p | isIdHsWrapper wrapper -> translatePat fam_insts p
| WpCast co <- wrapper, isReflexiveCo co -> translatePat p | WpCast co <- wrapper, isReflexiveCo co -> translatePat fam_insts p
| otherwise -> do | otherwise -> do
ps <- translatePat p ps <- translatePat fam_insts p
(xp,xe) <- mkPmId2FormsSM ty (xp,xe) <- mkPmId2FormsSM ty
let g = mkGuard ps (HsWrap wrapper (unLoc xe)) let g = mkGuard ps (HsWrap wrapper (unLoc xe))
return [xp,g] return [xp,g]
-- (n + k) ===> x (True <- x >= k) (n <- x-k) -- (n + k) ===> x (True <- x >= k) (n <- x-k)
NPlusKPat (L _ n) k ge minus -> do NPlusKPat (L _ n) k1 k2 ge minus ty -> do
(xp, xe) <- mkPmId2FormsSM (idType n) (xp, xe) <- mkPmId2FormsSM ty
let ke = L (getLoc k) (HsOverLit (unLoc k)) let ke1 = L (getLoc k1) (HsOverLit (unLoc k1))
g1 = mkGuard [truePattern] (OpApp xe (noLoc ge) no_fixity ke) ke2 = L (getLoc k1) (HsOverLit k2)
g2 = mkGuard [PmVar n] (OpApp xe (noLoc minus) no_fixity ke) g1 = mkGuard [truePattern] (unLoc $ nlHsSyntaxApps ge [xe, ke1])
g2 = mkGuard [PmVar n] (unLoc $ nlHsSyntaxApps minus [xe, ke2])
return [xp, g1, g2] return [xp, g1, g2]
-- (fun -> pat) ===> x (pat <- fun x) -- (fun -> pat) ===> x (pat <- fun x)
ViewPat lexpr lpat arg_ty -> do ViewPat lexpr lpat arg_ty -> do
ps <- translatePat (unLoc lpat) ps <- translatePat fam_insts (unLoc lpat)
-- See Note [Guards and Approximation] -- See Note [Guards and Approximation]
case all cantFailPattern ps of case all cantFailPattern ps of
True -> do True -> do
...@@ -316,15 +321,18 @@ translatePat pat = case pat of ...@@ -316,15 +321,18 @@ translatePat pat = case pat of
-- list -- list
ListPat ps ty Nothing -> do ListPat ps ty Nothing -> do
foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec (map unLoc ps) foldr (mkListPatVec ty) [nilPattern ty] <$> translatePatVec fam_insts (map unLoc ps)
-- overloaded list -- overloaded list
ListPat lpats elem_ty (Just (pat_ty, _to_list)) ListPat lpats elem_ty (Just (pat_ty, _to_list))
| Just e_ty <- splitListTyConApp_maybe pat_ty, elem_ty `eqType` e_ty -> | Just e_ty <- splitListTyConApp_maybe pat_ty
, (_, norm_elem_ty) <- normaliseType fam_insts Nominal elem_ty
-- elem_ty is frequently something like `Item [Int]`, but we prefer `Int`
, norm_elem_ty `eqType` e_ty ->
-- We have to ensure that the element types are exactly the same. -- We have to ensure that the element types are exactly the same.
-- Otherwise, one may give an instance IsList [Int] (more specific than -- Otherwise, one may give an instance IsList [Int] (more specific than
-- the default IsList [a]) with a different implementation for `toList' -- the default IsList [a]) with a different implementation for `toList'
translatePat (ListPat lpats e_ty Nothing) translatePat fam_insts (ListPat lpats e_ty Nothing)
| otherwise -> do | otherwise -> do
-- See Note [Guards and Approximation] -- See Note [Guards and Approximation]
var <- mkPmVarSM pat_ty var <- mkPmVarSM pat_ty
...@@ -345,29 +353,29 @@ translatePat pat = case pat of ...@@ -345,29 +353,29 @@ translatePat pat = case pat of
, pat_tvs = ex_tvs , pat_tvs = ex_tvs
, pat_dicts = dicts , pat_dicts = dicts
, pat_args = ps } -> do , pat_args = ps } -> do
args <- translateConPatVec arg_tys ex_tvs con ps args <- translateConPatVec fam_insts arg_tys ex_tvs con ps
return [PmCon { pm_con_con = con return [PmCon { pm_con_con = con
, pm_con_arg_tys = arg_tys , pm_con_arg_tys = arg_tys
, pm_con_tvs = ex_tvs , pm_con_tvs = ex_tvs
, pm_con_dicts = dicts , pm_con_dicts = dicts
, pm_con_args = args }] , pm_con_args = args }]
NPat (L _ ol) mb_neg _eq -> translateNPat ol mb_neg NPat (L _ ol) mb_neg _eq ty -> translateNPat fam_insts ol mb_neg ty
LitPat lit LitPat lit
-- If it is a string then convert it to a list of characters -- If it is a string then convert it to a list of characters
| HsString src s <- lit -> | HsString src s <- lit ->
foldr (mkListPatVec charTy) [nilPattern charTy] <$> foldr (mkListPatVec charTy) [nilPattern charTy] <$>
translatePatVec (map (LitPat . HsChar src) (unpackFS s)) translatePatVec fam_insts (map (LitPat . HsChar src) (unpackFS s))
| otherwise -> return [mkLitPattern lit] | otherwise -> return [mkLitPattern lit]
PArrPat ps ty -> do PArrPat ps ty -> do
tidy_ps <- translatePatVec (map unLoc ps) tidy_ps <- translatePatVec fam_insts (map unLoc ps)
let fake_con = parrFakeCon (length ps) let fake_con = parrFakeCon (length ps)
return [vanillaConPattern fake_con [ty] (concat tidy_ps)] return [vanillaConPattern fake_con [ty] (concat tidy_ps)]
TuplePat ps boxity tys -> do TuplePat ps boxity tys -> do
tidy_ps <- translatePatVec (map unLoc ps) tidy_ps <- translatePatVec fam_insts (map unLoc ps)
let tuple_con = tupleDataCon boxity (length ps) let tuple_con = tupleDataCon boxity (length ps)
return [vanillaConPattern tuple_con tys (concat tidy_ps)] return [vanillaConPattern tuple_con tys (concat tidy_ps)]
...@@ -378,33 +386,35 @@ translatePat pat = case pat of ...@@ -378,33 +386,35 @@ translatePat pat = case pat of
SigPatIn {} -> panic "Check.translatePat: SigPatIn" SigPatIn {} -> panic "Check.translatePat: SigPatIn"
-- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs) -- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs)
translateNPat :: HsOverLit Id -> Maybe (SyntaxExpr Id) -> UniqSM PatVec translateNPat :: FamInstEnvs
translateNPat (OverLit val False _ ty) mb_neg -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> UniqSM PatVec
| isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty
= translatePat (LitPat (HsString src s)) | not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
| isIntTy ty, HsIntegral src i <- val = translatePat fam_insts (LitPat (HsString src s))
= translatePat (mk_num_lit HsInt src i) | not type_change, isIntTy ty, HsIntegral src i <- val
| isWordTy ty, HsIntegral src i <- val = translatePat fam_insts (mk_num_lit HsInt src i)
= translatePat (mk_num_lit HsWordPrim src i) | not type_change, isWordTy ty, HsIntegral src i <- val
= translatePat fam_insts (mk_num_lit HsWordPrim src i)
where where
type_change = not (outer_ty `eqType` ty)
mk_num_lit c src i = LitPat $ case mb_neg of mk_num_lit c src i = LitPat $ case mb_neg of
Nothing -> c src i Nothing -> c src i
Just _ -> c src (-i) Just _ -> c src (-i)
translateNPat ol mb_neg translateNPat _ ol mb_neg _
= return [PmLit { pm_lit_lit = PmOLit (isJust mb_neg) ol }] = return [PmLit { pm_lit_lit = PmOLit (isJust mb_neg) ol }]
-- | Translate a list of patterns (Note: each pattern is translated -- | Translate a list of patterns (Note: each pattern is translated
-- to a pattern vector but we do not concatenate the results). -- to a pattern vector but we do not concatenate the results).
translatePatVec :: [Pat Id] -> UniqSM [PatVec] translatePatVec :: FamInstEnvs -> [Pat Id] -> UniqSM [PatVec]
translatePatVec pats = mapM translatePat pats translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
translateConPatVec :: [Type] -> [TyVar] translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
-> DataCon -> HsConPatDetails Id -> UniqSM PatVec -> DataCon -> HsConPatDetails Id -> UniqSM PatVec
translateConPatVec _univ_tys _ex_tvs _ (PrefixCon ps) translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec (map unLoc ps) = concat <$> translatePatVec fam_insts (map unLoc ps)
translateConPatVec _univ_tys _ex_tvs _ (InfixCon p1 p2) translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
= concat <$> translatePatVec (map unLoc [p1,p2]) = concat <$> translatePatVec fam_insts (map unLoc [p1,p2])
translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
-- Nothing matched. Make up some fresh term variables -- Nothing matched. Make up some fresh term variables
| null fs = mkPmVarsSM arg_tys | null fs = mkPmVarsSM arg_tys
-- The data constructor was not defined using record syntax. For the -- The data constructor was not defined using record syntax. For the
...@@ -417,7 +427,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) ...@@ -417,7 +427,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| matched_lbls `subsetOf` orig_lbls | matched_lbls `subsetOf` orig_lbls
= ASSERT(length orig_lbls == length arg_tys) = ASSERT(length orig_lbls == length arg_tys)
let translateOne (lbl, ty) = case lookup lbl matched_pats of let translateOne (lbl, ty) = case lookup lbl matched_pats of
Just p -> translatePat p Just p -> translatePat fam_insts p
Nothing -> mkPmVarsSM [ty] Nothing -> mkPmVarsSM [ty]
in concatMapM translateOne (zip orig_lbls arg_tys) in concatMapM translateOne (zip orig_lbls arg_tys)
-- The fields that appear are not in the correct order. Make up fresh -- The fields that appear are not in the correct order. Make up fresh
...@@ -426,7 +436,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) ...@@ -426,7 +436,7 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| otherwise = do | otherwise = do
arg_var_pats <- mkPmVarsSM arg_tys arg_var_pats <- mkPmVarsSM arg_tys
translated_pats <- forM matched_pats $ \(x,pat) -> do translated_pats <- forM matched_pats $ \(x,pat) -> do
pvec <- translatePat pat pvec <- translatePat fam_insts pat
return (x, pvec) return (x, pvec)
let zipped = zip orig_lbls [ x | PmVar x <- arg_var_pats ] let zipped = zip orig_lbls [ x | PmVar x <- arg_var_pats ]
...@@ -453,10 +463,10 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _)) ...@@ -453,10 +463,10 @@ translateConPatVec univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| x == y = subsetOf xs ys | x == y = subsetOf xs ys
| otherwise = subsetOf (x:xs) ys | otherwise = subsetOf (x:xs) ys
translateMatch :: LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec]) translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> UniqSM (PatVec,[PatVec])
translateMatch (L _ (Match _ lpats _ grhss)) = do translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
pats' <- concat <$> translatePatVec pats pats' <- concat <$> translatePatVec fam_insts pats
guards' <- mapM translateGuards guards guards' <- mapM (translateGuards fam_insts) guards
return (pats', guards') return (pats', guards')
where where
extractGuards :: LGRHS Id (LHsExpr Id) -> [GuardStmt Id] extractGuards :: LGRHS Id (LHsExpr Id) -> [GuardStmt Id]
...@@ -469,9 +479,9 @@ translateMatch (L _ (Match _ lpats _ grhss)) = do ...@@ -469,9 +479,9 @@ translateMatch (L _ (Match _ lpats _ grhss)) = do
-- * Transform source guards (GuardStmt Id) to PmPats (Pattern) -- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
-- | Translate a list of guard statements to a pattern vector -- | Translate a list of guard statements to a pattern vector
translateGuards :: [GuardStmt Id] -> UniqSM PatVec translateGuards :: FamInstEnvs -> [GuardStmt Id] -> UniqSM PatVec
translateGuards guards = do translateGuards fam_insts guards = do
all_guards <- concat <$> mapM translateGuard guards all_guards <- concat <$> mapM (translateGuard fam_insts) guards
return (replace_unhandled all_guards) return (replace_unhandled all_guards)
-- It should have been (return $ all_guards) but it is too expressive. -- It should have been (return $ all_guards) but it is too expressive.
-- Since the term oracle does not handle all constraints we generate, -- Since the term oracle does not handle all constraints we generate,
...@@ -509,24 +519,24 @@ cantFailPattern (PmGrd pv _e) ...@@ -509,24 +519,24 @@ cantFailPattern (PmGrd pv _e)
cantFailPattern _ = False cantFailPattern _ = False
-- | Translate a guard statement to Pattern -- | Translate a guard statement to Pattern
translateGuard :: GuardStmt Id -> UniqSM PatVec translateGuard :: FamInstEnvs -> GuardStmt Id -> UniqSM PatVec
translateGuard (BodyStmt e _ _ _) = translateBoolGuard e translateGuard _ (BodyStmt e _ _ _) = translateBoolGuard e
translateGuard (LetStmt binds) = translateLet (unLoc binds) translateGuard _ (LetStmt binds) = translateLet (unLoc binds)
translateGuard (BindStmt p e _ _) = translateBind p e translateGuard fam_insts (BindStmt p e _ _ _) = translateBind fam_insts p e
translateGuard (LastStmt {}) = panic "translateGuard LastStmt" translateGuard _ (LastStmt {}) = panic "translateGuard LastStmt"
translateGuard (ParStmt {}) = panic "translateGuard ParStmt" translateGuard _ (ParStmt {}) = panic "translateGuard ParStmt"
translateGuard (TransStmt {}) = panic "translateGuard TransStmt" translateGuard _ (TransStmt {}) = panic "translateGuard TransStmt"
translateGuard (RecStmt {}) = panic "translateGuard RecStmt" translateGuard _ (RecStmt {}) = panic "translateGuard RecStmt"
translateGuard (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt" translateGuard _ (ApplicativeStmt {}) = panic "translateGuard ApplicativeLastStmt"
-- | Translate let-bindings -- | Translate let-bindings
translateLet :: HsLocalBinds Id -> UniqSM PatVec translateLet :: HsLocalBinds Id -> UniqSM PatVec
translateLet _binds = return [] translateLet _binds = return []
-- | Translate a pattern guard -- | Translate a pattern guard
translateBind :: LPat Id -> LHsExpr Id -> UniqSM PatVec translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> UniqSM PatVec
translateBind (L _ p) e = do translateBind fam_insts (L _ p) e = do
ps <- translatePat p ps <- translatePat fam_insts p
return [mkGuard ps (unLoc e)] return [mkGuard ps (unLoc e)]
-- | Translate a boolean guard -- | Translate a boolean guard
...@@ -600,7 +610,8 @@ below is the *right thing to do*: ...@@ -600,7 +610,8 @@ below is the *right thing to do*:
The case with literals is a bit different. a literal @l@ should be translated The case with literals is a bit different. a literal @l@ should be translated
to @x (True <- x == from l)@. Since we want to have better warnings for to @x (True <- x == from l)@. Since we want to have better warnings for
overloaded literals as it is a very common feature, we treat them differently. overloaded literals as it is a very common feature, we treat them differently.
They are mainly covered in Note [Undecidable Equality on Overloaded Literals]. They are mainly covered in Note [Undecidable Equality on Overloaded Literals]
in PmExpr.
4. N+K Patterns & Pattern Synonyms 4. N+K Patterns & Pattern Synonyms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -845,9 +856,6 @@ coercePmPat (PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys ...@@ -845,9 +856,6 @@ coercePmPat (PmCon { pm_con_con = con, pm_con_arg_tys = arg_tys
, pm_con_args = coercePatVec args }] , pm_con_args = coercePatVec args }]
coercePmPat (PmGrd {}) = [] -- drop the guards coercePmPat (PmGrd {}) = [] -- drop the guards
no_fixity :: a -- TODO: Can we retrieve the fixity from the operator name?
no_fixity = panic "Check: no fixity"
-- Get all constructors in the family (including given) -- Get all constructors in the family (including given)
allConstructors :: DataCon -> [DataCon] allConstructors :: DataCon -> [DataCon]
allConstructors = tyConDataCons . dataConTyCon allConstructors = tyConDataCons . dataConTyCon
...@@ -1101,7 +1109,7 @@ cMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps ...@@ -1101,7 +1109,7 @@ cMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- CLitLit -- CLitLit
cMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of cMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals] -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` covered us gvsa ps vsa -- match True -> va `mkCons` covered us gvsa ps vsa -- match
False -> Empty -- mismatch False -> Empty -- mismatch
...@@ -1172,7 +1180,7 @@ uMatcher us gvsa ( p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps ...@@ -1172,7 +1180,7 @@ uMatcher us gvsa ( p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- ULitLit -- ULitLit
uMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of uMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals] -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` uncovered us gvsa ps vsa -- match True -> va `mkCons` uncovered us gvsa ps vsa -- match
False -> va `mkCons` vsa -- mismatch False -> va `mkCons` vsa -- mismatch
...@@ -1256,7 +1264,7 @@ dMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps ...@@ -1256,7 +1264,7 @@ dMatcher us gvsa (p@(PmCon { pm_con_con = c1, pm_con_args = args1 })) ps
-- DLitLit -- DLitLit
dMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of dMatcher us gvsa (PmLit l1) ps (va@(PmLit l2)) vsa = case eqPmLit l1 l2 of
-- See Note [Undecidable Equality for Overloaded Literals] -- See Note [Undecidable Equality for Overloaded Literals] in PmExpr
True -> va `mkCons` divergent us gvsa ps vsa -- match True -> va `mkCons` divergent us gvsa ps vsa -- match
False -> Empty -- mismatch False -> Empty -- mismatch
...@@ -1331,10 +1339,12 @@ genCaseTmCs2 :: Maybe (LHsExpr Id) -- Scrutinee ...@@ -1331,10 +1339,12 @@ genCaseTmCs2 :: Maybe (LHsExpr Id) -- Scrutinee
-> [Id] -- MatchVars (should have length 1) -> [Id] -- MatchVars (should have length 1)
-> DsM (Bag SimpleEq) -> DsM (Bag SimpleEq)
genCaseTmCs2 Nothing _ _ = return emptyBag genCaseTmCs2 Nothing _ _ = return emptyBag
genCaseTmCs2 (Just scr) [p] [var] = liftUs $ do genCaseTmCs2 (Just scr) [p] [var] = do
[e] <- map valAbsToPmExpr . coercePatVec <$> translatePat p fam_insts <- dsGetFamInstEnvs
let scr_e = lhsExprToPmExpr scr liftUs $ do
return $ listToBag [(var, e), (var, scr_e)] [e] <- map valAbsToPmExpr . coercePatVec <$> translatePat fam_insts p
let scr_e = lhsExprToPmExpr scr
return $ listToBag [(var, e), (var, scr_e)]
genCaseTmCs2 _ _ _ = panic "genCaseTmCs2: HsCase" genCaseTmCs2 _ _ _ = panic "genCaseTmCs2: HsCase"
-- | Generate a simple equality when checking a case expression: -- | Generate a simple equality when checking a case expression:
......
...@@ -592,8 +592,9 @@ addTickHsExpr (ExplicitList ty wit es) = ...@@ -592,8 +592,9 @@ addTickHsExpr (ExplicitList ty wit es) =
(addTickWit wit) (addTickWit wit)
(mapM (addTickLHsExpr) es) (mapM (addTickLHsExpr) es)
where addTickWit Nothing = return Nothing where addTickWit Nothing = return Nothing
addTickWit (Just fln) = do fln' <- addTickHsExpr fln addTickWit (Just fln)
return (Just fln') = do fln' <- addTickSyntaxExpr hpcSrcSpan fln
return (Just fln')
addTickHsExpr (ExplicitPArr ty es) = addTickHsExpr (ExplicitPArr ty es) =
liftM2 ExplicitPArr liftM2 ExplicitPArr
(return ty) (return ty)
...@@ -621,7 +622,7 @@ addTickHsExpr (ArithSeq ty wit arith_seq) = ...@@ -621,7 +622,7 @@ addTickHsExpr (ArithSeq ty wit arith_seq) =
(addTickWit wit) (addTickWit wit)
(addTickArithSeqInfo arith_seq) (addTickArithSeqInfo arith_seq)
where addTickWit Nothing = return Nothing where addTickWit Nothing = return Nothing
addTickWit (Just fl) = do fl' <- addTickHsExpr fl addTickWit (Just fl) = do fl' <- addTickSyntaxExpr hpcSrcSpan fl
return (Just fl') return (Just fl')
-- We might encounter existing ticks (multiple Coverage passes) -- We might encounter existing ticks (multiple Coverage passes)
...@@ -732,12 +733,13 @@ addTickStmt _isGuard (LastStmt e noret ret) = do ...@@ -732,12 +733,13 @@ addTickStmt _isGuard (LastStmt e noret ret) = do
(addTickLHsExpr e) (addTickLHsExpr e)
(pure noret) (pure noret)
(addTickSyntaxExpr hpcSrcSpan ret) (addTickSyntaxExpr hpcSrcSpan ret)
addTickStmt _isGuard (BindStmt pat e bind fail) = do addTickStmt _isGuard (BindStmt pat e bind fail ty) = do
liftM4 BindStmt liftM5 BindStmt
(addTickLPat pat) (addTickLPat pat)
(addTickLHsExprRHS e) (addTickLHsExprRHS e)
(addTickSyntaxExpr hpcSrcSpan bind) (addTickSyntaxExpr hpcSrcSpan bind)
(addTickSyntaxExpr hpcSrcSpan fail) (addTickSyntaxExpr hpcSrcSpan fail)
(return ty)
addTickStmt isGuard (BodyStmt e bind' guard' ty) = do addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
liftM4 BodyStmt liftM4 BodyStmt
(addTick isGuard e) (addTick isGuard e)
...@@ -747,11 +749,12 @@ addTickStmt isGuard (BodyStmt e bind' guard' ty) = do ...@@ -747,11 +749,12 @@ addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
addTickStmt _isGuard (LetStmt (L l binds)) = do addTickStmt _isGuard (LetStmt (L l binds)) = do
liftM (LetStmt . L l) liftM (LetStmt . L l)
(addTickHsLocalBinds binds) (addTickHsLocalBinds binds)
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr ty) = do
liftM3 ParStmt liftM4 ParStmt
(mapM (addTickStmtAndBinders isGuard) pairs) (mapM (addTickStmtAndBinders isGuard) pairs)
(addTickSyntaxExpr hpcSrcSpan mzipExpr) (unLoc <$> addTickLHsExpr (L hpcSrcSpan mzipExpr))
(addTickSyntaxExpr hpcSrcSpan bindExpr) (addTickSyntaxExpr hpcSrcSpan bindExpr)
(return ty)
addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do
args' <- mapM (addTickApplicativeArg isGuard) args args' <- mapM (addTickApplicativeArg isGuard) args
return (ApplicativeStmt args' mb_join body_ty) return (ApplicativeStmt args' mb_join body_ty)
...@@ -765,7 +768,7 @@ addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts ...@@ -765,7 +768,7 @@ addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
t_u <- addTickLHsExprRHS using t_u <- addTickLHsExprRHS using
t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr
t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr
t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr L _ t_m <- addTickLHsExpr