Commit 043604c7 authored by Ben Gamari's avatar Ben Gamari Committed by Ben Gamari
Browse files

RnExpr: Fix ApplicativeDo desugaring with RebindableSyntax

We need to compare against the local return and pure, not returnMName
and pureAName.

Fixes #12490.

Test Plan: Validate, add testcase

Reviewers: austin, simonmar

Reviewed By: simonmar

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2499

GHC Trac Issues: #12490
parent e9b0bf4e
......@@ -25,7 +25,8 @@ module RnEnv (
lookupFieldFixityRn, lookupTyFixityRn,
lookupInstDeclBndr, lookupRecFieldOcc, lookupFamInstName,
lookupConstructorFields,
lookupSyntaxName, lookupSyntaxNames, lookupIfThenElse,
lookupSyntaxName, lookupSyntaxName', lookupSyntaxNames,
lookupIfThenElse,
lookupGreAvailRn,
getLookupOccRn,mkUnboundName, mkUnboundNameRdr, isUnboundName,
addUsedGRE, addUsedGREs, addUsedDataCons,
......@@ -1600,6 +1601,16 @@ lookupIfThenElse
; return ( Just (mkRnSyntaxExpr ite)
, unitFV ite ) } }
lookupSyntaxName' :: Name -- ^ The standard name
-> RnM Name -- ^ Possibly a non-standard name
lookupSyntaxName' std_name
= do { rebindable_on <- xoptM LangExt.RebindableSyntax
; if not rebindable_on then
return std_name
else
-- Get the similarly named thing from the local environment
lookupOccRn (mkRdrUnqual (nameOccName std_name)) }
lookupSyntaxName :: Name -- The standard name
-> RnM (SyntaxExpr Name, FreeVars) -- Possibly a non-standard name
lookupSyntaxName std_name
......
......@@ -1452,6 +1452,10 @@ dsDo {(arg_1 | ... | arg_n); stmts} expr =
-}
-- | The 'Name's of @return@ and @pure@. These may not be 'returnName' and
-- 'pureName' due to @RebindableSyntax@.
data MonadNames = MonadNames { return_name, pure_name :: Name }
-- | rearrange a list of statements using ApplicativeDoStmt. See
-- Note [ApplicativeDo].
rearrangeForApplicativeDo
......@@ -1465,7 +1469,11 @@ rearrangeForApplicativeDo ctxt stmts0 = do
optimal_ado <- goptM Opt_OptimalApplicativeDo
let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts
| otherwise = mkStmtTreeHeuristic stmts
stmtTreeToStmts ctxt stmt_tree [last] last_fvs
return_name <- lookupSyntaxName' returnMName
pure_name <- lookupSyntaxName' pureAName
let monad_names = MonadNames { return_name = return_name
, pure_name = pure_name }
stmtTreeToStmts monad_names ctxt stmt_tree [last] last_fvs
where
(stmts,(last,last_fvs)) = findLast stmts0
findLast [] = error "findLast"
......@@ -1568,7 +1576,8 @@ mkStmtTreeOptimal stmts =
-- | Turn the ExprStmtTree back into a sequence of statements, using
-- ApplicativeStmt where necessary.
stmtTreeToStmts
:: HsStmtContext Name
:: MonadNames
-> HsStmtContext Name
-> ExprStmtTree
-> [ExprLStmt Name] -- ^ the "tail"
-> FreeVars -- ^ free variables of the tail
......@@ -1581,9 +1590,9 @@ stmtTreeToStmts
-- In the spec, but we do it here rather than in the desugarer,
-- because we need the typechecker to typecheck the <$> form rather than
-- the bind form, which would give rise to a Monad constraint.
stmtTreeToStmts ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
tail _tail_fvs
| isIrrefutableHsPat pat, (False,tail') <- needJoin tail
| isIrrefutableHsPat pat, (False,tail') <- needJoin monad_names tail
-- WARNING: isIrrefutableHsPat on (HsPat Name) doesn't have enough info
-- to know which types have only one constructor. So only
-- tuples come out as irrefutable; other single-constructor
......@@ -1591,19 +1600,19 @@ stmtTreeToStmts ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
-- isIrrefuatableHsPat
= mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
stmtTreeToStmts _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
return (s : tail, emptyNameSet)
stmtTreeToStmts ctxt (StmtTreeBind before after) tail tail_fvs = do
(stmts1, fvs1) <- stmtTreeToStmts ctxt after tail tail_fvs
stmtTreeToStmts monad_names ctxt (StmtTreeBind before after) tail tail_fvs = do
(stmts1, fvs1) <- stmtTreeToStmts monad_names ctxt after tail tail_fvs
let tail1_fvs = unionNameSets (tail_fvs : map snd (flattenStmtTree after))
(stmts2, fvs2) <- stmtTreeToStmts ctxt before stmts1 tail1_fvs
(stmts2, fvs2) <- stmtTreeToStmts monad_names ctxt before stmts1 tail1_fvs
return (stmts2, fvs1 `plusFV` fvs2)
stmtTreeToStmts ctxt (StmtTreeApplicative trees) tail tail_fvs = do
stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
pairs <- mapM (stmtTreeArg ctxt tail_fvs) trees
let (stmts', fvss) = unzip pairs
let (need_join, tail') = needJoin tail
let (need_join, tail') = needJoin monad_names tail
(stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
return (stmts, unionNameSets (fvs:fvss))
where
......@@ -1617,7 +1626,7 @@ stmtTreeToStmts ctxt (StmtTreeApplicative trees) tail tail_fvs = do
-- See Note [Deterministic ApplicativeDo and RecursiveDo desugaring]
pat = mkBigLHsVarPatTup pvars
tup = mkBigLHsVarTup pvars
(stmts',fvs2) <- stmtTreeToStmts ctxt tree [] pvarset
(stmts',fvs2) <- stmtTreeToStmts monad_names ctxt tree [] pvarset
(mb_ret, fvs1) <-
if | L _ ApplicativeStmt{} <- last stmts' ->
return (unLoc tup, emptyNameSet)
......@@ -1763,17 +1772,22 @@ mkApplicativeStmt ctxt args need_join body_stmts
-- | Given the statements following an ApplicativeStmt, determine whether
-- we need a @join@ or not, and remove the @return@ if necessary.
needJoin :: [ExprLStmt Name] -> (Bool, [ExprLStmt Name])
needJoin [] = (False, []) -- we're in an ApplicativeArg
needJoin [L loc (LastStmt e _ t)]
| Just arg <- isReturnApp e = (False, [L loc (LastStmt arg True t)])
needJoin stmts = (True, stmts)
needJoin :: MonadNames
-> [ExprLStmt Name]
-> (Bool, [ExprLStmt Name])
needJoin _monad_names [] = (False, []) -- we're in an ApplicativeArg
needJoin monad_names [L loc (LastStmt e _ t)]
| Just arg <- isReturnApp monad_names e =
(False, [L loc (LastStmt arg True t)])
needJoin _monad_names stmts = (True, stmts)
-- | @Just e@, if the expression is @return e@ or @return $ e@,
-- otherwise @Nothing@
isReturnApp :: LHsExpr Name -> Maybe (LHsExpr Name)
isReturnApp (L _ (HsPar expr)) = isReturnApp expr
isReturnApp (L _ e) = case e of
isReturnApp :: MonadNames
-> LHsExpr Name
-> Maybe (LHsExpr Name)
isReturnApp monad_names (L _ (HsPar expr)) = isReturnApp monad_names expr
isReturnApp monad_names (L _ e) = case e of
OpApp l op _ r | is_return l, is_dollar op -> Just r
HsApp f arg | is_return f -> Just arg
_otherwise -> Nothing
......@@ -1784,7 +1798,8 @@ isReturnApp (L _ e) = case e of
-- TODO: I don't know how to get this right for rebindable syntax
is_var _ _ = False
is_return = is_var (\n -> n == returnMName || n == pureAName)
is_return = is_var (\n -> n == return_name monad_names
|| n == pure_name monad_names)
is_dollar = is_var (`hasKey` dollarIdKey)
{-
......
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ApplicativeDo #-}
module T12490 where
import Prelude (Int, String, Functor(..), ($), undefined, (+))
join :: Monad f => f (f a) -> f a
join = undefined
class Functor f => Applicative f where
pure :: a -> f a
(<*>) :: f (a -> b) -> f a -> f b
class Applicative f => Monad f where
return :: a -> f a
(>>=) :: f a -> (a -> f b) -> f b
fail :: String -> f a
f_app :: Applicative f => f Int -> f Int -> f Int
f_app a b = do
a' <- a
b' <- b
pure (a' + b')
f_monad :: Monad f => f Int -> f Int -> f Int
f_monad a b = do
a' <- a
b' <- b
return $ a' + b'
......@@ -7,3 +7,4 @@ test('ado006', normal, compile, [''])
test('ado007', normal, compile, [''])
test('T11607', normal, compile_and_run, [''])
test('ado-optimal', normal, compile_and_run, [''])
test('T12490', normal, compile, [''])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment