Commit a8941e2a authored by Simon Peyton Jones's avatar Simon Peyton Jones

Refactor HsExpr.MatchGroup

 * Make MatchGroup into a record, and use the record fields

 * Split the type field into two: mg_arg_tys and mg_res_ty
   This makes life much easier for the desugarer when the
   case alterantives are empty

A little bit of this change unavoidably ended up in the preceding
commit about empty case alternatives
parent 3671e674
......@@ -271,7 +271,7 @@ addTickLHsBind (L pos (funBind@(FunBind { fun_id = (L _ id) }))) = do
-- See Note [inline sccs]
if inline && gopt Opt_SccProfilingOn dflags then return (L pos funBind) else do
(fvs, (MatchGroup matches' ty)) <-
(fvs, mg@(MG { mg_alts = matches' })) <-
getFreeVars $
addPathEntry name $
addTickMatchGroup False (fun_matches funBind)
......@@ -293,7 +293,7 @@ addTickLHsBind (L pos (funBind@(FunBind { fun_id = (L _ id) }))) = do
else
return Nothing
return $ L pos $ funBind { fun_matches = MatchGroup matches' ty
return $ L pos $ funBind { fun_matches = mg { mg_alts = matches' }
, fun_tick = tick }
where
......@@ -586,10 +586,10 @@ addTickTupArg (Present e) = do { e' <- addTickLHsExpr e; return (Present e') }
addTickTupArg (Missing ty) = return (Missing ty)
addTickMatchGroup :: Bool{-is lambda-} -> MatchGroup Id (LHsExpr Id) -> TM (MatchGroup Id (LHsExpr Id))
addTickMatchGroup is_lam (MatchGroup matches ty) = do
addTickMatchGroup is_lam mg@(MG { mg_alts = matches }) = do
let isOneOfMany = matchesOneOfMany matches
matches' <- mapM (liftL (addTickMatch isOneOfMany is_lam)) matches
return $ MatchGroup matches' ty
return $ mg { mg_alts = matches' }
addTickMatch :: Bool -> Bool -> Match Id (LHsExpr Id) -> TM (Match Id (LHsExpr Id))
addTickMatch isOneOfMany isLambda (Match pats opSig gRHSs) =
......@@ -799,9 +799,9 @@ addTickHsCmd (HsCmdArrForm e fix cmdtop) =
--addTickHsCmd e = pprPanic "addTickHsCmd" (ppr e)
addTickCmdMatchGroup :: MatchGroup Id (LHsCmd Id) -> TM (MatchGroup Id (LHsCmd Id))
addTickCmdMatchGroup (MatchGroup matches ty) = do
addTickCmdMatchGroup mg@(MG { mg_alts = matches }) = do
matches' <- mapM (liftL addTickCmdMatch) matches
return $ MatchGroup matches' ty
return $ mg { mg_alts = matches' }
addTickCmdMatch :: Match Id (LHsCmd Id) -> TM (Match Id (LHsCmd Id))
addTickCmdMatch (Match pats opSig gRHSs) =
......
......@@ -33,7 +33,6 @@ import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import TcType
import TcEvidence
import Type
import CoreSyn
import CoreFVs
import CoreUtils
......@@ -382,7 +381,7 @@ dsCmd ids local_vars stack res_ty (HsCmdApp cmd arg) env_ids = do
-- ---> premap (\ ((((xs), p1), ... pk)*ts) -> ((ys)*ts)) c
dsCmd ids local_vars stack res_ty
(HsCmdLam (MatchGroup [L _ (Match pats _ (GRHSs [L _ (GRHS [] body)] _ ))] _))
(HsCmdLam (MG { mg_alts = [L _ (Match pats _ (GRHSs [L _ (GRHS [] body)] _ ))] }))
env_ids = do
let
pat_vars = mkVarSet (collectPatsBinders pats)
......@@ -483,8 +482,9 @@ case bodies, containing the following fields:
bodies with |||.
\begin{code}
dsCmd ids local_vars stack res_ty (HsCmdCase exp (MatchGroup matches match_ty))
env_ids = do
dsCmd ids local_vars stack res_ty
(HsCmdCase exp (MG { mg_alts = matches, mg_arg_tys = arg_tys }))
env_ids = do
stack_ids <- mapM newSysLocalDs stack
-- Extract and desugar the leaf commands in the case, building tuple
......@@ -526,12 +526,11 @@ dsCmd ids local_vars stack res_ty (HsCmdCase exp (MatchGroup matches match_ty))
(_, matches') = mapAccumL (replaceLeavesMatch res_ty) leaves' matches
in_ty = envStackType env_ids stack
pat_ty = funArgTy match_ty
match_ty' = mkFunTy pat_ty sum_ty
core_body <- dsExpr (HsCase exp (MG { mg_alts = matches', mg_arg_tys = arg_tys
, mg_res_ty = sum_ty }))
-- Note that we replace the HsCase result type by sum_ty,
-- which is the type of matches'
core_body <- dsExpr (HsCase exp (MatchGroup matches' match_ty'))
core_matches <- matchEnvStack env_ids stack_ids core_body
return (do_premap ids in_ty sum_ty res_ty core_matches core_choices,
exprFreeIds core_body `intersectVarSet` local_vars)
......
......@@ -490,7 +490,7 @@ dsExpr expr@(RecordUpd record_expr (HsRecFields { rec_flds = fields })
-- constructor aguments.
; alts <- mapM (mk_alt upd_fld_env) cons_to_upd
; ([discrim_var], matching_code)
<- matchWrapper RecUpd (MatchGroup alts in_out_ty)
<- matchWrapper RecUpd (MG { mg_alts = alts, mg_arg_tys = [in_ty], mg_res_ty = out_ty })
; return (add_field_binds field_binds' $
bindNonRec discrim_var record_expr' matching_code) }
......@@ -512,7 +512,7 @@ dsExpr expr@(RecordUpd record_expr (HsRecFields { rec_flds = fields })
-- from instance type to family type
tycon = dataConTyCon (head cons_to_upd)
in_ty = mkTyConApp tycon in_inst_tys
in_out_ty = mkFunTy in_ty (mkFamilyTyConApp tycon out_inst_tys)
out_ty = mkFamilyTyConApp tycon out_inst_tys
mk_alt upd_fld_env con
= do { let (univ_tvs, ex_tvs, eq_spec,
......@@ -761,8 +761,8 @@ dsDo stmts
later_pats = rec_tup_pats
rets = map noLoc rec_rets
mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
(mkFunTy tup_ty body_ty))
mfix_arg = noLoc $ HsLam (MG { mg_alts = [mkSimpleMatch [mfix_pat] body]
, mg_arg_tys = [tup_ty], mg_res_ty = body_ty })
mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTup rec_tup_pats
body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
ret_app = nlHsApp (noLoc return_op) (mkBigLHsTup rets)
......
......@@ -25,6 +25,7 @@ import TysWiredIn
import PrelNames
import Module
import Name
import Util
import SrcLoc
import Outputable
\end{code}
......@@ -56,16 +57,15 @@ dsGRHSs :: HsMatchContext Name -> [Pat Id] -- These are to build a MatchCon
-> GRHSs Id (LHsExpr Id) -- Guarded RHSs
-> Type -- Type of RHS
-> DsM MatchResult
dsGRHSs hs_ctx _ (GRHSs grhss binds) rhs_ty = do
match_results <- mapM (dsGRHS hs_ctx rhs_ty) grhss
let
match_result1 = foldr1 combineMatchResults match_results
match_result2 = adjustMatchResultDs
dsGRHSs hs_ctx _ (GRHSs grhss binds) rhs_ty
= ASSERT( notNull grhss )
do { match_results <- mapM (dsGRHS hs_ctx rhs_ty) grhss
; let match_result1 = foldr1 combineMatchResults match_results
match_result2 = adjustMatchResultDs
(\e -> dsLocalBinds binds e)
match_result1
-- NB: nested dsLet inside matchResult
--
return match_result2
; return match_result2 }
dsGRHS :: HsMatchContext Name -> Type -> LGRHS Id (LHsExpr Id) -> DsM MatchResult
dsGRHS hs_ctx rhs_ty (L _ (GRHS guards rhs))
......
......@@ -917,8 +917,8 @@ repE e@(HsIPVar _) = notHandled "Implicit parameters" (ppr e)
-- HsOverlit can definitely occur
repE (HsOverLit l) = do { a <- repOverloadedLiteral l; repLit a }
repE (HsLit l) = do { a <- repLiteral l; repLit a }
repE (HsLam (MatchGroup [m] _)) = repLambda m
repE (HsLamCase _ (MatchGroup ms _))
repE (HsLam (MG { mg_alts = [m] })) = repLambda m
repE (HsLamCase _ (MG { mg_alts = ms }))
= do { ms' <- mapM repMatchTup ms
; repLamCase (nonEmptyCoreList ms') }
repE (HsApp x y) = do {a <- repLE x; b <- repLE y; repApp a b}
......@@ -935,7 +935,7 @@ repE (NegApp x _) = do
repE (HsPar x) = repLE x
repE (SectionL x y) = do { a <- repLE x; b <- repLE y; repSectionL a b }
repE (SectionR x y) = do { a <- repLE x; b <- repLE y; repSectionR a b }
repE (HsCase e (MatchGroup ms _))
repE (HsCase e (MG { mg_alts = ms }))
= do { arg <- repLE e
; ms2 <- mapM repMatchTup ms
; repCaseE arg (nonEmptyCoreList ms2) }
......@@ -1166,7 +1166,7 @@ rep_bind :: LHsBind Name -> DsM (SrcSpan, Core TH.DecQ)
-- e.g. x = g 5 as a Fun MonoBinds. This is indicated by a single match
-- with an empty list of patterns
rep_bind (L loc (FunBind { fun_id = fn,
fun_matches = MatchGroup [L _ (Match [] _ (GRHSs guards wheres))] _ }))
fun_matches = MG { mg_alts = [L _ (Match [] _ (GRHSs guards wheres))] } }))
= do { (ss,wherecore) <- repBinds wheres
; guardcore <- addBinds ss (repGuards guards)
; fn' <- lookupLBinder fn
......@@ -1175,7 +1175,7 @@ rep_bind (L loc (FunBind { fun_id = fn,
; ans' <- wrapGenSyms ss ans
; return (loc, ans') }
rep_bind (L loc (FunBind { fun_id = fn, fun_matches = MatchGroup ms _ }))
rep_bind (L loc (FunBind { fun_id = fn, fun_matches = MG { mg_alts = ms } }))
= do { ms1 <- mapM repClauseTup ms
; fn' <- lookupLBinder fn
; ans <- repFun fn' (nonEmptyCoreList ms1)
......
......@@ -831,11 +831,12 @@ patterns in each equation.
\begin{code}
data MatchGroup id body
= MatchGroup
[LMatch id body] -- The alternatives
PostTcType -- The type is the type of the entire group
-- t1 -> ... -> tn -> tr
-- where there are n patterns
= MG { mg_alts :: [LMatch id body] -- The alternatives
, mg_arg_tys :: [PostTcType] -- Types of the arguments, t1..tn
, mg_res_ty :: PostTcType } -- Type of the result, tr
-- The type is the type of the entire group
-- t1 -> ... -> tn -> tr
-- where there are n patterns
deriving (Data, Typeable)
type LMatch id body = Located (Match id body)
......@@ -849,17 +850,14 @@ data Match id body
deriving (Data, Typeable)
isEmptyMatchGroup :: MatchGroup id body -> Bool
isEmptyMatchGroup (MatchGroup ms _) = null ms
isEmptyMatchGroup (MG { mg_alts = ms }) = null ms
matchGroupArity :: MatchGroup id body -> Arity
matchGroupArity (MatchGroup [] _)
= panic "matchGroupArity" -- Precondition: MatchGroup is non-empty
matchGroupArity (MatchGroup (match:matches) _)
= ASSERT( all ((== n_pats) . length . hsLMatchPats) matches )
-- Assertion just checks that all the matches have the same number of pats
n_pats
where
n_pats = length (hsLMatchPats match)
-- Precondition: MatchGroup is non-empty
-- This is called before type checking, when mg_arg_tys is not set
matchGroupArity (MG { mg_alts = alts })
| (alt1:_) <- alts = length (hsLMatchPats alt1)
| otherwise = panic "matchGroupArity"
hsLMatchPats :: LMatch id body -> [LPat id]
hsLMatchPats (L _ (Match pats _ _)) = pats
......@@ -884,7 +882,7 @@ We know the list must have at least one @Match@ in it.
\begin{code}
pprMatches :: (OutputableBndr idL, OutputableBndr idR, Outputable body)
=> HsMatchContext idL -> MatchGroup idR body -> SDoc
pprMatches ctxt (MatchGroup matches _)
pprMatches ctxt (MG { mg_alts = matches })
= vcat (map (pprMatch ctxt) (map unLoc matches))
-- Don't print the type; it's only a place-holder before typechecking
......
......@@ -128,7 +128,7 @@ unguardedRHS :: Located (body id) -> [LGRHS id (Located (body id))]
unguardedRHS rhs@(L loc _) = [L loc (GRHS [] rhs)]
mkMatchGroup :: [LMatch id (Located (body id))] -> MatchGroup id (Located (body id))
mkMatchGroup matches = MatchGroup matches placeHolderType
mkMatchGroup matches = MG { mg_alts = matches, mg_arg_tys = [], mg_res_ty = placeHolderType }
mkHsAppTy :: LHsType name -> LHsType name -> LHsType name
mkHsAppTy t1 t2 = addCLoc t1 t2 (HsAppTy t1 t2)
......
......@@ -310,13 +310,13 @@ getMonoBind :: LHsBind RdrName -> [LHsDecl RdrName]
-- No AndMonoBinds or EmptyMonoBinds here; just single equations
getMonoBind (L loc1 (FunBind { fun_id = fun_id1@(L _ f1), fun_infix = is_infix1,
fun_matches = MatchGroup mtchs1 _ })) binds
fun_matches = MG { mg_alts = mtchs1 } })) binds
| has_args mtchs1
= go is_infix1 mtchs1 loc1 binds []
where
go is_infix mtchs loc
(L loc2 (ValD (FunBind { fun_id = L _ f2, fun_infix = is_infix2,
fun_matches = MatchGroup mtchs2 _ })) : binds) _
fun_matches = MG { mg_alts = mtchs2 } })) : binds) _
| f1 == f2 = go (is_infix || is_infix2) (mtchs2 ++ mtchs)
(combineSrcSpans loc loc2) binds []
go is_infix mtchs loc (doc_decl@(L loc2 (DocD _)) : binds) doc_decls
......@@ -886,9 +886,9 @@ checkCmdStmt _ stmt@(RecStmt { recS_stmts = stmts }) = do
checkCmdStmt l stmt = cmdStmtFail l stmt
checkCmdMatchGroup :: MatchGroup RdrName (LHsExpr RdrName) -> P (MatchGroup RdrName (LHsCmd RdrName))
checkCmdMatchGroup (MatchGroup ms ty) = do
checkCmdMatchGroup mg@(MG { mg_alts = ms }) = do
ms' <- mapM (locMap $ const convert) ms
return $ MatchGroup ms' ty
return $ mg { mg_alts = ms' }
where convert (Match pat mty grhss) = do
grhss' <- checkCmdGRHSs grhss
return $ Match pat mty grhss'
......
......@@ -606,7 +606,7 @@ rnMethodBind :: Name
-> RnM (Bag (LHsBindLR Name Name), FreeVars)
rnMethodBind cls sig_fn
(L loc bind@(FunBind { fun_id = name, fun_infix = is_infix
, fun_matches = MatchGroup matches _ }))
, fun_matches = MG { mg_alts = matches } }))
= setSrcSpan loc $ do
sel_name <- wrapLocM (lookupInstDeclBndr cls (ptext (sLit "method"))) name
let plain_name = unLoc sel_name
......@@ -614,7 +614,7 @@ rnMethodBind cls sig_fn
(new_matches, fvs) <- bindSigTyVarsFV (sig_fn plain_name) $
mapFvRn (rnMatch (FunRhs plain_name is_infix) rnLExpr) matches
let new_group = MatchGroup new_matches placeHolderType
let new_group = mkMatchGroup new_matches
when is_infix $ checkPrecMatch plain_name new_group
return (unitBag (L loc (bind { fun_id = sel_name
......
......@@ -525,7 +525,7 @@ methodNamesCmd (HsCmdCase _ matches)
---------------------------------------------------
methodNamesMatch :: MatchGroup Name (LHsCmd Name) -> FreeVars
methodNamesMatch (MatchGroup ms _)
methodNamesMatch (MG { mg_alts = ms })
= plusFVs (map do_one ms)
where
do_one (L _ (Match _ _ grhss)) = methodNamesGRHSs grhss
......
......@@ -719,7 +719,7 @@ checkPrecMatch :: Name -> MatchGroup Name body -> RnM ()
-- eg a `op` b `C` c = ...
-- See comments with rnExpr (OpApp ...) about "deriving"
checkPrecMatch op (MatchGroup ms _)
checkPrecMatch op (MG { mg_alts = ms })
= mapM_ check ms
where
check (L _ (Match (L l1 p1 : L l2 p2 :_) _ _))
......
......@@ -18,7 +18,7 @@ import {-# SOURCE #-} TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp, tcCheckId )
import HsSyn
import TcMatches
-- import TcSimplify( solveWantedsTcM )
import TcHsSyn( hsLPatType )
import TcType
import TcMType
import TcBinds
......@@ -192,7 +192,7 @@ tc_cmd env cmd@(HsCmdApp fun arg) (cmd_stk, res_ty)
-------------------------------------------
-- Lambda
tc_cmd env cmd@(HsCmdLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] _))
tc_cmd env cmd@(HsCmdLam (MG { mg_alts = [L mtch_loc (match@(Match pats _maybe_rhs_sig grhss))] }))
(cmd_stk, res_ty)
= addErrCtxt (pprMatchInCtxt match_ctxt match) $
......@@ -206,7 +206,10 @@ tc_cmd env cmd@(HsCmdLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_s
tc_grhss grhss res_ty
; let match' = L mtch_loc (Match pats' Nothing grhss')
; return (HsCmdLam (MatchGroup [match'] res_ty))
arg_tys = map hsLPatType pats'
; return (HsCmdLam (MG { mg_alts = [match'], mg_arg_tys = arg_tys
, mg_res_ty = res_ty }))
-- Or should we decompose res_ty?
}
where
......
......@@ -1305,8 +1305,8 @@ decideGeneralisationPlan dflags type_env bndr_names lbinds sig_fn
&& no_sig (unLoc v)
restricted (AbsBinds {}) = panic "isRestrictedGroup/unrestricted AbsBinds"
restricted_match (MatchGroup (L _ (Match [] _ _) : _) _) = True
restricted_match _ = False
restricted_match (MG { mg_alts = L _ (Match [] _ _) : _ }) = True
restricted_match _ = False
-- No args => like a pattern binding
-- Some args => a function binding
......
......@@ -498,10 +498,11 @@ zonkLTcSpecPrags env ps
zonkMatchGroup :: ZonkEnv
-> (ZonkEnv -> Located (body TcId) -> TcM (Located (body Id)))
-> MatchGroup TcId (Located (body TcId)) -> TcM (MatchGroup Id (Located (body Id)))
zonkMatchGroup env zBody (MatchGroup ms ty)
zonkMatchGroup env zBody (MG { mg_alts = ms, mg_arg_tys = arg_tys, mg_res_ty = res_ty })
= do { ms' <- mapM (zonkMatch env zBody) ms
; ty' <- zonkTcTypeToType env ty
; return (MatchGroup ms' ty') }
; arg_tys' <- zonkTcTypeToTypes env arg_tys
; res_ty' <- zonkTcTypeToType env res_ty
; return (MG { mg_alts = ms', mg_arg_tys = arg_tys', mg_res_ty = res_ty' }) }
zonkMatch :: ZonkEnv
-> (ZonkEnv -> Located (body TcId) -> TcM (Located (body Id)))
......
......@@ -109,7 +109,7 @@ tcMatchesCase :: (Outputable (body Name)) =>
tcMatchesCase ctxt scrut_ty matches res_ty
| isEmptyMatchGroup matches -- Allow empty case expressions
= return (MatchGroup [] (mkFunTys [scrut_ty] res_ty))
= return (MG { mg_alts = [], mg_arg_tys = [scrut_ty], mg_res_ty = res_ty })
| otherwise
= tcMatches ctxt [scrut_ty] res_ty matches
......@@ -180,10 +180,10 @@ data TcMatchCtxt body -- c.f. TcStmtCtxt, also in this module
-> TcRhoType
-> TcM (Located (body TcId)) }
tcMatches ctxt pat_tys rhs_ty (MatchGroup matches _)
tcMatches ctxt pat_tys rhs_ty (MG { mg_alts = matches })
= ASSERT( not (null matches) ) -- Ensure that rhs_ty is filled in
do { matches' <- mapM (tcMatch ctxt pat_tys rhs_ty) matches
; return (MatchGroup matches' (mkFunTys pat_tys rhs_ty)) }
; return (MG { mg_alts = matches', mg_arg_tys = pat_tys, mg_res_ty = rhs_ty }) }
-------------
tcMatch :: (Outputable (body Name)) => TcMatchCtxt body
......@@ -855,8 +855,11 @@ number of args are used in each equation.
\begin{code}
checkArgs :: Name -> MatchGroup Name body -> TcM ()
checkArgs fun (MatchGroup (match1:matches) _)
| null bad_matches = return ()
checkArgs _ (MG { mg_alts = [] })
= return ()
checkArgs fun (MG { mg_alts = match1:matches })
| null bad_matches
= return ()
| otherwise
= failWithTc (vcat [ptext (sLit "Equations for") <+> quotes (ppr fun) <+>
ptext (sLit "have different numbers of arguments"),
......@@ -868,6 +871,5 @@ checkArgs fun (MatchGroup (match1:matches) _)
args_in_match :: LMatch Name body -> Int
args_in_match (L _ (Match pats _ _)) = length pats
checkArgs fun _ = pprPanic "TcPat.checkArgs" (ppr fun) -- Matches always non-empty
\end{code}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment