Commit 8ecf6d8f authored by Simon Marlow's avatar Simon Marlow

ApplicativeDo transformation

Summary:
This is an implementation of the ApplicativeDo proposal.  See the Note
[ApplicativeDo] in RnExpr for details on the current implementation,
and the wiki page https://ghc.haskell.org/trac/ghc/wiki/ApplicativeDo
for design notes.

Test Plan: validate

Reviewers: simonpj, goldfire, austin

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D729
parent 43eb1dc5
...@@ -22,10 +22,6 @@ module MkCore ( ...@@ -22,10 +22,6 @@ module MkCore (
-- * Constructing equality evidence boxes -- * Constructing equality evidence boxes
mkEqBox, mkEqBox,
-- * Constructing general big tuples
-- $big_tuples
mkChunkified,
-- * Constructing small tuples -- * Constructing small tuples
mkCoreVarTup, mkCoreVarTupTy, mkCoreTup, mkCoreVarTup, mkCoreVarTupTy, mkCoreTup,
...@@ -67,6 +63,7 @@ import HscTypes ...@@ -67,6 +63,7 @@ import HscTypes
import TysWiredIn import TysWiredIn
import PrelNames import PrelNames
import HsUtils ( mkChunkified, chunkify )
import TcType ( mkSigmaTy ) import TcType ( mkSigmaTy )
import Type import Type
import Coercion import Coercion
...@@ -82,7 +79,6 @@ import UniqSupply ...@@ -82,7 +79,6 @@ import UniqSupply
import BasicTypes import BasicTypes
import Util import Util
import Pair import Pair
import Constants
import DynFlags import DynFlags
import Data.Char ( ord ) import Data.Char ( ord )
...@@ -319,47 +315,6 @@ mkEqBox co = ASSERT2( typeKind ty2 `eqKind` k, ppr co $$ ppr ty1 $$ ppr ty2 $$ p ...@@ -319,47 +315,6 @@ mkEqBox co = ASSERT2( typeKind ty2 `eqKind` k, ppr co $$ ppr ty1 $$ ppr ty2 $$ p
************************************************************************ ************************************************************************
-} -}
-- $big_tuples
-- #big_tuples#
--
-- GHCs built in tuples can only go up to 'mAX_TUPLE_SIZE' in arity, but
-- we might concievably want to build such a massive tuple as part of the
-- output of a desugaring stage (notably that for list comprehensions).
--
-- We call tuples above this size \"big tuples\", and emulate them by
-- creating and pattern matching on >nested< tuples that are expressible
-- by GHC.
--
-- Nesting policy: it's better to have a 2-tuple of 10-tuples (3 objects)
-- than a 10-tuple of 2-tuples (11 objects), so we want the leaves of any
-- construction to be big.
--
-- If you just use the 'mkBigCoreTup', 'mkBigCoreVarTupTy', 'mkTupleSelector'
-- and 'mkTupleCase' functions to do all your work with tuples you should be
-- fine, and not have to worry about the arity limitation at all.
-- | Lifts a \"small\" constructor into a \"big\" constructor by recursive decompositon
mkChunkified :: ([a] -> a) -- ^ \"Small\" constructor function, of maximum input arity 'mAX_TUPLE_SIZE'
-> [a] -- ^ Possible \"big\" list of things to construct from
-> a -- ^ Constructed thing made possible by recursive decomposition
mkChunkified small_tuple as = mk_big_tuple (chunkify as)
where
-- Each sub-list is short enough to fit in a tuple
mk_big_tuple [as] = small_tuple as
mk_big_tuple as_s = mk_big_tuple (chunkify (map small_tuple as_s))
chunkify :: [a] -> [[a]]
-- ^ Split a list into lists that are small enough to have a corresponding
-- tuple arity. The sub-lists of the result all have length <= 'mAX_TUPLE_SIZE'
-- But there may be more than 'mAX_TUPLE_SIZE' sub-lists
chunkify xs
| n_xs <= mAX_TUPLE_SIZE = [xs]
| otherwise = split xs
where
n_xs = length xs
split [] = []
split xs = take mAX_TUPLE_SIZE xs : split (drop mAX_TUPLE_SIZE xs)
{- {-
Creating tuples and their types for Core expressions Creating tuples and their types for Core expressions
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
(c) University of Glasgow, 2007 (c) University of Glasgow, 2007
-} -}
{-# LANGUAGE NondecreasingIndentation #-} {-# LANGUAGE CPP, NondecreasingIndentation #-}
module Coverage (addTicksToBinds, hpcInitCode) where module Coverage (addTicksToBinds, hpcInitCode) where
...@@ -660,9 +660,10 @@ addTickLStmts' isGuard lstmts res ...@@ -660,9 +660,10 @@ addTickLStmts' isGuard lstmts res
; return (lstmts', a) } ; return (lstmts', a) }
addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id (LHsExpr Id) -> TM (Stmt Id (LHsExpr Id)) addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id (LHsExpr Id) -> TM (Stmt Id (LHsExpr Id))
addTickStmt _isGuard (LastStmt e ret) = do addTickStmt _isGuard (LastStmt e noret ret) = do
liftM2 LastStmt liftM3 LastStmt
(addTickLHsExpr e) (addTickLHsExpr e)
(pure noret)
(addTickSyntaxExpr hpcSrcSpan ret) (addTickSyntaxExpr hpcSrcSpan ret)
addTickStmt _isGuard (BindStmt pat e bind fail) = do addTickStmt _isGuard (BindStmt pat e bind fail) = do
liftM4 BindStmt liftM4 BindStmt
...@@ -684,6 +685,9 @@ addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do ...@@ -684,6 +685,9 @@ addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do
(mapM (addTickStmtAndBinders isGuard) pairs) (mapM (addTickStmtAndBinders isGuard) pairs)
(addTickSyntaxExpr hpcSrcSpan mzipExpr) (addTickSyntaxExpr hpcSrcSpan mzipExpr)
(addTickSyntaxExpr hpcSrcSpan bindExpr) (addTickSyntaxExpr hpcSrcSpan bindExpr)
addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do
args' <- mapM (addTickApplicativeArg isGuard) args
return (ApplicativeStmt args' mb_join body_ty)
addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
, trS_by = by, trS_using = using , trS_by = by, trS_using = using
...@@ -710,6 +714,20 @@ addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id) ...@@ -710,6 +714,20 @@ addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id)
addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e
| otherwise = addTickLHsExprRHS e | otherwise = addTickLHsExprRHS e
addTickApplicativeArg
:: Maybe (Bool -> BoxLabel) -> (SyntaxExpr Id, ApplicativeArg Id Id)
-> TM (SyntaxExpr Id, ApplicativeArg Id Id)
addTickApplicativeArg isGuard (op, arg) =
liftM2 (,) (addTickSyntaxExpr hpcSrcSpan op) (addTickArg arg)
where
addTickArg (ApplicativeArgOne pat expr) =
ApplicativeArgOne <$> addTickLPat pat <*> addTickLHsExpr expr
addTickArg (ApplicativeArgMany stmts ret pat) =
ApplicativeArgMany
<$> addTickLStmts isGuard stmts
<*> addTickSyntaxExpr hpcSrcSpan ret
<*> addTickLPat pat
addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id
-> TM (ParStmtBlock Id Id) -> TM (ParStmtBlock Id Id)
addTickStmtAndBinders isGuard (ParStmtBlock stmts ids returnExpr) = addTickStmtAndBinders isGuard (ParStmtBlock stmts ids returnExpr) =
...@@ -872,9 +890,10 @@ addTickCmdStmt (BindStmt pat c bind fail) = do ...@@ -872,9 +890,10 @@ addTickCmdStmt (BindStmt pat c bind fail) = do
(addTickLHsCmd c) (addTickLHsCmd c)
(return bind) (return bind)
(return fail) (return fail)
addTickCmdStmt (LastStmt c ret) = do addTickCmdStmt (LastStmt c noret ret) = do
liftM2 LastStmt liftM3 LastStmt
(addTickLHsCmd c) (addTickLHsCmd c)
(pure noret)
(addTickSyntaxExpr hpcSrcSpan ret) (addTickSyntaxExpr hpcSrcSpan ret)
addTickCmdStmt (BodyStmt c bind' guard' ty) = do addTickCmdStmt (BodyStmt c bind' guard' ty) = do
liftM4 BodyStmt liftM4 BodyStmt
...@@ -892,6 +911,8 @@ addTickCmdStmt stmt@(RecStmt {}) ...@@ -892,6 +911,8 @@ addTickCmdStmt stmt@(RecStmt {})
; bind' <- addTickSyntaxExpr hpcSrcSpan (recS_bind_fn stmt) ; bind' <- addTickSyntaxExpr hpcSrcSpan (recS_bind_fn stmt)
; return (stmt { recS_stmts = stmts', recS_ret_fn = ret' ; return (stmt { recS_stmts = stmts', recS_ret_fn = ret'
, recS_mfix_fn = mfix', recS_bind_fn = bind' }) } , recS_mfix_fn = mfix', recS_bind_fn = bind' }) }
addTickCmdStmt ApplicativeStmt{} =
panic "ToDo: addTickCmdStmt ApplicativeLastStmt"
-- Others should never happen in a command context. -- Others should never happen in a command context.
addTickCmdStmt stmt = pprPanic "addTickHsCmd" (ppr stmt) addTickCmdStmt stmt = pprPanic "addTickHsCmd" (ppr stmt)
......
...@@ -18,6 +18,7 @@ import DsMonad ...@@ -18,6 +18,7 @@ import DsMonad
import HsSyn hiding (collectPatBinders, collectPatsBinders, collectLStmtsBinders, collectLStmtBinders, collectStmtBinders ) import HsSyn hiding (collectPatBinders, collectPatsBinders, collectLStmtsBinders, collectLStmtBinders, collectStmtBinders )
import TcHsSyn import TcHsSyn
import qualified HsUtils
-- NB: The desugarer, which straddles the source and Core worlds, sometimes -- NB: The desugarer, which straddles the source and Core worlds, sometimes
-- needs to see source types (newtypes etc), and sometimes not -- needs to see source types (newtypes etc), and sometimes not
...@@ -694,7 +695,7 @@ dsCmdDo _ _ _ [] _ = panic "dsCmdDo" ...@@ -694,7 +695,7 @@ dsCmdDo _ _ _ [] _ = panic "dsCmdDo"
-- --
-- ---> premap (\ (xs) -> ((xs), ())) c -- ---> premap (\ (xs) -> ((xs), ())) c
dsCmdDo ids local_vars res_ty [L _ (LastStmt body _)] env_ids = do dsCmdDo ids local_vars res_ty [L _ (LastStmt body _ _)] env_ids = do
(core_body, env_ids') <- dsLCmd ids local_vars unitTy res_ty body env_ids (core_body, env_ids') <- dsLCmd ids local_vars unitTy res_ty body env_ids
let env_ty = mkBigCoreVarTupTy env_ids let env_ty = mkBigCoreVarTupTy env_ids
env_var <- newSysLocalDs env_ty env_var <- newSysLocalDs env_ty
...@@ -1167,11 +1168,5 @@ collectLStmtBinders :: LStmt Id body -> [Id] ...@@ -1167,11 +1168,5 @@ collectLStmtBinders :: LStmt Id body -> [Id]
collectLStmtBinders = collectStmtBinders . unLoc collectLStmtBinders = collectStmtBinders . unLoc
collectStmtBinders :: Stmt Id body -> [Id] collectStmtBinders :: Stmt Id body -> [Id]
collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
collectStmtBinders (LetStmt binds) = collectLocalBinders binds
collectStmtBinders (BodyStmt {}) = []
collectStmtBinders (LastStmt {}) = []
collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders
$ [ s | ParStmtBlock ss _ _ <- xs, s <- ss]
collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt { recS_later_ids = later_ids }) = later_ids collectStmtBinders (RecStmt { recS_later_ids = later_ids }) = later_ids
collectStmtBinders stmt = HsUtils.collectStmtBinders stmt
...@@ -33,6 +33,7 @@ import TcType ...@@ -33,6 +33,7 @@ import TcType
import Coercion ( Role(..) ) import Coercion ( Role(..) )
import TcEvidence import TcEvidence
import TcRnMonad import TcRnMonad
import TcHsSyn
import Type import Type
import CoreSyn import CoreSyn
import CoreUtils import CoreUtils
...@@ -819,7 +820,7 @@ dsDo stmts ...@@ -819,7 +820,7 @@ dsDo stmts
goL [] = panic "dsDo" goL [] = panic "dsDo"
goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts) goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
go _ (LastStmt body _) stmts go _ (LastStmt body _ _) stmts
= ASSERT( null stmts ) dsLExpr body = ASSERT( null stmts ) dsLExpr body
-- The 'return' op isn't used for 'do' expressions -- The 'return' op isn't used for 'do' expressions
...@@ -846,13 +847,45 @@ dsDo stmts ...@@ -846,13 +847,45 @@ dsDo stmts
; match_code <- handle_failure pat match fail_op ; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) } ; return (mkApps bind_op' [rhs', Lam var match_code]) }
go _ (ApplicativeStmt args mb_join body_ty) stmts
= do {
let
(pats, rhss) = unzip (map (do_arg . snd) args)
do_arg (ApplicativeArgOne pat expr) =
(pat, dsLExpr expr)
do_arg (ApplicativeArgMany stmts ret pat) =
(pat, dsDo (stmts ++ [noLoc $ mkLastStmt (noLoc ret)]))
arg_tys = map hsLPatType pats
; rhss' <- sequence rhss
; ops' <- mapM dsExpr (map fst args)
; let body' = noLoc $ HsDo DoExpr stmts body_ty
; let fun = L noSrcSpan $ HsLam $
MG { mg_alts = [mkSimpleMatch pats body']
, mg_arg_tys = arg_tys
, mg_res_ty = body_ty
, mg_origin = Generated }
; fun' <- dsLExpr fun
; let mk_ap_call l (op,r) = mkApps op [l,r]
expr = foldl mk_ap_call fun' (zip ops' rhss')
; case mb_join of
Nothing -> return expr
Just join_op ->
do { join_op' <- dsExpr join_op
; return (App join_op' expr) } }
go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = return_op , recS_rec_ids = rec_ids, recS_ret_fn = return_op
, recS_mfix_fn = mfix_op, recS_bind_fn = bind_op , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
, recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts , recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts
= goL (new_bind_stmt : stmts) -- rec_ids can be empty; eg rec { print 'x' } = goL (new_bind_stmt : stmts) -- rec_ids can be empty; eg rec { print 'x' }
where where
new_bind_stmt = L loc $ BindStmt (mkBigLHsPatTup later_pats) new_bind_stmt = L loc $ BindStmt (mkBigLHsPatTupId later_pats)
mfix_app bind_op mfix_app bind_op
noSyntaxExpr -- Tuple cannot fail noSyntaxExpr -- Tuple cannot fail
...@@ -865,9 +898,9 @@ dsDo stmts ...@@ -865,9 +898,9 @@ dsDo stmts
mfix_arg = noLoc $ HsLam (MG { mg_alts = [mkSimpleMatch [mfix_pat] body] mfix_arg = noLoc $ HsLam (MG { mg_alts = [mkSimpleMatch [mfix_pat] body]
, mg_arg_tys = [tup_ty], mg_res_ty = body_ty , mg_arg_tys = [tup_ty], mg_res_ty = body_ty
, mg_origin = Generated }) , mg_origin = Generated })
mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTup rec_tup_pats mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTupId rec_tup_pats
body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
ret_app = nlHsApp (noLoc return_op) (mkBigLHsTup rets) ret_app = nlHsApp (noLoc return_op) (mkBigLHsTupId rets)
ret_stmt = noLoc $ mkLastStmt ret_app ret_stmt = noLoc $ mkLastStmt ret_app
-- This LastStmt will be desugared with dsDo, -- This LastStmt will be desugared with dsDo,
-- which ignores the return_op in the LastStmt, -- which ignores the return_op in the LastStmt,
......
...@@ -123,6 +123,8 @@ matchGuards (LastStmt {} : _) _ _ _ = panic "matchGuards LastStmt" ...@@ -123,6 +123,8 @@ matchGuards (LastStmt {} : _) _ _ _ = panic "matchGuards LastStmt"
matchGuards (ParStmt {} : _) _ _ _ = panic "matchGuards ParStmt" matchGuards (ParStmt {} : _) _ _ _ = panic "matchGuards ParStmt"
matchGuards (TransStmt {} : _) _ _ _ = panic "matchGuards TransStmt" matchGuards (TransStmt {} : _) _ _ _ = panic "matchGuards TransStmt"
matchGuards (RecStmt {} : _) _ _ _ = panic "matchGuards RecStmt" matchGuards (RecStmt {} : _) _ _ _ = panic "matchGuards RecStmt"
matchGuards (ApplicativeStmt {} : _) _ _ _ =
panic "matchGuards ApplicativeLastStmt"
isTrueLHsExpr :: LHsExpr Id -> Maybe (CoreExpr -> DsM CoreExpr) isTrueLHsExpr :: LHsExpr Id -> Maybe (CoreExpr -> DsM CoreExpr)
......
...@@ -81,7 +81,7 @@ dsListComp lquals res_ty = do ...@@ -81,7 +81,7 @@ dsListComp lquals res_ty = do
-- and the type of the elements that it outputs (tuples of binders) -- and the type of the elements that it outputs (tuples of binders)
dsInnerListComp :: (ParStmtBlock Id Id) -> DsM (CoreExpr, Type) dsInnerListComp :: (ParStmtBlock Id Id) -> DsM (CoreExpr, Type)
dsInnerListComp (ParStmtBlock stmts bndrs _) dsInnerListComp (ParStmtBlock stmts bndrs _)
= do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)]) = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTupId bndrs)])
(mkListTy bndrs_tuple_type) (mkListTy bndrs_tuple_type)
; return (expr, bndrs_tuple_type) } ; return (expr, bndrs_tuple_type) }
where where
...@@ -133,7 +133,7 @@ dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderM ...@@ -133,7 +133,7 @@ dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderM
-- Build a pattern that ensures the consumer binds into the NEW binders, -- Build a pattern that ensures the consumer binds into the NEW binders,
-- which hold lists rather than single values -- which hold lists rather than single values
let pat = mkBigLHsVarPatTup to_bndrs let pat = mkBigLHsVarPatTupId to_bndrs
return (bound_unzipped_inner_list_expr, pat) return (bound_unzipped_inner_list_expr, pat)
dsTransStmt _ = panic "dsTransStmt: Not given a TransStmt" dsTransStmt _ = panic "dsTransStmt: Not given a TransStmt"
...@@ -208,7 +208,7 @@ deListComp :: [ExprStmt Id] -> CoreExpr -> DsM CoreExpr ...@@ -208,7 +208,7 @@ deListComp :: [ExprStmt Id] -> CoreExpr -> DsM CoreExpr
deListComp [] _ = panic "deListComp" deListComp [] _ = panic "deListComp"
deListComp (LastStmt body _ : quals) list deListComp (LastStmt body _ _ : quals) list
= -- Figure 7.4, SLPJ, p 135, rule C above = -- Figure 7.4, SLPJ, p 135, rule C above
ASSERT( null quals ) ASSERT( null quals )
do { core_body <- dsLExpr body do { core_body <- dsLExpr body
...@@ -246,11 +246,14 @@ deListComp (ParStmt stmtss_w_bndrs _ _ : quals) list ...@@ -246,11 +246,14 @@ deListComp (ParStmt stmtss_w_bndrs _ _ : quals) list
bndrs_s = [bs | ParStmtBlock _ bs _ <- stmtss_w_bndrs] bndrs_s = [bs | ParStmtBlock _ bs _ <- stmtss_w_bndrs]
-- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
pat = mkBigLHsPatTup pats pat = mkBigLHsPatTupId pats
pats = map mkBigLHsVarPatTup bndrs_s pats = map mkBigLHsVarPatTupId bndrs_s
deListComp (RecStmt {} : _) _ = panic "deListComp RecStmt" deListComp (RecStmt {} : _) _ = panic "deListComp RecStmt"
deListComp (ApplicativeStmt {} : _) _ =
panic "deListComp ApplicativeStmt"
deBindComp :: OutPat Id deBindComp :: OutPat Id
-> CoreExpr -> CoreExpr
-> [ExprStmt Id] -> [ExprStmt Id]
...@@ -312,7 +315,7 @@ dfListComp :: Id -> Id -- 'c' and 'n' ...@@ -312,7 +315,7 @@ dfListComp :: Id -> Id -- 'c' and 'n'
dfListComp _ _ [] = panic "dfListComp" dfListComp _ _ [] = panic "dfListComp"
dfListComp c_id n_id (LastStmt body _ : quals) dfListComp c_id n_id (LastStmt body _ _ : quals)
= ASSERT( null quals ) = ASSERT( null quals )
do { core_body <- dsLExpr body do { core_body <- dsLExpr body
; return (mkApps (Var c_id) [core_body, Var n_id]) } ; return (mkApps (Var c_id) [core_body, Var n_id]) }
...@@ -342,6 +345,8 @@ dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do ...@@ -342,6 +345,8 @@ dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do
dfListComp _ _ (ParStmt {} : _) = panic "dfListComp ParStmt" dfListComp _ _ (ParStmt {} : _) = panic "dfListComp ParStmt"
dfListComp _ _ (RecStmt {} : _) = panic "dfListComp RecStmt" dfListComp _ _ (RecStmt {} : _) = panic "dfListComp RecStmt"
dfListComp _ _ (ApplicativeStmt {} : _) =
panic "dfListComp ApplicativeStmt"
dfBindComp :: Id -> Id -- 'c' and 'n' dfBindComp :: Id -> Id -- 'c' and 'n'
-> (LPat Id, CoreExpr) -> (LPat Id, CoreExpr)
...@@ -510,7 +515,7 @@ dePArrComp [] _ _ = panic "dePArrComp" ...@@ -510,7 +515,7 @@ dePArrComp [] _ _ = panic "dePArrComp"
-- --
-- <<[:e' | :]>> pa ea = mapP (\pa -> e') ea -- <<[:e' | :]>> pa ea = mapP (\pa -> e') ea
-- --
dePArrComp (LastStmt e' _ : quals) pa cea dePArrComp (LastStmt e' _ _ : quals) pa cea
= ASSERT( null quals ) = ASSERT( null quals )
do { mapP <- dsDPHBuiltin mapPVar do { mapP <- dsDPHBuiltin mapPVar
; let ty = parrElemType cea ; let ty = parrElemType cea
...@@ -589,6 +594,8 @@ dePArrComp (ParStmt {} : _) _ _ = ...@@ -589,6 +594,8 @@ dePArrComp (ParStmt {} : _) _ _ =
panic "DsListComp.dePArrComp: malformed comprehension AST: ParStmt" panic "DsListComp.dePArrComp: malformed comprehension AST: ParStmt"
dePArrComp (TransStmt {} : _) _ _ = panic "DsListComp.dePArrComp: TransStmt" dePArrComp (TransStmt {} : _) _ _ = panic "DsListComp.dePArrComp: TransStmt"
dePArrComp (RecStmt {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt" dePArrComp (RecStmt {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt"
dePArrComp (ApplicativeStmt {} : _) _ _ =
panic "DsListComp.dePArrComp: ApplicativeStmt"
-- <<[:e' | qs | qss:]>> pa ea = -- <<[:e' | qs | qss:]>> pa ea =
-- <<[:e' | qss:]>> (pa, (x_1, ..., x_n)) -- <<[:e' | qss:]>> (pa, (x_1, ..., x_n))
...@@ -666,7 +673,7 @@ dsMcStmts (L loc stmt : lstmts) = putSrcSpanDs loc (dsMcStmt stmt lstmts) ...@@ -666,7 +673,7 @@ dsMcStmts (L loc stmt : lstmts) = putSrcSpanDs loc (dsMcStmt stmt lstmts)
--------------- ---------------
dsMcStmt :: ExprStmt Id -> [ExprLStmt Id] -> DsM CoreExpr dsMcStmt :: ExprStmt Id -> [ExprLStmt Id] -> DsM CoreExpr
dsMcStmt (LastStmt body ret_op) stmts dsMcStmt (LastStmt body _ ret_op) stmts
= ASSERT( null stmts ) = ASSERT( null stmts )
do { body' <- dsLExpr body do { body' <- dsLExpr body
; ret_op' <- dsExpr ret_op ; ret_op' <- dsExpr ret_op
...@@ -761,7 +768,7 @@ dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest ...@@ -761,7 +768,7 @@ dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest
; mzip_op' <- dsExpr mzip_op ; mzip_op' <- dsExpr mzip_op
; let -- The pattern variables ; let -- The pattern variables
pats = [ mkBigLHsVarPatTup bs | ParStmtBlock _ bs _ <- blocks] pats = [ mkBigLHsVarPatTupId bs | ParStmtBlock _ bs _ <- blocks]
-- Pattern with tuples of variables -- Pattern with tuples of variables
-- [v1,v2,v3] => (v1, (v2, v3)) -- [v1,v2,v3] => (v1, (v2, v3))
pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats
...@@ -834,7 +841,7 @@ dsInnerMonadComp :: [ExprLStmt Id] ...@@ -834,7 +841,7 @@ dsInnerMonadComp :: [ExprLStmt Id]
-> HsExpr Id -- The monomorphic "return" operator -> HsExpr Id -- The monomorphic "return" operator
-> DsM CoreExpr -> DsM CoreExpr
dsInnerMonadComp stmts bndrs ret_op dsInnerMonadComp stmts bndrs ret_op
= dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTup bndrs) ret_op)]) = dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTupId bndrs) False ret_op)])
-- The `unzip` function for `GroupStmt` in a monad comprehensions -- The `unzip` function for `GroupStmt` in a monad comprehensions
-- --
......
...@@ -1279,7 +1279,7 @@ repSts (ParStmt stmt_blocks _ _ : ss) = ...@@ -1279,7 +1279,7 @@ repSts (ParStmt stmt_blocks _ _ : ss) =
do { (ss1, zs) <- repSts (map unLoc stmts) do { (ss1, zs) <- repSts (map unLoc stmts)
; zs1 <- coreList stmtQTyConName zs ; zs1 <- coreList stmtQTyConName zs
; return (ss1, zs1) } ; return (ss1, zs1) }
repSts [LastStmt e _] repSts [LastStmt e _ _]
= do { e2 <- repLE e = do { e2 <- repLE e
; z <- repNoBindSt e2 ; z <- repNoBindSt e2
; return ([], [z]) } ; return ([], [z]) }
......
...@@ -30,7 +30,7 @@ module DsUtils ( ...@@ -30,7 +30,7 @@ module DsUtils (
-- LHs tuples -- LHs tuples
mkLHsVarPatTup, mkLHsPatTup, mkVanillaTuplePat, mkLHsVarPatTup, mkLHsPatTup, mkVanillaTuplePat,
mkBigLHsVarTup, mkBigLHsTup, mkBigLHsVarPatTup, mkBigLHsPatTup, mkBigLHsVarTupId, mkBigLHsTupId, mkBigLHsVarPatTupId, mkBigLHsPatTupId,
mkSelectorBinds, mkSelectorBinds,
...@@ -717,18 +717,18 @@ mkVanillaTuplePat :: [OutPat Id] -> Boxity -> Pat Id ...@@ -717,18 +717,18 @@ mkVanillaTuplePat :: [OutPat Id] -> Boxity -> Pat Id
mkVanillaTuplePat pats box = TuplePat pats box (map hsLPatType pats) mkVanillaTuplePat pats box = TuplePat pats box (map hsLPatType pats)
-- The Big equivalents for the source tuple expressions -- The Big equivalents for the source tuple expressions
mkBigLHsVarTup :: [Id] -> LHsExpr Id mkBigLHsVarTupId :: [Id] -> LHsExpr Id
mkBigLHsVarTup ids = mkBigLHsTup (map nlHsVar ids) mkBigLHsVarTupId ids = mkBigLHsTupId (map nlHsVar ids)
mkBigLHsTup :: [LHsExpr Id] -> LHsExpr Id mkBigLHsTupId :: [LHsExpr Id] -> LHsExpr Id
mkBigLHsTup = mkChunkified mkLHsTupleExpr mkBigLHsTupId = mkChunkified mkLHsTupleExpr
-- The Big equivalents for the source tuple patterns -- The Big equivalents for the source tuple patterns
mkBigLHsVarPatTup :: [Id] -> LPat Id mkBigLHsVarPatTupId :: [Id] -> LPat Id
mkBigLHsVarPatTup bs = mkBigLHsPatTup (map nlVarPat bs) mkBigLHsVarPatTupId bs = mkBigLHsPatTupId (map nlVarPat bs)
mkBigLHsPatTup :: [LPat Id] -> LPat Id mkBigLHsPatTupId :: [LPat Id] -> LPat Id
mkBigLHsPatTup = mkChunkified mkLHsPatTup mkBigLHsPatTupId = mkChunkified mkLHsPatTup
{- {-
************************************************************************ ************************************************************************
......
...@@ -39,6 +39,7 @@ import Type ...@@ -39,6 +39,7 @@ import Type
-- libraries: -- libraries:
import Data.Data hiding (Fixity) import Data.Data hiding (Fixity)
import Data.Maybe (isNothing)
{- {-
************************************************************************ ************************************************************************
...@@ -1266,12 +1267,15 @@ data StmtLR idL idR body -- body should always be (LHs**** idR) ...@@ -1266,12 +1267,15 @@ data StmtLR idL idR body -- body should always be (LHs**** idR)
= LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp,
-- and (after the renamer) DoExpr, MDoExpr -- and (after the renamer) DoExpr, MDoExpr
-- Not used for GhciStmtCtxt, PatGuard, which scope over other stuff -- Not used for GhciStmtCtxt, PatGuard, which scope over other stuff
body body
(SyntaxExpr idR) -- The return operator, used only for MonadComp Bool -- True <=> return was stripped by ApplicativeDo
-- For ListComp, PArrComp, we use the baked-in 'return' (SyntaxExpr idR) -- The return operator, used only for
-- For DoExpr, MDoExpr, we don't apply a 'return' at all -- MonadComp For ListComp, PArrComp, we
-- See Note [Monad Comprehensions] -- use the baked-in 'return' For DoExpr,
-- | - 'ApiAnnotation.AnnKeywordId' : 'ApiAnnotation.AnnLarrow' -- MDoExpr, we don't apply a 'return' at
-- all See Note [Monad Comprehensions] |
-- - 'ApiAnnotation.AnnKeywordId' :
-- 'ApiAnnotation.AnnLarrow'
-- For details on above see note [Api annotations] in ApiAnnotation -- For details on above see note [Api annotations] in ApiAnnotation
| BindStmt (LPat idL) | BindStmt (LPat idL)
...@@ -1281,6 +1285,20 @@ data StmtLR idL idR body -- body should always be (LHs**** idR) ...@@ -1281,6 +1285,20 @@ data StmtLR idL idR body -- body should always be (LHs**** idR)
-- The fail operator is noSyntaxExpr -- The fail operator is noSyntaxExpr
-- if the pattern match can't fail -- if the pattern match can't fail
-- | 'ApplicativeStmt' represents an applicative expression built with
-- <$> and <*>. It is generated by the renamer, and is desugared into the
-- appropriate applicative expression by the desugarer, but it is intended
-- to be invisible in error messages.
--
-- For full details, see Note [ApplicativeDo] in RnExpr
--
| ApplicativeStmt
[ ( SyntaxExpr idR
, ApplicativeArg idL idR) ]
-- [(<$>, e1), (<*>, e2), ..., (<*>, en)]
(Maybe (SyntaxExpr idR)) -- 'join', if necessary
(PostTc idR Type) -- Type of the body
| BodyStmt body -- See Note [BodyStmt] | BodyStmt body -- See Note [BodyStmt]
(SyntaxExpr idR) -- The (>>) operator (SyntaxExpr idR) -- The (>>) operator
(SyntaxExpr idR) -- The `guard` operator; used only in MonadComp (SyntaxExpr idR) -- The `guard` operator; used only in MonadComp
...@@ -1375,6 +1393,17 @@ data ParStmtBlock idL idR ...@@ -1375,6 +1393,17 @@ data ParStmtBlock idL idR
deriving( Typeable ) deriving( Typeable )
deriving instance (DataId idL, DataId idR) => Data (ParStmtBlock idL idR) deriving instance (DataId idL, DataId idR) => Data (ParStmtBlock idL idR)
data ApplicativeArg idL idR
= ApplicativeArgOne -- pat <- expr (pat must be irrefutable)
(LPat idL)
(LHsExpr idL)
| ApplicativeArgMany -- do { stmts; return vars }
[ExprLStmt idL] -- stmts
(SyntaxExpr idL) -- return (v1,..,vn), or just (v1,..,vn)
(LPat idL) -- (v1,...,vn)
deriving( Typeable )
deriving instance (DataId idL, DataId idR) => Data (ApplicativeArg idL idR)
{- {-
Note [The type of bind in Stmts] Note [The type of bind in Stmts]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -1520,9 +1549,12 @@ instance (OutputableBndr idL, OutputableBndr idR, Outputable body) ...@@ -1520,9 +1549,12 @@ instance (OutputableBndr idL, OutputableBndr idR, Outputable body)
=> Outputable (StmtLR idL idR body) where => Outputable (StmtLR idL idR body) where
ppr stmt = pprStmt stmt ppr stmt = pprStmt stmt
pprStmt :: (OutputableBndr idL, OutputableBndr idR, Outputable body) pprStmt :: forall idL idR body . (OutputableBndr idL, OutputableBndr idR, Outputable body)
=> (StmtLR idL idR body) -> SDoc => (StmtLR idL idR body) -> SDoc
pprStmt (LastStmt expr _) = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr pprStmt (LastStmt expr ret_stripped _)
= ifPprDebug (ptext (sLit "[last]")) <+>
(if ret_stripped then ptext (sLit "return") else empty) <+>
ppr expr
pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, larrow, ppr expr] pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, larrow, ppr expr]
pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds]
pprStmt (BodyStmt expr _ _ _) = ppr expr pprStmt (BodyStmt expr _ _ _) = ppr expr
...@@ -1538,6 +1570,45 @@ pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids ...@@ -1538,6 +1570,45 @@ pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids
, ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids , ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids
, ptext (sLit "later_ids=") <> ppr later_ids])] , ptext (sLit "later_ids=") <> ppr later_ids])]
pprStmt (ApplicativeStmt args mb_join _)
= getPprStyle $ \style ->
if userStyle style
then pp_for_user
else pp_debug
where
-- make all the Applicative stuff invisible in error messages by
-- flattening the whole ApplicativeStmt nest back to a sequence
-- of statements.
pp_for_user = vcat $ punctuate semi $ concatMap flattenArg args
-- ppr directly rather than transforming here, becuase we need to
-- inject a "return" which is hard when we're polymorphic in the id