Commit 41f90559 authored by Simon Marlow's avatar Simon Marlow

ApplicativeDo: handle BodyStmt (#12143)

Summary:
It's simple to treat BodyStmt just like a BindStmt with a wildcard
pattern, which is enough to fix #12143 without going all the way to
using `<*` and `*>` (#10892).

Test Plan:
* new test cases in `ado004.hs`
* validate

Reviewers: niteria, simonpj, bgamari, austin, erikd

Subscribers: rwbarton, thomie

GHC Trac Issues: #12143

Differential Revision: https://phabricator.haskell.org/D4128
parent 7d7d94fb
......@@ -767,8 +767,11 @@ addTickApplicativeArg
addTickApplicativeArg isGuard (op, arg) =
liftM2 (,) (addTickSyntaxExpr hpcSrcSpan op) (addTickArg arg)
where
addTickArg (ApplicativeArgOne pat expr) =
ApplicativeArgOne <$> addTickLPat pat <*> addTickLHsExpr expr
addTickArg (ApplicativeArgOne pat expr isBody) =
ApplicativeArgOne
<$> addTickLPat pat
<*> addTickLHsExpr expr
<*> pure isBody
addTickArg (ApplicativeArgMany stmts ret pat) =
ApplicativeArgMany
<$> addTickLStmts isGuard stmts
......
......@@ -924,7 +924,7 @@ dsDo stmts
let
(pats, rhss) = unzip (map (do_arg . snd) args)
do_arg (ApplicativeArgOne pat expr) =
do_arg (ApplicativeArgOne pat expr _) =
(pat, dsLExpr expr)
do_arg (ApplicativeArgMany stmts ret pat) =
(pat, dsDo (stmts ++ [noLoc $ mkLastStmt (noLoc ret)]))
......
......@@ -1777,13 +1777,18 @@ deriving instance (DataId idL, DataId idR) => Data (ParStmtBlock idL idR)
-- | Applicative Argument
data ApplicativeArg idL idR
= ApplicativeArgOne -- pat <- expr (pat must be irrefutable)
(LPat idL)
= ApplicativeArgOne -- A single statement (BindStmt or BodyStmt)
(LPat idL) -- WildPat if it was a BodyStmt (see below)
(LHsExpr idL)
| ApplicativeArgMany -- do { stmts; return vars }
[ExprLStmt idL] -- stmts
(HsExpr idL) -- return (v1,..,vn), or just (v1,..,vn)
(LPat idL) -- (v1,...,vn)
Bool -- True <=> was a BodyStmt
-- False <=> was a BindStmt
-- See Note [Applicative BodyStmt]
| ApplicativeArgMany -- do { stmts; return vars }
[ExprLStmt idL] -- stmts
(HsExpr idL) -- return (v1,..,vn), or just (v1,..,vn)
(LPat idL) -- (v1,...,vn)
deriving instance (DataId idL, DataId idR) => Data (ApplicativeArg idL idR)
{-
......@@ -1921,6 +1926,34 @@ Parallel statements require the 'Control.Monad.Zip.mzip' function:
In any other context than 'MonadComp', the fields for most of these
'SyntaxExpr's stay bottom.
Note [Applicative BodyStmt]
(#12143) For the purposes of ApplicativeDo, we treat any BodyStmt
as if it was a BindStmt with a wildcard pattern. For example,
do
x <- A
B
return x
is transformed as if it were
do
x <- A
_ <- B
return x
so it transforms to
(\(x,_) -> x) <$> A <*> B
But we have to remember when we treat a BodyStmt like a BindStmt,
because in error messages we want to emit the original syntax the user
wrote, not our internal representation. So ApplicativeArgOne has a
Bool flag that is True when the original statement was a BodyStmt, so
that we can pretty-print it correctly.
-}
instance (SourceTextX idL, OutputableBndrId idL)
......@@ -1973,7 +2006,11 @@ pprStmt (ApplicativeStmt args mb_join _)
flattenStmt (L _ (ApplicativeStmt args _ _)) = concatMap flattenArg args
flattenStmt stmt = [ppr stmt]
flattenArg (_, ApplicativeArgOne pat expr) =
flattenArg (_, ApplicativeArgOne pat expr isBody)
| isBody = -- See Note [Applicative BodyStmt]
[ppr (BodyStmt expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
:: ExprStmt idL)]
| otherwise =
[ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
:: ExprStmt idL)]
flattenArg (_, ApplicativeArgMany stmts _ _) =
......@@ -1987,7 +2024,11 @@ pprStmt (ApplicativeStmt args mb_join _)
then ap_expr
else text "join" <+> parens ap_expr
pp_arg (_, ApplicativeArgOne pat expr) =
pp_arg (_, ApplicativeArgOne pat expr isBody)
| isBody = -- See Note [Applicative BodyStmt]
ppr (BodyStmt expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
:: ExprStmt idL)
| otherwise =
ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
:: ExprStmt idL)
pp_arg (_, ApplicativeArgMany stmts return pat) =
......
......@@ -1197,7 +1197,7 @@ lStmtsImplicits = hs_lstmts
hs_stmt :: StmtLR GhcRn idR (Located (body idR)) -> NameSet
hs_stmt (BindStmt pat _ _ _ _) = lPatImplicits pat
hs_stmt (ApplicativeStmt args _ _) = unionNameSets (map do_arg args)
where do_arg (_, ApplicativeArgOne pat _) = lPatImplicits pat
where do_arg (_, ApplicativeArgOne pat _ _) = lPatImplicits pat
do_arg (_, ApplicativeArgMany stmts _ _) = hs_lstmts stmts
hs_stmt (LetStmt binds) = hs_local_binds (unLoc binds)
hs_stmt (BodyStmt {}) = emptyNameSet
......
......@@ -1659,7 +1659,12 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
tail _tail_fvs
| not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail
-- See Note [ApplicativeDo and strict patterns]
= mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
= mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs False] False tail'
stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt rhs _ _ _),_))
tail _tail_fvs
| (False,tail') <- needJoin monad_names tail
= mkApplicativeStmt ctxt
[ApplicativeArgOne nlWildPatName rhs True] False tail'
stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
return (s : tail, emptyNameSet)
......@@ -1678,7 +1683,9 @@ stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
return (stmts, unionNameSets (fvs:fvss))
where
stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BindStmt pat exp _ _ _), _)) =
return (ApplicativeArgOne pat exp, emptyFVs)
return (ApplicativeArgOne pat exp False, emptyFVs)
stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BodyStmt exp _ _ _), _)) =
return (ApplicativeArgOne nlWildPatName exp True, emptyFVs)
stmtTreeArg ctxt tail_fvs tree = do
let stmts = flattenStmtTree tree
pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
......
......@@ -1098,11 +1098,11 @@ zonkStmt env _zBody (ApplicativeStmt args mb_join body_ty)
zonk_join env Nothing = return (env, Nothing)
zonk_join env (Just j) = second Just <$> zonkSyntaxExpr env j
get_pat (_, ApplicativeArgOne pat _) = pat
get_pat (_, ApplicativeArgOne pat _ _) = pat
get_pat (_, ApplicativeArgMany _ _ pat) = pat
replace_pat pat (op, ApplicativeArgOne _ a)
= (op, ApplicativeArgOne pat a)
replace_pat pat (op, ApplicativeArgOne _ a isBody)
= (op, ApplicativeArgOne pat a isBody)
replace_pat pat (op, ApplicativeArgMany a b _)
= (op, ApplicativeArgMany a b pat)
......@@ -1121,9 +1121,9 @@ zonkStmt env _zBody (ApplicativeStmt args mb_join body_ty)
; return (env2, (new_op, new_arg) : new_args) }
zonk_args_rev env [] = return (env, [])
zonk_arg env (ApplicativeArgOne pat expr)
zonk_arg env (ApplicativeArgOne pat expr isBody)
= do { new_expr <- zonkLExpr env expr
; return (ApplicativeArgOne pat new_expr) }
; return (ApplicativeArgOne pat new_expr isBody) }
zonk_arg env (ApplicativeArgMany stmts ret pat)
= do { (env1, new_stmts) <- zonkStmts env zonkLExpr stmts
; new_ret <- zonkExpr env1 ret
......
......@@ -1055,13 +1055,13 @@ tcApplicativeStmts ctxt pairs rhs_ty thing_inside
goArg :: (ApplicativeArg GhcRn GhcRn, Type, Type)
-> TcM (ApplicativeArg GhcTcId GhcTcId)
goArg (ApplicativeArgOne pat rhs, pat_ty, exp_ty)
goArg (ApplicativeArgOne pat rhs isBody, pat_ty, exp_ty)
= setSrcSpan (combineSrcSpans (getLoc pat) (getLoc rhs)) $
addErrCtxt (pprStmtInCtxt ctxt (mkBindStmt pat rhs)) $
do { rhs' <- tcMonoExprNC rhs (mkCheckExpType exp_ty)
; (pat', _) <- tcPat (StmtCtxt ctxt) pat (mkCheckExpType pat_ty) $
return ()
; return (ApplicativeArgOne pat' rhs') }
; return (ApplicativeArgOne pat' rhs' isBody) }
goArg (ApplicativeArgMany stmts ret pat, pat_ty, exp_ty)
= do { (stmts', (ret',pat')) <-
......@@ -1075,7 +1075,7 @@ tcApplicativeStmts ctxt pairs rhs_ty thing_inside
; return (ApplicativeArgMany stmts' ret' pat') }
get_arg_bndrs :: ApplicativeArg GhcTcId GhcTcId -> [Id]
get_arg_bndrs (ApplicativeArgOne pat _) = collectPatBinders pat
get_arg_bndrs (ApplicativeArgOne pat _ _) = collectPatBinders pat
get_arg_bndrs (ApplicativeArgMany _ _ pat) = collectPatBinders pat
......
......@@ -16,6 +16,19 @@ test1a f = do
y <- f 4
return $ x + y
-- When one of the statements is a BodyStmt
test1b :: Applicative f => (Int -> f Int) -> f Int
test1b f = do
x <- f 3
f 4
return x
test1c :: Applicative f => (Int -> f Int) -> f Int
test1c f = do
f 3
x <- f 4
return x
-- Test we can also infer the Applicative version of the type
test2 f = do
x <- f 3
......@@ -32,6 +45,11 @@ test2c f = do
x <- f 3
return $ x + 1
-- with a BodyStmt
test2d f = do
f 3
return 4
-- Test for just one statement
test2b f = do
return (f 3)
......
......@@ -3,6 +3,10 @@ TYPE SIGNATURES
forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
test1a ::
forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
test1b ::
forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
test1c ::
forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
test2 ::
forall (f :: * -> *) t b.
(Num b, Num t, Applicative f) =>
......@@ -17,6 +21,10 @@ TYPE SIGNATURES
forall (f :: * -> *) t b.
(Num b, Num t, Functor f) =>
(t -> f b) -> f b
test2d ::
forall (f :: * -> *) t1 b t2.
(Num b, Num t1, Functor f) =>
(t1 -> f t2) -> f b
test3 ::
forall (m :: * -> *) t1 t2 a.
(Num t1, Monad m) =>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment