Commit f04dead9 authored by simonpj@microsoft.com's avatar simonpj@microsoft.com

Add 'rec' to stmts in a 'do', and deprecate 'mdo'

The change is this (see Trac #2798).  Instead of writing

  mdo { a <- getChar
      ; b <- f c
      ; c <- g b
      ; putChar c
      ; return b }

you would write

  do { a <- getChar
     ; rec { b <- f c
           ; c <- g b }
     ; putChar c
     ; return b }

That is, 
  * 'mdo' is eliminated 
  * 'rec' is added, which groups a bunch of statements
    into a single recursive statement

This 'rec' thing is already present for the arrow notation, so it  
makes the two more uniform.  Moreover, 'rec' lets you say more
precisely where the recursion is (if you want to), whereas 'mdo' just
says "there's recursion here somewhere".  Lastly, all this works with
rebindable syntax (which mdo does not).

Currently 'mdo' is enabled by -XRecursiveDo.  So we now deprecate this
flag, with another flag -XDoRec to enable the 'rec' keyword.

Implementation notes:
  * Some changes in Lexer.x
  * All uses of RecStmt now use record syntax

I'm still not really happy with the "rec_ids" and "later_ids" in the
RecStmt constructor, but I don't dare change it without consulting Ross
about the consequences for arrow syntax.
parent 69f8ed93
......@@ -461,13 +461,15 @@ addTickStmt isGuard (GroupStmt (stmts, binderMap) groupByClause) = do
case x of
Left a -> f a >>= (return . Left)
Right b -> g b >>= (return . Right)
addTickStmt isGuard (RecStmt stmts ids1 ids2 tys dictbinds) = do
liftM5 RecStmt
(addTickLStmts isGuard stmts)
(return ids1)
(return ids2)
(return tys)
(addTickDictBinds dictbinds)
addTickStmt isGuard stmt@(RecStmt {})
= do { stmts' <- addTickLStmts isGuard (recS_stmts stmt)
; ret' <- addTickSyntaxExpr hpcSrcSpan (recS_ret_fn stmt)
; mfix' <- addTickSyntaxExpr hpcSrcSpan (recS_mfix_fn stmt)
; bind' <- addTickSyntaxExpr hpcSrcSpan (recS_bind_fn stmt)
; dicts' <- addTickDictBinds (recS_dicts stmt)
; return (stmt { recS_stmts = stmts', recS_ret_fn = ret'
, recS_mfix_fn = mfix', recS_bind_fn = bind'
, recS_dicts = dicts' }) }
addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id)
addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e
......
......@@ -779,7 +779,9 @@ dsCmdStmt ids local_vars env_ids out_ids (LetStmt binds) = do
-- first (loop (arr (\((ys1),~(ys2)) -> (ys)) >>> ss)) >>>
-- arr (\((xs1),(xs2)) -> (xs')) >>> ss'
dsCmdStmt ids local_vars env_ids out_ids (RecStmt stmts later_ids rec_ids rhss _binds) = do
dsCmdStmt ids local_vars env_ids out_ids
(RecStmt { recS_stmts = stmts, recS_later_ids = later_ids, recS_rec_ids = rec_ids
, recS_rec_rets = rhss, recS_dicts = _binds }) = do
let -- ToDo: ****** binds not desugared; ROSS PLEASE FIX ********
env2_id_set = mkVarSet out_ids `minusVarSet` mkVarSet later_ids
env2_ids = varSetElems env2_id_set
......
......@@ -49,6 +49,7 @@ import DynFlags
import StaticFlags
import CostCentre
import Id
import Var
import PrelInfo
import DataCon
import TysWiredIn
......@@ -676,13 +677,16 @@ dsDo :: [LStmt Id]
-> Type -- Type of the whole expression
-> DsM CoreExpr
dsDo stmts body _result_ty
dsDo stmts body result_ty
= goL stmts
where
-- result_ty must be of the form (m b)
(m_ty, _b_ty) = tcSplitAppTy result_ty
goL [] = dsLExpr body
goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go stmt lstmts)
goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
go (ExprStmt rhs then_expr _) stmts
go _ (ExprStmt rhs then_expr _) stmts
= do { rhs2 <- dsLExpr rhs
; case tcSplitAppTy_maybe (exprType rhs2) of
Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty
......@@ -691,23 +695,52 @@ dsDo stmts body _result_ty
; rest <- goL stmts
; return (mkApps then_expr2 [rhs2, rest]) }
go (LetStmt binds) stmts
go _ (LetStmt binds) stmts
= do { rest <- goL stmts
; dsLocalBinds binds rest }
go (BindStmt pat rhs bind_op fail_op) stmts
=
do { body <- goL stmts
; rhs' <- dsLExpr rhs
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
go _ (BindStmt pat rhs bind_op fail_op) stmts
= do { body <- goL stmts
; rhs' <- dsLExpr rhs
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = return_op
, recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
, recS_rec_rets = rec_rets, recS_dicts = binds }) stmts
= ASSERT( length rec_ids > 0 )
goL (new_bind_stmt : let_stmt : stmts)
where
-- returnE <- dsExpr return_id
-- mfixE <- dsExpr mfix_id
new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) mfix_app
bind_op
noSyntaxExpr -- Tuple cannot fail
let_stmt = L loc $ LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)] []))
tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids
rec_tup_pats = map nlVarPat tup_ids
later_pats = rec_tup_pats
rets = map noLoc rec_rets
mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
(mkFunTy tup_ty body_ty))
mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
body = noLoc $ HsDo DoExpr rec_stmts return_app body_ty
return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
body_ty = mkAppTy m_ty tup_ty
tup_ty = mkCoreTupTy (map idType tup_ids)
-- mkCoreTupTy deals with singleton case
-- In a do expression, pattern-match failure just calls
-- the monadic 'fail' rather than throwing an exception
handle_failure pat match fail_op
......@@ -774,10 +807,11 @@ dsMDo tbl stmts body result_ty
; return (mkApps (Var bind_id) [Type (hsLPatType pat), Type b_ty,
rhs', Lam var match_code]) }
go loc (RecStmt rec_stmts later_ids rec_ids rec_rets binds) stmts
go loc (RecStmt rec_stmts later_ids rec_ids _ _ _ rec_rets binds) stmts
= ASSERT( length rec_ids > 0 )
ASSERT( length rec_ids == length rec_rets )
goL (new_bind_stmt : let_stmt : stmts)
pprTrace "dsMDo" (ppr later_ids) $
goL (new_bind_stmt : let_stmt : stmts)
where
new_bind_stmt = L loc $ mkBindStmt (mk_tup_pat later_pats) mfix_app
let_stmt = L loc $ LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)] []))
......
......@@ -847,26 +847,41 @@ data StmtLR idL idR
-- the names which they group over in statements
-- Recursive statement (see Note [RecStmt] below)
| RecStmt [LStmtLR idL idR]
--- The next two fields are only valid after renaming
[idR] -- The ids are a subset of the variables bound by the
-- stmts that are used in stmts that follow the RecStmt
[idR] -- Ditto, but these variables are the "recursive" ones,
-- that are used before they are bound in the stmts of
-- the RecStmt. From a type-checking point of view,
-- these ones have to be monomorphic
--- These fields are only valid after typechecking
[PostTcExpr] -- These expressions correspond 1-to-1 with
-- the "recursive" [id], and are the
-- expressions that should be returned by
-- the recursion.
-- They may not quite be the Ids themselves,
-- because the Id may be *polymorphic*, but
-- the returned thing has to be *monomorphic*.
(DictBinds idR) -- Method bindings of Ids bound by the
-- RecStmt, and used afterwards
| RecStmt
{ recS_stmts :: [LStmtLR idL idR]
-- The next two fields are only valid after renaming
, recS_later_ids :: [idR] -- The ids are a subset of the variables bound by the
-- stmts that are used in stmts that follow the RecStmt
, recS_rec_ids :: [idR] -- Ditto, but these variables are the "recursive" ones,
-- that are used before they are bound in the stmts of
-- the RecStmt.
-- An Id can be in both groups
-- Both sets of Ids are (now) treated monomorphically
-- The only reason they are separate is becuase the DsArrows
-- code uses them separately, and I don't understand it well
-- enough to change it
-- Rebindable syntax
, recS_bind_fn :: SyntaxExpr idR -- The bind function
, recS_ret_fn :: SyntaxExpr idR -- The return function
, recS_mfix_fn :: SyntaxExpr idR -- The mfix function
-- These fields are only valid after typechecking
, recS_rec_rets :: [PostTcExpr] -- These expressions correspond 1-to-1 with
-- recS_rec_ids, and are the
-- expressions that should be returned by
-- the recursion.
-- They may not quite be the Ids themselves,
-- because the Id may be *polymorphic*, but
-- the returned thing has to be *monomorphic*,
-- so they may be type applications
, recS_dicts :: DictBinds idR -- Method bindings of Ids bound by the
-- RecStmt, and used afterwards
}
\end{code}
ExprStmts are a bit tricky, because what they mean
......@@ -894,8 +909,8 @@ depends on the context. Consider the following contexts:
Array comprehensions are handled like list comprehensions -=chak
Note [RecStmt]
~~~~~~~~~~~~~~
Note [How RecStmt works]
~~~~~~~~~~~~~~~~~~~~~~~~
Example:
HsDo [ BindStmt x ex
......@@ -917,6 +932,17 @@ Here, the RecStmt binds a,b,c; but
Nota Bene: the two a's have different types, even though they
have the same Name.
Note [Typing a RecStmt]
~~~~~~~~~~~~~~~~~~~~~~~
A (RecStmt stmts) types as if you had written
(v1,..,vn, _, ..., _) <- mfix (\~(_, ..., _, r1, ..., rm) ->
do { stmts
; return (v1,..vn, r1, ..., rm) })
where v1..vn are the later_ids
r1..rm are the rec_ids
\begin{code}
instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where
......@@ -934,7 +960,11 @@ pprStmt (TransformStmt (stmts, _) usingExpr maybeByExpr)
byExprDoc = maybe empty (\byExpr -> hsep [ptext (sLit "by"), ppr byExpr]) maybeByExpr
pprStmt (GroupStmt (stmts, _) groupByClause) = (hsep [stmtsDoc, ptext (sLit "then group"), pprGroupByClause groupByClause])
where stmtsDoc = interpp'SP stmts
pprStmt (RecStmt segment _ _ _ _) = ptext (sLit "rec") <+> braces (vcat (map ppr segment))
pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids, recS_later_ids = later_ids })
= ptext (sLit "rec") <+>
vcat [ braces (vcat (map ppr segment))
, ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids
, ptext (sLit "later_ids=") <> ppr later_ids])]
pprGroupByClause :: (OutputableBndr id) => GroupByClause id -> SDoc
pprGroupByClause (GroupByNothing usingExpr) = hsep [ptext (sLit "using"), ppr usingExpr]
......
......@@ -139,7 +139,9 @@ mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL id
mkExprStmt :: LHsExpr idR -> StmtLR idL idR
mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR
mkRecStmt :: [LStmtLR idL idR] -> StmtLR idL idR
emptyRecStmt :: StmtLR idL idR
mkRecStmt :: [LStmtLR idL idR] -> StmtLR idL idR
mkHsIntegral i = OverLit (HsIntegral i) noRebindableInfo noSyntaxExpr
......@@ -163,7 +165,13 @@ mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt (stmts, []) (GroupBySometh
mkExprStmt expr = ExprStmt expr noSyntaxExpr placeHolderType
mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr
mkRecStmt stmts = RecStmt stmts [] [] [] emptyLHsBinds
emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = []
, recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr
, recS_bind_fn = noSyntaxExpr
, recS_rec_rets = [], recS_dicts = emptyLHsBinds }
mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts }
-------------------------------
--- A useful function for building @OpApps@. The operator is always a
......@@ -414,8 +422,8 @@ collectStmtBinders (ExprStmt _ _ _) = []
collectStmtBinders (ParStmt xs) = collectLStmtsBinders
$ concatMap fst xs
collectStmtBinders (TransformStmt (stmts, _) _ _) = collectLStmtsBinders stmts
collectStmtBinders (GroupStmt (stmts, _) _) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt ss _ _ _ _) = collectLStmtsBinders ss
collectStmtBinders (GroupStmt (stmts, _) _) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss
\end{code}
......
......@@ -246,6 +246,7 @@ data DynFlag
| Opt_TransformListComp
| Opt_GeneralizedNewtypeDeriving
| Opt_RecursiveDo
| Opt_DoRec
| Opt_PostfixOperators
| Opt_TupleSections
| Opt_PatternGuards
......@@ -1650,7 +1651,7 @@ mkFlag turnOn flagPrefix f (name, dynflag, deprecated)
deprecatedForLanguage :: String -> Bool -> Deprecated
deprecatedForLanguage lang turn_on
= Deprecated ("use -X" ++ flag ++ " or pragma {-# LANGUAGE " ++ flag ++ "#-} instead")
= Deprecated ("use -X" ++ flag ++ " or pragma {-# LANGUAGE " ++ flag ++ " #-} instead")
where
flag | turn_on = lang
| otherwise = "No"++lang
......@@ -1801,7 +1802,9 @@ xFlags = [
( "RankNTypes", Opt_RankNTypes, const Supported ),
( "ImpredicativeTypes", Opt_ImpredicativeTypes, const Supported ),
( "TypeOperators", Opt_TypeOperators, const Supported ),
( "RecursiveDo", Opt_RecursiveDo, const Supported ),
( "RecursiveDo", Opt_RecursiveDo,
deprecatedForLanguage "DoRec"),
( "DoRec", Opt_DoRec, const Supported ),
( "Arrows", Opt_Arrows, const Supported ),
( "PArr", Opt_PArr, const Supported ),
( "TemplateHaskell", Opt_TemplateHaskell, const Supported ),
......@@ -1911,7 +1914,7 @@ glasgowExtsFlags = [
, Opt_LiberalTypeSynonyms
, Opt_RankNTypes
, Opt_TypeOperators
, Opt_RecursiveDo
, Opt_DoRec
, Opt_ParallelListComp
, Opt_EmptyDataDecls
, Opt_KindSignatures
......
......@@ -662,7 +662,7 @@ reservedWordsFM = listToUFM $
( "ccall", ITccallconv, bit ffiBit),
( "prim", ITprimcallconv, bit ffiBit),
( "rec", ITrec, bit arrowsBit),
( "rec", ITrec, bit recBit),
( "proc", ITproc, bit arrowsBit)
]
......@@ -1672,6 +1672,8 @@ rawTokenStreamBit :: Int
rawTokenStreamBit = 20 -- producing a token stream with all comments included
newQualOpsBit :: Int
newQualOpsBit = 21 -- Haskell' qualified operator syntax, e.g. Prelude.(+)
recBit :: Int
recBit = 22 -- rec
always :: Int -> Bool
always _ = True
......@@ -1766,6 +1768,8 @@ mkPState buf loc flags =
.|. magicHashBit `setBitIf` dopt Opt_MagicHash flags
.|. kindSigsBit `setBitIf` dopt Opt_KindSignatures flags
.|. recursiveDoBit `setBitIf` dopt Opt_RecursiveDo flags
.|. recBit `setBitIf` dopt Opt_DoRec flags
.|. recBit `setBitIf` dopt Opt_Arrows flags
.|. unicodeSyntaxBit `setBitIf` dopt Opt_UnicodeSyntax flags
.|. unboxedTuplesBit `setBitIf` dopt Opt_UnboxedTuples flags
.|. standaloneDerivingBit `setBitIf` dopt Opt_StandaloneDeriving flags
......
......@@ -32,9 +32,7 @@ import RnTypes ( rnHsTypeFVs, rnSplice, checkTH,
import RnPat
import DynFlags ( DynFlag(..) )
import BasicTypes ( FixityDirection(..) )
import PrelNames ( hasKey, assertIdKey, assertErrorName,
loopAName, choiceAName, appAName, arrAName, composeAName, firstAName,
negateName, thenMName, bindMName, failMName, groupWithName )
import PrelNames
import Name
import NameSet
......@@ -454,8 +452,8 @@ convertOpFormsStmt (BindStmt pat cmd _ _)
= BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr
convertOpFormsStmt (ExprStmt cmd _ _)
= ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType
convertOpFormsStmt (RecStmt stmts lvs rvs es binds)
= RecStmt (map (fmap convertOpFormsStmt) stmts) lvs rvs es binds
convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts })
= stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts }
convertOpFormsStmt stmt = stmt
convertOpFormsMatch :: MatchGroup id -> MatchGroup id
......@@ -537,14 +535,13 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars
methodNamesLStmt = methodNamesStmt . unLoc
methodNamesStmt :: StmtLR Name Name -> FreeVars
methodNamesStmt (ExprStmt cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (RecStmt stmts _ _ _ _)
= methodNamesStmts stmts `addOneFV` loopAName
methodNamesStmt (LetStmt _) = emptyFVs
methodNamesStmt (ParStmt _) = emptyFVs
methodNamesStmt (TransformStmt _ _ _) = emptyFVs
methodNamesStmt (GroupStmt _ _) = emptyFVs
methodNamesStmt (ExprStmt cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
methodNamesStmt (LetStmt _) = emptyFVs
methodNamesStmt (ParStmt _) = emptyFVs
methodNamesStmt (TransformStmt _ _ _) = emptyFVs
methodNamesStmt (GroupStmt _ _) = emptyFVs
-- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error
-- here so we just do what's convenient
\end{code}
......@@ -636,67 +633,95 @@ rnStmts ctxt = rnNormalStmts ctxt
rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName]
-> RnM (thing, FreeVars)
-> RnM (([LStmt Name], thing), FreeVars)
-- Used for cases *other* than recursive mdo
-- Implements nested scopes
rnNormalStmts _ [] thing_inside
= do { (thing, fvs) <- thing_inside
; return (([],thing), fvs) }
rnNormalStmts ctxt (L loc stmt : stmts) thing_inside
= do { ((stmt', (stmts', thing)), fvs) <- rnStmt ctxt stmt $
rnNormalStmts ctxt stmts thing_inside
; return (((L loc stmt' : stmts'), thing), fvs) }
rnNormalStmts ctxt (stmt@(L loc _) : stmts) thing_inside
= do { ((stmts1, (stmts2, thing)), fvs)
<- setSrcSpan loc $
rnStmt ctxt stmt $
rnNormalStmts ctxt stmts thing_inside
; return (((stmts1 ++ stmts2), thing), fvs) }
rnStmt :: HsStmtContext Name -> Stmt RdrName
rnStmt :: HsStmtContext Name -> LStmt RdrName
-> RnM (thing, FreeVars)
-> RnM ((Stmt Name, thing), FreeVars)
-> RnM (([LStmt Name], thing), FreeVars)
rnStmt _ (ExprStmt expr _ _) thing_inside
rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
= do { (expr', fv_expr) <- rnLExpr expr
; (then_op, fvs1) <- lookupSyntaxName thenMName
; (thing, fvs2) <- thing_inside
; return ((ExprStmt expr' then_op placeHolderType, thing),
; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
fv_expr `plusFV` fvs1 `plusFV` fvs2) }
rnStmt ctxt (BindStmt pat expr _ _) thing_inside
rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
= do { (expr', fv_expr) <- rnLExpr expr
-- The binders do not scope over the expression
; (bind_op, fvs1) <- lookupSyntaxName bindMName
; (fail_op, fvs2) <- lookupSyntaxName failMName
; rnPats (StmtCtxt ctxt) [pat] $ \ [pat'] -> do
{ (thing, fvs3) <- thing_inside
; return ((BindStmt pat' expr' bind_op fail_op, thing),
; return (([L loc (BindStmt pat' expr' bind_op fail_op)], thing),
fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }}
-- fv_expr shouldn't really be filtered by the rnPatsAndThen
-- but it does not matter because the names are unique
rnStmt ctxt (LetStmt binds) thing_inside
rnStmt ctxt (L loc (LetStmt binds)) thing_inside
= do { checkLetStmt ctxt binds
; rnLocalBindsAndThen binds $ \binds' -> do
{ (thing, fvs) <- thing_inside
; return ((LetStmt binds', thing), fvs) } }
; return (([L loc (LetStmt binds')], thing), fvs) } }
rnStmt ctxt (RecStmt rec_stmts _ _ _ _) thing_inside
rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
= do { checkRecStmt ctxt
; rn_rec_stmts_and_then rec_stmts $ \ segs -> do
{ (thing, fvs) <- thing_inside
-- Step1: Bring all the binders of the mdo into scope
-- (Remember that this also removes the binders from the
-- finally-returned free-vars.)
-- And rename each individual stmt, making a
-- singleton segment. At this stage the FwdRefs field
-- isn't finished: it's empty for all except a BindStmt
-- for which it's the fwd refs within the bind itself
-- (This set may not be empty, because we're in a recursive
-- context.)
; rn_rec_stmts_and_then rec_stmts $ \ segs -> do
{ (thing, fvs_later) <- thing_inside
; (return_op, fvs1) <- lookupSyntaxName returnMName
; (mfix_op, fvs2) <- lookupSyntaxName mfixName
; (bind_op, fvs3) <- lookupSyntaxName bindMName
; let
-- Step 2: Fill in the fwd refs.
-- The segments are all singletons, but their fwd-ref
-- field mentions all the things used by the segment
-- that are bound after their use
segs_w_fwd_refs = addFwdRefs segs
(ds, us, fs, rec_stmts') = unzip4 segs_w_fwd_refs
later_vars = nameSetToList (plusFVs ds `intersectNameSet` fvs)
fwd_vars = nameSetToList (plusFVs fs)
uses = plusFVs us
rec_stmt = RecStmt rec_stmts' later_vars fwd_vars [] emptyLHsBinds
; return ((rec_stmt, thing), uses `plusFV` fvs) } }
rnStmt ctxt (ParStmt segs) thing_inside
-- Step 3: Group together the segments to make bigger segments
-- Invariant: in the result, no segment uses a variable
-- bound in a later segment
grouped_segs = glomSegments segs_w_fwd_refs
-- Step 4: Turn the segments into Stmts
-- Use RecStmt when and only when there are fwd refs
-- Also gather up the uses from the end towards the
-- start, so we can tell the RecStmt which things are
-- used 'after' the RecStmt
empty_rec_stmt = emptyRecStmt { recS_ret_fn = return_op
, recS_mfix_fn = mfix_op
, recS_bind_fn = bind_op }
(rec_stmts', fvs) = segsToStmts empty_rec_stmt grouped_segs fvs_later
; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
rnStmt ctxt (L loc (ParStmt segs)) thing_inside
= do { checkParStmt ctxt
; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
; return ((ParStmt segs', thing), fvs) }
; return (([L loc (ParStmt segs')], thing), fvs) }
rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
rnStmt ctxt (L loc (TransformStmt (stmts, _) usingExpr maybeByExpr)) thing_inside = do
checkTransformStmt ctxt
(usingExpr', fv_usingExpr) <- rnLExpr usingExpr
......@@ -707,14 +732,15 @@ rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
return ((maybeByExpr', thing), fv_maybeByExpr `plusFV` fv_thing)
return ((TransformStmt (stmts', binders) usingExpr' maybeByExpr', thing), fv_usingExpr `plusFV` fvs)
return (([L loc (TransformStmt (stmts', binders) usingExpr' maybeByExpr')], thing),
fv_usingExpr `plusFV` fvs)
where
rnMaybeLExpr Nothing = return (Nothing, emptyFVs)
rnMaybeLExpr (Just expr) = do
(expr', fv_expr) <- rnLExpr expr
return (Just expr', fv_expr)
rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
rnStmt ctxt (L loc (GroupStmt (stmts, _) groupByClause)) thing_inside = do
checkTransformStmt ctxt
-- We must rename the using expression in the context before the transform is begun
......@@ -771,7 +797,7 @@ rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
return ((groupByClause', usedBinderMap, thing), fv_groupByClause `plusFV` real_fv_thing)
traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap)
return ((GroupStmt (stmts', usedBinderMap) groupByClause', thing), fvs)
return (([L loc (GroupStmt (stmts', usedBinderMap) groupByClause')], thing), fvs)
rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name
-> [LStmt RdrName]
......@@ -858,39 +884,12 @@ rnMDoStmts :: [LStmt RdrName]
-> RnM (thing, FreeVars)
-> RnM (([LStmt Name], thing), FreeVars)
rnMDoStmts stmts thing_inside
= -- Step1: Bring all the binders of the mdo into scope
-- (Remember that this also removes the binders from the
-- finally-returned free-vars.)
-- And rename each individual stmt, making a
-- singleton segment. At this stage the FwdRefs field
-- isn't finished: it's empty for all except a BindStmt
-- for which it's the fwd refs within the bind itself
-- (This set may not be empty, because we're in a recursive
-- context.)
rn_rec_stmts_and_then stmts $ \ segs -> do {
; (thing, fvs_later) <- thing_inside
; let
-- Step 2: Fill in the fwd refs.
-- The segments are all singletons, but their fwd-ref
-- field mentions all the things used by the segment
-- that are bound after their use
segs_w_fwd_refs = addFwdRefs segs
-- Step 3: Group together the segments to make bigger segments
-- Invariant: in the result, no segment uses a variable
-- bound in a later segment
= rn_rec_stmts_and_then stmts $ \ segs -> do
{ (thing, fvs_later) <- thing_inside
; let segs_w_fwd_refs = addFwdRefs segs
grouped_segs = glomSegments segs_w_fwd_refs
-- Step 4: Turn the segments into Stmts
-- Use RecStmt when and only when there are fwd refs
-- Also gather up the uses from the end towards the
-- start, so we can tell the RecStmt which things are
-- used 'after' the RecStmt
(stmts', fvs) = segsToStmts grouped_segs fvs_later
; return ((stmts', thing), fvs) }
(stmts', fvs) = segsToStmts emptyRecStmt grouped_segs fvs_later
; return ((stmts', thing), fvs) }
---------------------------------------------
......@@ -957,7 +956,8 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
emptyFVs
)]
rn_rec_stmt_lhs fix_env (L _ (RecStmt stmts _ _ _ _)) -- Flatten Rec inside Rec
-- XXX Do we need to do something with the return and mfix names?
rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
= rn_rec_stmts_lhs fix_env stmts
rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _)) -- Syntactically illegal in mdo
......@@ -1020,16 +1020,16 @@ rn_rec_stmt all_bndrs (L loc (LetStmt (HsValBinds binds'))) _ = do
emptyNameSet, L loc (LetStmt (HsValBinds binds')))]
-- no RecStmt case becuase they get flattened above when doing the LHSes
rn_rec_stmt _ stmt@(L _ (RecStmt _ _ _ _ _)) _
rn_rec_stmt _ stmt@(L _ (RecStmt {})) _
= pprPanic "rn_rec_stmt: RecStmt" (ppr stmt)
rn_rec_stmt _ stmt@(L _ (ParStmt _)) _ -- Syntactically illegal in mdo
rn_rec_stmt _ stmt@(L _ (ParStmt {})) _ -- Syntactically illegal in mdo
= pprPanic "rn_rec_stmt: ParStmt" (ppr stmt)
rn_rec_stmt _ stmt@(L _ (TransformStmt _ _ _)) _ -- Syntactically illegal in mdo
rn_rec_stmt _ stmt@(L _ (TransformStmt {})) _ -- Syntactically illegal in mdo
= pprPanic "rn_rec_stmt: TransformStmt" (ppr stmt)
rn_rec_stmt _ stmt@(L _ (GroupStmt _ _)) _ -- Syntactically illegal in mdo
rn_rec_stmt _ stmt@(L _ (GroupStmt {})) _ -- Syntactically illegal in mdo
= pprPanic "rn_rec_stmt: GroupStmt" (ppr stmt)
rn_rec_stmt _ (L _ (LetStmt EmptyLocalBinds)) _
......@@ -1120,23 +1120,24 @@ glomSegments ((defs,uses,fwds,stmt) : segs)
----------------------------------------------------
segsToStmts :: [Segment [LStmt Name]]
segsToStmts :: Stmt Name -- A RecStmt with the SyntaxOps filled in
-> [Segment [LStmt Name]]
-> FreeVars -- Free vars used 'later'
-> ([LStmt Name], FreeVars)
segsToStmts [] fvs_later = ([], fvs_later)
segsToStmts ((defs, uses, fwds, ss) : segs) fvs_later
segsToStmts _ [] fvs_later = ([], fvs_later)
segsToStmts empty_rec_stmt ((defs, uses, fwds, ss) : segs) fvs_later
= ASSERT( not (null ss) )
(new_stmt : later_stmts, later_uses `plusFV` uses)
where
(later_stmts, later_uses) = segsToStmts segs fvs_later
(later_stmts, later_uses) = segsToStmts empty_rec_stmt segs fvs_later
new_stmt | non_rec = head ss
| otherwise = L (getLoc (head ss)) $
RecStmt ss (nameSetToList used_later) (nameSetToList fwds)
[] emptyLHsBinds
where
non_rec = isSingleton ss && isEmptyNameSet fwds
used_later = defs `intersectNameSet` later_uses
| otherwise = L (getLoc (head ss)) rec_stmt
rec_stmt = empty_rec_stmt { recS_stmts = ss
, recS_later_ids = nameSetToList used_later
, recS_rec_ids = nameSetToList fwds }
non_rec = isSingleton ss && isEmptyNameSet fwds
used_later = defs `intersectNameSet` later_uses
-- The ones needed after the RecStmt
\end{code}
......@@ -1187,10 +1188,7 @@ checkLetStmt _ctxt _binds = return ()
---------
checkRecStmt :: HsStmtContext Name -> RnM ()
checkRecStmt (MDoExpr {}) = return () -- Recursive stmt ok in 'mdo'
checkRecStmt (DoExpr {}) = return () -- ..and in 'do' but only because of arrows:
-- proc x -> do { ...rec... }
-- We don't have enough context to distinguish this situation here
-- so we leave it to the type checker
checkRecStmt (DoExpr {}) = return () -- and in 'do'
checkRecStmt ctxt = addErr msg
where
msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
......
......@@ -10,7 +10,7 @@ tcPolyExpr ::
-> BoxySigmaType
-> TcM (LHsExpr TcId)
tcMonoExpr ::
tcMonoExpr, tcMonoExprNC ::
LHsExpr Name
-> BoxyRhoType
-> TcM (LHsExpr TcId)
......
......@@ -682,21 +682,26 @@ zonkStmt env (ParStmt stmts_w_bndrs)