Commit 0805ed7e authored by John Ericson's avatar John Ericson Committed by Marge Bot

Use non-empty lists to remove partiality in matching code

parent 7aa4a061
Pipeline #14322 failed with stages
in 744 minutes and 2 seconds
......@@ -84,6 +84,8 @@ import qualified GHC.LanguageExtensions as LangExt
import TcEvidence
import Control.Monad ( zipWithM )
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NEL
{-
************************************************************************
......@@ -186,9 +188,9 @@ worthy of a type synonym and a few handy functions.
firstPat :: EquationInfo -> Pat GhcTc
firstPat eqn = ASSERT( notNull (eqn_pats eqn) ) head (eqn_pats eqn)
shiftEqns :: [EquationInfo] -> [EquationInfo]
shiftEqns :: Functor f => f EquationInfo -> f EquationInfo
-- Drop the first pattern in each equation
shiftEqns eqns = [ eqn { eqn_pats = tail (eqn_pats eqn) } | eqn <- eqns ]
shiftEqns = fmap $ \eqn -> eqn { eqn_pats = tail (eqn_pats eqn) }
-- Functions on MatchResults
......@@ -286,13 +288,13 @@ data CaseAlt a = MkCaseAlt{ alt_pat :: a,
alt_result :: MatchResult }
mkCoAlgCaseMatchResult
:: Id -- Scrutinee
-> Type -- Type of exp
-> [CaseAlt DataCon] -- Alternatives (bndrs *include* tyvars, dicts)
:: Id -- ^ Scrutinee
-> Type -- ^ Type of exp
-> NonEmpty (CaseAlt DataCon) -- ^ Alternatives (bndrs *include* tyvars, dicts)
-> MatchResult
mkCoAlgCaseMatchResult var ty match_alts
| isNewtype -- Newtype case; use a let
= ASSERT( null (tail match_alts) && null (tail arg_ids1) )
= ASSERT( null match_alts_tail && null (tail arg_ids1) )
mkCoLetMatchResult (NonRec arg_id1 newtype_rhs) match_result1
| otherwise
......@@ -303,8 +305,8 @@ mkCoAlgCaseMatchResult var ty match_alts
-- [Interesting: because of GADTs, we can't rely on the type of
-- the scrutinised Id to be sufficiently refined to have a TyCon in it]
alt1@MkCaseAlt{ alt_bndrs = arg_ids1, alt_result = match_result1 }
= ASSERT( notNull match_alts ) head match_alts
alt1@MkCaseAlt{ alt_bndrs = arg_ids1, alt_result = match_result1 } :| match_alts_tail
= match_alts
-- Stuff for newtype
arg_id1 = ASSERT( notNull arg_ids1 ) head arg_ids1
var_ty = idType var
......@@ -315,9 +317,6 @@ mkCoAlgCaseMatchResult var ty match_alts
mkCoSynCaseMatchResult :: Id -> Type -> CaseAlt PatSyn -> MatchResult
mkCoSynCaseMatchResult var ty alt = MatchResult CanFail $ mkPatSynCase var ty alt
sort_alts :: [CaseAlt DataCon] -> [CaseAlt DataCon]
sort_alts = sortWith (dataConTag . alt_pat)
mkPatSynCase :: Id -> Type -> CaseAlt PatSyn -> CoreExpr -> DsM CoreExpr
mkPatSynCase var ty alt fail = do
matcher <- dsLExpr $ mkLHsWrap wrapper $
......@@ -337,17 +336,16 @@ mkPatSynCase var ty alt fail = do
ensure_unstrict cont | needs_void_lam = Lam voidArgId cont
| otherwise = cont
mkDataConCase :: Id -> Type -> [CaseAlt DataCon] -> MatchResult
mkDataConCase _ _ [] = panic "mkDataConCase: no alternatives"
mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case
mkDataConCase :: Id -> Type -> NonEmpty (CaseAlt DataCon) -> MatchResult
mkDataConCase var ty alts@(alt1 :| _) = MatchResult fail_flag mk_case
where
con1 = alt_pat alt1
tycon = dataConTyCon con1
data_cons = tyConDataCons tycon
match_results = map alt_result alts
match_results = fmap alt_result alts
sorted_alts :: [CaseAlt DataCon]
sorted_alts = sort_alts alts
sorted_alts :: NonEmpty (CaseAlt DataCon)
sorted_alts = NEL.sortWith (dataConTag . alt_pat) alts
var_ty = idType var
(_, ty_args) = tcSplitTyConApp var_ty -- Don't look through newtypes
......@@ -356,7 +354,7 @@ mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case
mk_case :: CoreExpr -> DsM CoreExpr
mk_case fail = do
alts <- mapM (mk_alt fail) sorted_alts
return $ mkWildCase (Var var) (idType var) ty (mk_default fail ++ alts)
return $ mkWildCase (Var var) (idType var) ty (mk_default fail ++ NEL.toList alts)
mk_alt :: CoreExpr -> CaseAlt DataCon -> DsM CoreAlt
mk_alt fail MkCaseAlt{ alt_pat = con,
......@@ -376,11 +374,11 @@ mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case
fail_flag :: CanItFail
fail_flag | exhaustive_case
= foldr orFail CantFail [can_it_fail | MatchResult can_it_fail _ <- match_results]
= foldr orFail CantFail [can_it_fail | MatchResult can_it_fail _ <- NEL.toList match_results]
| otherwise
= CanFail
mentioned_constructors = mkUniqSet $ map alt_pat alts
mentioned_constructors = mkUniqSet $ map alt_pat $ NEL.toList alts
un_mentioned_constructors
= mkUniqSet data_cons `minusUniqSet` mentioned_constructors
exhaustive_case = isEmptyUniqSet un_mentioned_constructors
......
This diff is collapsed.
......@@ -34,6 +34,7 @@ import SrcLoc
import Outputable
import Control.Monad(liftM)
import Data.List (groupBy)
import Data.List.NonEmpty (NonEmpty(..))
{-
We are confronted with the first column of patterns in a set of
......@@ -88,40 +89,38 @@ have-we-used-all-the-constructors? question; the local function
@match_cons_used@ does all the real work.
-}
matchConFamily :: [Id]
matchConFamily :: NonEmpty Id
-> Type
-> [[EquationInfo]]
-> NonEmpty (NonEmpty EquationInfo)
-> DsM MatchResult
-- Each group of eqns is for a single constructor
matchConFamily (var:vars) ty groups
matchConFamily (var :| vars) ty groups
= do alts <- mapM (fmap toRealAlt . matchOneConLike vars ty) groups
return (mkCoAlgCaseMatchResult var ty alts)
where
toRealAlt alt = case alt_pat alt of
RealDataCon dcon -> alt{ alt_pat = dcon }
_ -> panic "matchConFamily: not RealDataCon"
matchConFamily [] _ _ = panic "matchConFamily []"
matchPatSyn :: [Id]
matchPatSyn :: NonEmpty Id
-> Type
-> [EquationInfo]
-> NonEmpty EquationInfo
-> DsM MatchResult
matchPatSyn (var:vars) ty eqns
matchPatSyn (var :| vars) ty eqns
= do alt <- fmap toSynAlt $ matchOneConLike vars ty eqns
return (mkCoSynCaseMatchResult var ty alt)
where
toSynAlt alt = case alt_pat alt of
PatSynCon psyn -> alt{ alt_pat = psyn }
_ -> panic "matchPatSyn: not PatSynCon"
matchPatSyn _ _ _ = panic "matchPatSyn []"
type ConArgPats = HsConDetails (LPat GhcTc) (HsRecFields GhcTc (LPat GhcTc))
matchOneConLike :: [Id]
-> Type
-> [EquationInfo]
-> NonEmpty EquationInfo
-> DsM (CaseAlt ConLike)
matchOneConLike vars ty (eqn1 : eqns) -- All eqns for a single constructor
matchOneConLike vars ty (eqn1 :| eqns) -- All eqns for a single constructor
= do { let inst_tys = ASSERT( all tcIsTcTyVar ex_tvs )
-- ex_tvs can only be tyvars as data types in source
-- Haskell cannot mention covar yet (Aug 2018).
......@@ -195,7 +194,6 @@ matchOneConLike vars ty (eqn1 : eqns) -- All eqns for a single constructor
lookup_fld (L _ rpat) = lookupNameEnv_NF fld_var_env
(idName (unLoc (hsRecFieldId rpat)))
select_arg_vars _ [] = panic "matchOneCon/select_arg_vars []"
matchOneConLike _ _ [] = panic "matchOneCon []"
-----------------
compatible_pats :: (ConArgPats,a) -> (ConArgPats,a) -> Bool
......
......@@ -53,6 +53,8 @@ import qualified GHC.LanguageExtensions as LangExt
import Control.Monad
import Data.Int
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NEL
import Data.Word
import Data.Proxy
......@@ -397,14 +399,13 @@ tidyNPat over_lit mb_neg eq outer_ty
************************************************************************
-}
matchLiterals :: [Id]
-> Type -- Type of the whole case expression
-> [[EquationInfo]] -- All PgLits
matchLiterals :: NonEmpty Id
-> Type -- ^ Type of the whole case expression
-> NonEmpty (NonEmpty EquationInfo) -- ^ All PgLits
-> DsM MatchResult
matchLiterals (var:vars) ty sub_groups
= ASSERT( notNull sub_groups && all notNull sub_groups )
do { -- Deal with each group
matchLiterals (var :| vars) ty sub_groups
= do { -- Deal with each group
; alts <- mapM match_group sub_groups
-- Combine results. For everything except String
......@@ -415,14 +416,14 @@ matchLiterals (var:vars) ty sub_groups
; mrs <- mapM (wrap_str_guard eq_str) alts
; return (foldr1 combineMatchResults mrs) }
else
return (mkCoPrimCaseMatchResult var ty alts)
return (mkCoPrimCaseMatchResult var ty $ NEL.toList alts)
}
where
match_group :: [EquationInfo] -> DsM (Literal, MatchResult)
match_group eqns
match_group :: NonEmpty EquationInfo -> DsM (Literal, MatchResult)
match_group eqns@(firstEqn :| _)
= do { dflags <- getDynFlags
; let LitPat _ hs_lit = firstPat (head eqns)
; match_result <- match vars ty (shiftEqns eqns)
; let LitPat _ hs_lit = firstPat firstEqn
; match_result <- match vars ty (NEL.toList $ shiftEqns eqns)
; return (hsLitKey dflags hs_lit, match_result) }
wrap_str_guard :: Id -> (Literal,MatchResult) -> DsM MatchResult
......@@ -436,7 +437,6 @@ matchLiterals (var:vars) ty sub_groups
; return (mkGuardedMatchResult pred mr) }
wrap_str_guard _ (l, _) = pprPanic "matchLiterals/wrap_str_guard" (ppr l)
matchLiterals [] _ _ = panic "matchLiterals []"
---------------------------
hsLitKey :: DynFlags -> HsLit GhcTc -> Literal
......@@ -467,8 +467,8 @@ hsLitKey _ l = pprPanic "hsLitKey" (ppr l)
************************************************************************
-}
matchNPats :: [Id] -> Type -> [EquationInfo] -> DsM MatchResult
matchNPats (var:vars) ty (eqn1:eqns) -- All for the same literal
matchNPats :: NonEmpty Id -> Type -> NonEmpty EquationInfo -> DsM MatchResult
matchNPats (var :| vars) ty (eqn1 :| eqns) -- All for the same literal
= do { let NPat _ (L _ lit) mb_neg eq_chk = firstPat eqn1
; lit_expr <- dsOverLit lit
; neg_lit <- case mb_neg of
......@@ -477,7 +477,6 @@ matchNPats (var:vars) ty (eqn1:eqns) -- All for the same literal
; pred_expr <- dsSyntaxExpr eq_chk [Var var, neg_lit]
; match_result <- match vars ty (shiftEqns (eqn1:eqns))
; return (mkGuardedMatchResult pred_expr match_result) }
matchNPats vars _ eqns = pprPanic "matchOneNPat" (ppr (vars, eqns))
{-
************************************************************************
......@@ -497,9 +496,9 @@ We generate:
\end{verbatim}
-}
matchNPlusKPats :: [Id] -> Type -> [EquationInfo] -> DsM MatchResult
matchNPlusKPats :: NonEmpty Id -> Type -> NonEmpty EquationInfo -> DsM MatchResult
-- All NPlusKPats, for the *same* literal k
matchNPlusKPats (var:vars) ty (eqn1:eqns)
matchNPlusKPats (var :| vars) ty (eqn1 :| eqns)
= do { let NPlusKPat _ (L _ n1) (L _ lit1) lit2 ge minus
= firstPat eqn1
; lit1_expr <- dsOverLit lit1
......@@ -517,5 +516,3 @@ matchNPlusKPats (var:vars) ty (eqn1:eqns)
= (wrapBind n n1, eqn { eqn_pats = pats })
-- The wrapBind is a no-op for the first equation
shift _ e = pprPanic "matchNPlusKPats/shift" (ppr e)
matchNPlusKPats vars _ eqns = pprPanic "matchNPlusKPats" (ppr (vars, eqns))
......@@ -124,6 +124,8 @@ import Text.Printf
import Numeric (showFFloat)
import Data.Graph (SCC(..))
import Data.List (intersperse)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NEL
import GHC.Fingerprint
import GHC.Show ( showMultiLineString )
......@@ -819,6 +821,9 @@ instance Outputable () where
instance (Outputable a) => Outputable [a] where
ppr xs = brackets (fsep (punctuate comma (map ppr xs)))
instance (Outputable a) => Outputable (NonEmpty a) where
ppr = ppr . NEL.toList
instance (Outputable a) => Outputable (Set a) where
ppr s = braces (fsep (punctuate comma (map ppr (Set.toList s))))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment