Commit d551dbfe authored by simonpj's avatar simonpj

[project @ 2005-04-04 11:55:11 by simonpj]

This commit combines three overlapping things:

1.  Make rebindable syntax work for do-notation. The idea
    here is that, in particular, (>>=) can have a type that
    has class constraints on its argument types, e.g.
       (>>=) :: (Foo m, Baz a) => m a -> (a -> m b) -> m b
    The consequence is that a BindStmt and ExprStmt must have
    individual evidence attached -- previously it was one
    batch of evidence for the entire Do
    
    Sadly, we can't do this for MDo, because we use bind at
    a polymorphic type (to tie the knot), so we still use one
    blob of evidence (now in the HsStmtContext) for MDo.
    
    For arrow syntax, the evidence is in the HsCmd.
    
    For list comprehensions, it's all built-in anyway.
    
    So the evidence on a BindStmt is only used for ordinary
    do-notation.

2.  Tidy up HsSyn.  In particular:

	- Eliminate a few "Out" forms, which we can manage
	without (e.g. 

	- It ought to be the case that the type checker only
	decorates the syntax tree, but doesn't change one
	construct into another.  That wasn't true for NPat,
	LitPat, NPlusKPat, so I've fixed that.

	- Eliminate ResultStmts from Stmt.  They always had
	to be the last Stmt, which led to awkward pattern
	matching in some places; and the benefits didn't seem
	to outweigh the costs.  Now each construct that uses
	[Stmt] has a result expression too (e.g. GRHS).


3.  Make 'deriving( Ix )' generate a binding for unsafeIndex,
    rather than for index.  This is loads more efficient.

    (This item only affects TcGenDeriv, but some of point (2)
    also affects TcGenDeriv, so it has to be in one commit.)
parent cb486104
......@@ -289,6 +289,13 @@ data IdInfo
strictnessInfo :: StrictnessInfo, -- Strictness properties
#endif
workerInfo :: WorkerInfo, -- Pointer to Worker Function
-- Within one module this is irrelevant; the
-- inlining of a worker is handled via the Unfolding
-- WorkerInfo is used *only* to indicate the form of
-- the RHS, so that interface files don't actually
-- need to contain the RHS; it can be derived from
-- the strictness info
unfoldingInfo :: Unfolding, -- Its unfolding
cafInfo :: CafInfo, -- CAF info
lbvarInfo :: LBVarInfo, -- Info about a lambda-bound variable
......
......@@ -411,12 +411,19 @@ get_used_lits qs = remove_dups' all_literals
get_used_lits' :: [(EqnNo, EquationInfo)] -> [HsLit]
get_used_lits' [] = []
get_used_lits' (q:qs)
| LitPat lit <- first_pat = lit : get_used_lits qs
| NPatOut lit _ _ <- first_pat = lit : get_used_lits qs
| otherwise = get_used_lits qs
| LitPat lit <- first_pat = lit : get_used_lits qs
| NPat lit _ _ _ <- first_pat = over_lit_lit lit : get_used_lits qs
| otherwise = get_used_lits qs
where
first_pat = firstPatN q
over_lit_lit :: HsOverLit id -> HsLit
-- Get a representative HsLit to stand for the OverLit
-- It doesn't matter which one, because they will only be compared
-- with other HsLits gotten in the same way
over_lit_lit (HsIntegral i _) = HsIntPrim i
over_lit_lit (HsFractional f _) = HsFloatPrim f
get_unused_cons :: [Pat Id] -> [DataCon]
get_unused_cons used_cons = unused_cons
where
......@@ -462,7 +469,7 @@ is_con _ = False
is_lit :: Pat Id -> Bool
is_lit (LitPat _) = True
is_lit (NPatOut _ _ _) = True
is_lit (NPat _ _ _ _) = True
is_lit _ = False
is_var :: Pat Id -> Bool
......@@ -475,10 +482,10 @@ is_var_con con (ConPatOut (L _ id) _ _ _ _ _) | id == con = True
is_var_con con _ = False
is_var_lit :: HsLit -> Pat Id -> Bool
is_var_lit lit (WildPat _) = True
is_var_lit lit (LitPat lit') | lit == lit' = True
is_var_lit lit (NPatOut lit' _ _) | lit == lit' = True
is_var_lit lit _ = False
is_var_lit lit (WildPat _) = True
is_var_lit lit (LitPat lit') = lit == lit'
is_var_lit lit (NPat lit' _ _ _) = lit == over_lit_lit lit'
is_var_lit lit _ = False
\end{code}
The difference beteewn @make_con@ and @make_whole_con@ is that
......@@ -608,19 +615,19 @@ simplify_pat (TuplePat ps boxity)
where
arity = length ps
simplify_pat pat@(LitPat lit) = unLoc (tidyLitPat lit (noLoc pat))
-- unpack string patterns fully, so we can see when they overlap with
-- each other, or even explicit lists of Chars.
simplify_pat pat@(NPatOut (HsString s) _ _) =
simplify_pat pat@(LitPat (HsString s)) =
foldr (\c pat -> mk_simple_con_pat consDataCon (PrefixCon [mk_char_lit c,noLoc pat]) stringTy)
(mk_simple_con_pat nilDataCon (PrefixCon []) stringTy) (unpackFS s)
where
mk_char_lit c = noLoc (mk_simple_con_pat charDataCon (PrefixCon [nlLitPat (HsCharPrim c)]) charTy)
simplify_pat pat@(NPatOut lit lit_ty hsexpr) = unLoc (tidyNPat lit lit_ty (noLoc pat))
simplify_pat pat@(LitPat lit) = unLoc (tidyLitPat lit (noLoc pat))
simplify_pat pat@(NPat lit mb_neg _ lit_ty) = unLoc (tidyNPat lit mb_neg lit_ty (noLoc pat))
simplify_pat (NPlusKPatOut id hslit hsexpr1 hsexpr2)
simplify_pat (NPlusKPat id hslit hsexpr1 hsexpr2)
= WildPat (idType (unLoc id))
simplify_pat (DictPat dicts methods)
......
......@@ -13,7 +13,7 @@ import DsUtils ( mkErrorAppDs,
mkCoreTupTy, mkCoreTup, selectSimpleMatchVarL,
mkTupleCase, mkBigCoreTup, mkTupleType,
mkTupleExpr, mkTupleSelector,
dsReboundNames, lookupReboundName )
dsSyntaxTable, lookupEvidence )
import DsMonad
import HsSyn
......@@ -57,17 +57,17 @@ data DsCmdEnv = DsCmdEnv {
arr_id, compose_id, first_id, app_id, choice_id, loop_id :: CoreExpr
}
mkCmdEnv :: ReboundNames Id -> DsM DsCmdEnv
mkCmdEnv :: SyntaxTable Id -> DsM DsCmdEnv
mkCmdEnv ids
= dsReboundNames ids `thenDs` \ (meth_binds, ds_meths) ->
= dsSyntaxTable ids `thenDs` \ (meth_binds, ds_meths) ->
return $ DsCmdEnv {
meth_binds = meth_binds,
arr_id = lookupReboundName ds_meths arrAName,
compose_id = lookupReboundName ds_meths composeAName,
first_id = lookupReboundName ds_meths firstAName,
app_id = lookupReboundName ds_meths appAName,
choice_id = lookupReboundName ds_meths choiceAName,
loop_id = lookupReboundName ds_meths loopAName
arr_id = Var (lookupEvidence ds_meths arrAName),
compose_id = Var (lookupEvidence ds_meths composeAName),
first_id = Var (lookupEvidence ds_meths firstAName),
app_id = Var (lookupEvidence ds_meths appAName),
choice_id = Var (lookupEvidence ds_meths choiceAName),
loop_id = Var (lookupEvidence ds_meths loopAName)
}
bindCmdEnv :: DsCmdEnv -> CoreExpr -> CoreExpr
......@@ -388,7 +388,7 @@ dsCmd ids local_vars env_ids stack res_ty (HsApp cmd arg)
-- ---> arr (\ ((((xs), p1), ... pk)*ts) -> ((ys)*ts)) >>> c
dsCmd ids local_vars env_ids stack res_ty
(HsLam (MatchGroup [L _ (Match pats _ (GRHSs [L _ (GRHS [L _ (ResultStmt body)])] _ ))] _))
(HsLam (MatchGroup [L _ (Match pats _ (GRHSs [L _ (GRHS [] body)] _ ))] _))
= let
pat_vars = mkVarSet (collectPatsBinders pats)
local_vars' = local_vars `unionVarSet` pat_vars
......@@ -575,8 +575,8 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body)
core_body,
exprFreeVars core_binds `intersectVarSet` local_vars)
dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts _ _)
= dsCmdDo ids local_vars env_ids res_ty stmts
dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _)
= dsCmdDo ids local_vars env_ids res_ty stmts body
-- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t
-- A | xs |- ci :: [tsi] ti
......@@ -650,7 +650,8 @@ dsCmdDo :: DsCmdEnv -- arrow combinators
-- This is typically fed back,
-- so don't pull on it too early
-> Type -- return type of the statement
-> [LStmt Id] -- statements to desugar
-> [LStmt Id] -- statements to desugar
-> LHsExpr Id -- body
-> DsM (CoreExpr, -- desugared expression
IdSet) -- set of local vars that occur free
......@@ -658,16 +659,16 @@ dsCmdDo :: DsCmdEnv -- arrow combinators
-- --------------------------
-- A | xs |- do { c } :: [] t
dsCmdDo ids local_vars env_ids res_ty [L _ (ResultStmt cmd)]
= dsLCmd ids local_vars env_ids [] res_ty cmd
dsCmdDo ids local_vars env_ids res_ty [] body
= dsLCmd ids local_vars env_ids [] res_ty body
dsCmdDo ids local_vars env_ids res_ty (stmt:stmts)
dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) body
= let
bound_vars = mkVarSet (map unLoc (collectLStmtBinders stmt))
local_vars' = local_vars `unionVarSet` bound_vars
in
fixDs (\ ~(_,_,env_ids') ->
dsCmdDo ids local_vars' env_ids' res_ty stmts
dsCmdDo ids local_vars' env_ids' res_ty stmts body
`thenDs` \ (core_stmts, fv_stmts) ->
returnDs (core_stmts, fv_stmts, varSetElems fv_stmts))
`thenDs` \ (core_stmts, fv_stmts, env_ids') ->
......@@ -708,7 +709,7 @@ dsCmdStmt
-- ---> arr (\ (xs) -> ((xs1),(xs'))) >>> first c >>>
-- arr snd >>> ss
dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd c_ty)
dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ c_ty)
= dsfixCmd ids local_vars [] c_ty cmd
`thenDs` \ (core_cmd, fv_cmd, env_ids1) ->
matchEnvStack env_ids []
......@@ -740,7 +741,7 @@ dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd c_ty)
-- It would be simpler and more consistent to do this using second,
-- but that's likely to be defined in terms of first.
dsCmdStmt ids local_vars env_ids out_ids (BindStmt pat cmd)
dsCmdStmt ids local_vars env_ids out_ids (BindStmt pat cmd _ _)
= dsfixCmd ids local_vars [] (hsPatType pat) cmd
`thenDs` \ (core_cmd, fv_cmd, env_ids1) ->
let
......@@ -820,8 +821,8 @@ dsCmdStmt ids local_vars env_ids out_ids (LetStmt binds)
-- first (loop (arr (\((ys1),~(ys2)) -> (ys)) >>> ss)) >>>
-- arr (\((xs1),(xs2)) -> (xs')) >>> ss'
dsCmdStmt ids local_vars env_ids out_ids (RecStmt stmts later_ids rec_ids rhss)
= let
dsCmdStmt ids local_vars env_ids out_ids (RecStmt stmts later_ids rec_ids rhss binds)
= let -- ****** binds not desugared; ROSS PLEASE FIX ********
env2_id_set = mkVarSet out_ids `minusVarSet` mkVarSet later_ids
env2_ids = varSetElems env2_id_set
env2_ty = mkTupleType env2_ids
......@@ -885,7 +886,7 @@ dsRecCmd ids local_vars stmts later_ids rec_ids rhss
-- mk_pair_fn = \ (out_ids) -> ((later_ids),(rhss))
mappM dsLExpr rhss `thenDs` \ core_rhss ->
mappM dsExpr rhss `thenDs` \ core_rhss ->
let
later_tuple = mkTupleExpr later_ids
later_ty = mkTupleType later_ids
......@@ -1011,10 +1012,9 @@ leavesMatch (L _ (Match pats _ (GRHSs grhss binds)))
mkVarSet (map unLoc (collectGroupBinders binds))
in
[(expr,
mkVarSet (map unLoc (collectStmtsBinders stmts))
mkVarSet (map unLoc (collectLStmtsBinders stmts))
`unionVarSet` defined_vars)
| L _ (GRHS stmts) <- grhss,
let L _ (ResultStmt expr) = last stmts]
| L _ (GRHS stmts expr) <- grhss]
\end{code}
Replace the leaf commands in a match
......@@ -1037,8 +1037,8 @@ replaceLeavesGRHS
-> LGRHS Id -- rhss of a case command
-> ([LHsExpr Id],-- remaining leaf expressions
LGRHS Id) -- updated GRHS
replaceLeavesGRHS (leaf:leaves) (L loc (GRHS stmts))
= (leaves, L loc (GRHS (init stmts ++ [L (getLoc leaf) (ResultStmt leaf)])))
replaceLeavesGRHS (leaf:leaves) (L loc (GRHS stmts rhs))
= (leaves, L loc (GRHS stmts leaf))
\end{code}
Balanced fold of a non-empty list.
......
......@@ -9,14 +9,14 @@ module DsExpr ( dsExpr, dsLExpr, dsLet, dsLit ) where
#include "HsVersions.h"
import Match ( matchWrapper, matchSimply )
import MatchLit ( dsLit )
import Match ( matchWrapper, matchSimply, matchSinglePat )
import MatchLit ( dsLit, dsOverLit )
import DsBinds ( dsHsNestedBinds )
import DsGRHSs ( dsGuarded )
import DsListComp ( dsListComp, dsPArrComp )
import DsUtils ( mkErrorAppDs, mkStringExpr, mkConsExpr, mkNilExpr,
mkCoreTupTy, selectSimpleMatchVarL,
dsReboundNames, lookupReboundName )
extractMatchResult, cantFailMatchResult, matchCanFail,
mkCoreTupTy, selectSimpleMatchVarL, lookupEvidence )
import DsArrows ( dsProcExpr )
import DsMonad
......@@ -34,13 +34,13 @@ import TcHsSyn ( hsPatType )
-- Sigh. This is a pain.
import TcType ( tcSplitAppTy, tcSplitFunTys, tcTyConAppTyCon, tcTyConAppArgs,
tcTyConAppArgs, isUnLiftedType, Type, mkAppTy, tcEqType )
tcTyConAppArgs, isUnLiftedType, Type, mkAppTy )
import Type ( funArgTy, splitFunTys, isUnboxedTupleType, mkFunTy )
import CoreSyn
import CoreUtils ( exprType, mkIfThenElse, bindNonRec )
import CostCentre ( mkUserCC )
import Id ( Id, idType, idName )
import Id ( Id, idType, idName, isDataConWorkId_maybe )
import PrelInfo ( rEC_CON_ERROR_ID, iRREFUT_PAT_ERROR_ID )
import DataCon ( DataCon, dataConWrapId, dataConFieldLabels, dataConInstOrigArgTys )
import DataCon ( isVanillaDataCon )
......@@ -53,6 +53,7 @@ import PrelNames ( toPName,
mfixName )
import SrcLoc ( Located(..), unLoc, getLoc, noLoc )
import Util ( zipEqual, zipWithEqual )
import Maybe ( fromJust )
import Bag ( bagToList )
import Outputable
import FastString
......@@ -156,10 +157,15 @@ dsExpr :: HsExpr Id -> DsM CoreExpr
dsExpr (HsPar e) = dsLExpr e
dsExpr (ExprWithTySigOut e _) = dsLExpr e
dsExpr (HsVar var) = returnDs (Var var)
dsExpr (HsIPVar ip) = returnDs (Var (ipNameName ip))
dsExpr (HsLit lit) = dsLit lit
-- HsOverLit has been gotten rid of by the type checker
dsExpr (HsVar var) = returnDs (Var var)
dsExpr (HsIPVar ip) = returnDs (Var (ipNameName ip))
dsExpr (HsLit lit) = dsLit lit
dsExpr (HsOverLit lit) = dsOverLit lit
dsExpr (NegApp expr neg_expr)
= do { core_expr <- dsLExpr expr
; core_neg <- dsExpr neg_expr
; return (core_neg `App` core_expr) }
dsExpr expr@(HsLam a_Match)
= matchWrapper LambdaExpr a_Match `thenDs` \ (binders, matching_code) ->
......@@ -264,19 +270,21 @@ dsExpr (HsLet binds body)
-- We need the `ListComp' form to use `deListComp' (rather than the "do" form)
-- because the interpretation of `stmts' depends on what sort of thing it is.
--
dsExpr (HsDo ListComp stmts _ result_ty)
dsExpr (HsDo ListComp stmts body result_ty)
= -- Special case for list comprehensions
dsListComp stmts elt_ty
dsListComp stmts body elt_ty
where
[elt_ty] = tcTyConAppArgs result_ty
dsExpr (HsDo do_or_lc stmts ids result_ty)
| isDoExpr do_or_lc
= dsDo do_or_lc stmts ids result_ty
dsExpr (HsDo DoExpr stmts body result_ty)
= dsDo stmts body result_ty
dsExpr (HsDo (MDoExpr tbl) stmts body result_ty)
= dsMDo tbl stmts body result_ty
dsExpr (HsDo PArrComp stmts _ result_ty)
dsExpr (HsDo PArrComp stmts body result_ty)
= -- Special case for array comprehensions
dsPArrComp (map unLoc stmts) elt_ty
dsPArrComp (map unLoc stmts) body elt_ty
where
[elt_ty] = tcTyConAppArgs result_ty
......@@ -334,44 +342,44 @@ dsExpr (ExplicitTuple expr_list boxity)
returnDs (mkConApp (tupleCon boxity (length expr_list))
(map (Type . exprType) core_exprs ++ core_exprs))
dsExpr (ArithSeqOut expr (From from))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (ArithSeq expr (From from))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
returnDs (App expr2 from2)
dsExpr (ArithSeqOut expr (FromTo from two))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (ArithSeq expr (FromTo from two))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsLExpr two `thenDs` \ two2 ->
returnDs (mkApps expr2 [from2, two2])
dsExpr (ArithSeqOut expr (FromThen from thn))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (ArithSeq expr (FromThen from thn))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsLExpr thn `thenDs` \ thn2 ->
returnDs (mkApps expr2 [from2, thn2])
dsExpr (ArithSeqOut expr (FromThenTo from thn two))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (ArithSeq expr (FromThenTo from thn two))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsLExpr thn `thenDs` \ thn2 ->
dsLExpr two `thenDs` \ two2 ->
returnDs (mkApps expr2 [from2, thn2, two2])
dsExpr (PArrSeqOut expr (FromTo from two))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (PArrSeq expr (FromTo from two))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsLExpr two `thenDs` \ two2 ->
returnDs (mkApps expr2 [from2, two2])
dsExpr (PArrSeqOut expr (FromThenTo from thn two))
= dsLExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsExpr (PArrSeq expr (FromThenTo from thn two))
= dsExpr expr `thenDs` \ expr2 ->
dsLExpr from `thenDs` \ from2 ->
dsLExpr thn `thenDs` \ thn2 ->
dsLExpr two `thenDs` \ two2 ->
returnDs (mkApps expr2 [from2, thn2, two2])
dsExpr (PArrSeqOut expr _)
dsExpr (PArrSeq expr _)
= panic "DsExpr.dsExpr: Infinite parallel array!"
-- the parser shouldn't have generated it and the renamer and typechecker
-- shouldn't have let it through
......@@ -399,8 +407,8 @@ We also handle @C{}@ as valid construction syntax for an unlabelled
constructor @C@, setting all of @C@'s fields to bottom.
\begin{code}
dsExpr (RecordConOut data_con con_expr rbinds)
= dsLExpr con_expr `thenDs` \ con_expr' ->
dsExpr (RecordCon (L _ data_con_id) con_expr rbinds)
= dsExpr con_expr `thenDs` \ con_expr' ->
let
(arg_tys, _) = tcSplitFunTys (exprType con_expr')
-- A newtype in the corner should be opaque;
......@@ -413,7 +421,8 @@ dsExpr (RecordConOut data_con con_expr rbinds)
[] -> mkErrorAppDs rEC_CON_ERROR_ID arg_ty (showSDoc (ppr lbl))
unlabelled_bottom arg_ty = mkErrorAppDs rEC_CON_ERROR_ID arg_ty ""
labels = dataConFieldLabels data_con
labels = dataConFieldLabels (fromJust (isDataConWorkId_maybe data_con_id))
-- The data_con_id is guaranteed to be the work id of the constructor
in
(if null labels
......@@ -446,10 +455,10 @@ might do some argument-evaluation first; and may have to throw away some
dictionaries.
\begin{code}
dsExpr (RecordUpdOut record_expr record_in_ty record_out_ty [])
dsExpr (RecordUpd record_expr [] record_in_ty record_out_ty)
= dsLExpr record_expr
dsExpr expr@(RecordUpdOut record_expr record_in_ty record_out_ty rbinds)
dsExpr expr@(RecordUpd record_expr rbinds record_in_ty record_out_ty)
= dsLExpr record_expr `thenDs` \ record_expr' ->
-- Desugar the rbinds, and generate let-bindings if
......@@ -553,8 +562,6 @@ dsExpr (HsProc pat cmd) = dsProcExpr pat cmd
#ifdef DEBUG
-- HsSyn constructs that just shouldn't be here:
dsExpr (ExprWithTySig _ _) = panic "dsExpr:ExprWithTySig"
dsExpr (ArithSeqIn _) = panic "dsExpr:ArithSeqIn"
dsExpr (PArrSeqIn _) = panic "dsExpr:PArrSeqIn"
#endif
\end{code}
......@@ -566,64 +573,48 @@ handled in DsListComp). Basically does the translation given in the
Haskell 98 report:
\begin{code}
dsDo :: HsStmtContext Name
-> [LStmt Id]
-> ReboundNames Id -- id for: [return,fail,>>=,>>] and possibly mfixName
-> Type -- Element type; the whole expression has type (m t)
dsDo :: [LStmt Id]
-> LHsExpr Id
-> Type -- Type of the whole expression
-> DsM CoreExpr
dsDo do_or_lc stmts ids result_ty
= dsReboundNames ids `thenDs` \ (meth_binds, ds_meths) ->
let
fail_id = lookupReboundName ds_meths failMName
bind_id = lookupReboundName ds_meths bindMName
then_id = lookupReboundName ds_meths thenMName
(m_ty, b_ty) = tcSplitAppTy result_ty -- result_ty must be of the form (m b)
-- For ExprStmt, see the comments near HsExpr.Stmt about
-- exactly what ExprStmts mean!
--
-- In dsDo we can only see DoStmt and ListComp (no guards)
go [ResultStmt expr] = dsLExpr expr
go (ExprStmt expr a_ty : stmts)
= dsLExpr expr `thenDs` \ expr2 ->
go stmts `thenDs` \ rest ->
returnDs (mkApps then_id [Type a_ty, Type b_ty, expr2, rest])
go (LetStmt binds : stmts)
= go stmts `thenDs` \ rest ->
dsLet binds rest
go (BindStmt pat expr : stmts)
= go stmts `thenDs` \ body ->
dsLExpr expr `thenDs` \ rhs ->
mkStringExpr (mk_msg (getLoc pat)) `thenDs` \ core_msg ->
let
-- In a do expression, pattern-match failure just calls
-- the monadic 'fail' rather than throwing an exception
fail_expr = mkApps fail_id [Type b_ty, core_msg]
a_ty = hsPatType pat
in
selectSimpleMatchVarL pat `thenDs` \ var ->
matchSimply (Var var) (StmtCtxt do_or_lc) pat
body fail_expr `thenDs` \ match_code ->
returnDs (mkApps bind_id [Type a_ty, Type b_ty, rhs, Lam var match_code])
go (RecStmt rec_stmts later_vars rec_vars rec_rets : stmts)
= go (bind_stmt : stmts)
where
bind_stmt = dsRecStmt m_ty ds_meths rec_stmts later_vars rec_vars rec_rets
in
go (map unLoc stmts) `thenDs` \ stmts_code ->
returnDs (foldr Let stmts_code meth_binds)
dsDo stmts body result_ty
= go (map unLoc stmts)
where
mk_msg locn = "Pattern match failure in do expression at " ++ showSDoc (ppr locn)
go [] = dsLExpr body
go (ExprStmt rhs then_expr _ : stmts)
= do { rhs2 <- dsLExpr rhs
; then_expr2 <- dsExpr then_expr
; rest <- go stmts
; returnDs (mkApps then_expr2 [rhs2, rest]) }
go (LetStmt binds : stmts)
= do { rest <- go stmts
; dsLet binds rest }
go (BindStmt pat rhs bind_op fail_op : stmts)
= do { body <- go stmts
; var <- selectSimpleMatchVarL pat
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
result_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; rhs' <- dsLExpr rhs
; bind_op' <- dsExpr bind_op
; returnDs (mkApps bind_op' [rhs', Lam var match_code]) }
-- In a do expression, pattern-match failure just calls
-- the monadic 'fail' rather than throwing an exception
handle_failure pat match fail_op
| matchCanFail match
= do { fail_op' <- dsExpr fail_op
; fail_msg <- mkStringExpr (mk_fail_msg pat)
; extractMatchResult match (App fail_op' fail_msg) }
| otherwise
= extractMatchResult match (error "It can't fail")
mk_fail_msg pat = "Pattern match failure in do expression at " ++
showSDoc (ppr (getLoc pat))
\end{code}
Translation for RecStmt's:
......@@ -634,48 +625,79 @@ We turn (RecStmt [v1,..vn] stmts) into:
return (v1,..vn))
\begin{code}
dsRecStmt :: Type -- Monad type constructor :: * -> *
-> [(Name,Id)] -- Rebound Ids
-> [LStmt Id]
-> [Id] -> [Id] -> [LHsExpr Id]
-> Stmt Id
dsRecStmt m_ty ds_meths stmts later_vars rec_vars rec_rets
= ASSERT( length rec_vars > 0 )
ASSERT( length rec_vars == length rec_rets )
BindStmt (mk_tup_pat later_pats) mfix_app
where
-- Remove any vars from later_vars that already in rec_vars
-- NB that having the same name is not enough; they must have
-- the same type. See Note [RecStmt] in HsExpr.
trimmed_laters = filter not_in_rec later_vars
not_in_rec lv = null [ v | let lv_type = idType lv
, v <- rec_vars
, v == lv
, lv_type `tcEqType` idType v ]
dsMDo :: PostTcTable
-> [LStmt Id]
-> LHsExpr Id
-> Type -- Type of the whole expression
-> DsM CoreExpr
dsMDo tbl stmts body result_ty
= go (map unLoc stmts)
where
(m_ty, b_ty) = tcSplitAppTy result_ty -- result_ty must be of the form (m b)
mfix_id = lookupEvidence tbl mfixName
return_id = lookupEvidence tbl returnMName
bind_id = lookupEvidence tbl bindMName
then_id = lookupEvidence tbl thenMName
fail_id = lookupEvidence tbl failMName
ctxt = MDoExpr tbl
go [] = dsLExpr body
go (LetStmt binds : stmts)
= do { rest <- go stmts
; dsLet binds rest }
go (ExprStmt rhs _ rhs_ty : stmts)
= do { rhs2 <- dsLExpr rhs
; rest <- go stmts
; returnDs (mkApps (Var then_id) [Type rhs_ty, Type b_ty, rhs2, rest]) }
go (BindStmt pat rhs _ _ : stmts)
= do { body <- go stmts
; var <- selectSimpleMatchVarL pat
; match <- matchSinglePat (Var var) (StmtCtxt ctxt) pat
result_ty (cantFailMatchResult body)
; fail_msg <- mkStringExpr (mk_fail_msg pat)
; let fail_expr = mkApps (Var fail_id) [Type b_ty, fail_msg]
; match_code <- extractMatchResult match fail_expr
; rhs' <- dsLExpr rhs
; returnDs (mkApps (Var bind_id) [Type (hsPatType pat), Type b_ty,
rhs', Lam var match_code]) }
go (RecStmt rec_stmts later_ids rec_ids rec_rets binds : stmts)
= ASSERT( length rec_ids > 0 )
ASSERT( length rec_ids == length rec_rets )
go (new_bind_stmt : let_stmt : stmts)
where
new_bind_stmt = mkBindStmt (mk_tup_pat later_pats) mfix_app
let_stmt = LetStmt [HsBindGroup binds [] Recursive]
-- Remove the later_ids that appear (without fancy coercions)
-- in rec_rets, because there's no need to knot-tie them separately
-- See Note [RecStmt] in HsExpr
later_ids' = filter (`notElem` mono_rec_ids) later_ids
mono_rec_ids = [ id | HsVar id <- rec_rets ]
mfix_app = nlHsApp (noLoc $ TyApp (nlHsVar mfix_id) [tup_ty]) mfix_arg
mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
(mkFunTy tup_ty body_ty))
-- The rec_tup_pat must bind the rec_vars only; remember that the
-- The rec_tup_pat must bind the rec_ids only; remember that the
-- trimmed_laters may share the same Names
-- Meanwhile, the later_pats must bind the later_vars
rec_tup_pats = map mk_wild_pat trimmed_laters ++ map nlVarPat rec_vars
later_pats = map nlVarPat trimmed_laters ++ map mk_later_pat rec_vars
rets = map nlHsVar trimmed_laters ++ rec_rets
rec_tup_pats = map mk_wild_pat later_ids' ++ map nlVarPat rec_ids
later_pats = map nlVarPat later_ids' ++ map mk_later_pat rec_ids
rets = map nlHsVar later_ids' ++ map noLoc rec_rets
mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats
body = noLoc $ HsDo DoExpr (stmts ++ [return_stmt])
[(n, HsVar id) | (n,id) <- ds_meths] -- A bit of a hack
body_ty
body = noLoc $ HsDo ctxt rec_stmts return_app body_ty
body_ty = mkAppTy m_ty tup_ty
tup_ty = mkCoreTupTy (map idType (trimmed_laters ++ rec_vars))
tup_ty = mkCoreTupTy (map idType (later_ids' ++ rec_ids))
-- mkCoreTupTy deals with singleton case
Var return_id = lookupReboundName ds_meths returnMName
Var mfix_id = lookupReboundName ds_meths mfixName
return_stmt = noLoc $ ResultStmt return_app
return_app = nlHsApp (noLoc $ TyApp (nlHsVar return_id) [tup_ty])
(mk_ret_tup rets)
......@@ -683,8 +705,8 @@ dsRecStmt m_ty ds_meths stmts later_vars rec_vars rec_rets
mk_wild_pat v = noLoc $ WildPat $ idType v
mk_later_pat :: Id -> LPat Id
mk_later_pat v | v `elem` trimmed_laters = mk_wild_pat v
| otherwise = nlVarPat v
mk_later_pat v | v `elem` later_ids' = mk_wild_pat v
| otherwise = nlVarPat v
mk_tup_pat :: [LPat Id] -> LPat Id
mk_tup_pat [p] = p
......
......@@ -12,7 +12,7 @@ import {-# SOURCE #-} DsExpr ( dsLExpr, dsLet )
import {-# SOURCE #-} Match ( matchSinglePat )
import HsSyn ( Stmt(..), HsExpr(..), GRHSs(..), GRHS(..),
HsMatchContext(..), Pat(..) )
LHsExpr, HsMatchContext(..), Pat(..) )
import CoreSyn ( CoreExpr )
import Var ( Id )
import Type ( Type )
......@@ -64,8 +64,9 @@ dsGRHSs kind pats (GRHSs grhss binds) rhs_ty