Commit 478e69b3 authored by Simon Peyton Jones's avatar Simon Peyton Jones

Preliminary monad-comprehension patch (Trac #4370)

This is the work of Nils Schweinsberg <mail@n-sch.de>

It adds the language extension -XMonadComprehensions, which
generalises list comprehension syntax [ e | x <- xs] to work over
arbitrary monads.
parent 66a733f2
......@@ -301,10 +301,11 @@ addTickHsExpr (HsLet binds e) =
liftM2 HsLet
(addTickHsLocalBinds binds) -- to think about: !patterns.
(addTickLHsExprNeverOrAlways e)
addTickHsExpr (HsDo cxt stmts last_exp srcloc) = do
addTickHsExpr (HsDo cxt stmts last_exp return_exp srcloc) = do
(stmts', last_exp') <- addTickLStmts' forQual stmts
(addTickLHsExpr last_exp)
return (HsDo cxt stmts' last_exp' srcloc)
return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp
return (HsDo cxt stmts' last_exp' return_exp' srcloc)
where
forQual = case cxt of
ListComp -> Just $ BinBox QualBinBox
......@@ -438,31 +439,38 @@ addTickStmt _isGuard (BindStmt pat e bind fail) = do
(addTickLHsExprAlways e)
(addTickSyntaxExpr hpcSrcSpan bind)
(addTickSyntaxExpr hpcSrcSpan fail)
addTickStmt isGuard (ExprStmt e bind' ty) = do
liftM3 ExprStmt
addTickStmt isGuard (ExprStmt e bind' guard' ty) = do
liftM4 ExprStmt
(addTick isGuard e)
(addTickSyntaxExpr hpcSrcSpan bind')
(addTickSyntaxExpr hpcSrcSpan guard')
(return ty)
addTickStmt _isGuard (LetStmt binds) = do
liftM LetStmt
(addTickHsLocalBinds binds)
addTickStmt isGuard (ParStmt pairs) = do
liftM ParStmt
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do
liftM4 ParStmt
(mapM (addTickStmtAndBinders isGuard) pairs)
addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr) = do
liftM4 TransformStmt
(addTickLStmts isGuard stmts)
(return ids)
(addTickLHsExprAlways usingExpr)
(addTickMaybeByLHsExpr maybeByExpr)
addTickStmt isGuard (GroupStmt stmts binderMap by using) = do
liftM4 GroupStmt
(addTickLStmts isGuard stmts)
(return binderMap)
(fmapMaybeM addTickLHsExprAlways by)
(fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using)
(addTickSyntaxExpr hpcSrcSpan mzipExpr)
(addTickSyntaxExpr hpcSrcSpan bindExpr)
(addTickSyntaxExpr hpcSrcSpan returnExpr)
addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr returnExpr bindExpr) = do
t_s <- (addTickLStmts isGuard stmts)
t_u <- (addTickLHsExprAlways usingExpr)
t_m <- (addTickMaybeByLHsExpr maybeByExpr)
t_r <- (addTickSyntaxExpr hpcSrcSpan returnExpr)
t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr)
return $ TransformStmt t_s ids t_u t_m t_r t_b
addTickStmt isGuard (GroupStmt stmts binderMap by using returnExpr bindExpr liftMExpr) = do
t_s <- (addTickLStmts isGuard stmts)
t_y <- (fmapMaybeM addTickLHsExprAlways by)
t_u <- (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using)
t_f <- (addTickSyntaxExpr hpcSrcSpan returnExpr)
t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr)
t_m <- (addTickSyntaxExpr hpcSrcSpan liftMExpr)
return $ GroupStmt t_s binderMap t_y t_u t_b t_f t_m
addTickStmt isGuard stmt@(RecStmt {})
= do { stmts' <- addTickLStmts isGuard (recS_stmts stmt)
......@@ -569,9 +577,10 @@ addTickHsCmd (HsLet binds c) =
liftM2 HsLet
(addTickHsLocalBinds binds) -- to think about: !patterns.
(addTickLHsCmd c)
addTickHsCmd (HsDo cxt stmts last_exp srcloc) = do
addTickHsCmd (HsDo cxt stmts last_exp return_exp srcloc) = do
(stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp)
return (HsDo cxt stmts' last_exp' srcloc)
return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp
return (HsDo cxt stmts' last_exp' return_exp' srcloc)
addTickHsCmd (HsArrApp e1 e2 ty1 arr_ty lr) =
liftM5 HsArrApp
......@@ -635,10 +644,11 @@ addTickCmdStmt (BindStmt pat c bind fail) = do
(addTickLHsCmd c)
(return bind)
(return fail)
addTickCmdStmt (ExprStmt c bind' ty) = do
liftM3 ExprStmt
addTickCmdStmt (ExprStmt c bind' guard' ty) = do
liftM4 ExprStmt
(addTickLHsCmd c)
(return bind')
(addTickSyntaxExpr hpcSrcSpan bind')
(addTickSyntaxExpr hpcSrcSpan guard')
(return ty)
addTickCmdStmt (LetStmt binds) = do
liftM LetStmt
......
......@@ -541,7 +541,7 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do
core_body,
exprFreeVars core_binds `intersectVarSet` local_vars)
dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _)
dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _ _)
= dsCmdDo ids local_vars env_ids res_ty stmts body
-- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t
......@@ -674,7 +674,7 @@ dsCmdStmt
-- ---> arr (\ (xs) -> ((xs1),(xs'))) >>> first c >>>
-- arr snd >>> ss
dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ c_ty) = do
dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ _ c_ty) = do
(core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars [] c_ty cmd
core_mux <- matchEnvStack env_ids []
(mkCorePairExpr (mkBigCoreVarTup env_ids1) (mkBigCoreVarTup out_ids))
......
......@@ -325,22 +325,25 @@ dsExpr (HsLet binds body) = do
-- We need the `ListComp' form to use `deListComp' (rather than the "do" form)
-- because the interpretation of `stmts' depends on what sort of thing it is.
--
dsExpr (HsDo ListComp stmts body result_ty)
dsExpr (HsDo ListComp stmts body _ result_ty)
= -- Special case for list comprehensions
dsListComp stmts body elt_ty
where
[elt_ty] = tcTyConAppArgs result_ty
dsExpr (HsDo DoExpr stmts body result_ty)
dsExpr (HsDo DoExpr stmts body _ result_ty)
= dsDo stmts body result_ty
dsExpr (HsDo GhciStmt stmts body result_ty)
dsExpr (HsDo GhciStmt stmts body _ result_ty)
= dsDo stmts body result_ty
dsExpr (HsDo MDoExpr stmts body result_ty)
dsExpr (HsDo MDoExpr stmts body _ result_ty)
= dsDo stmts body result_ty
dsExpr (HsDo PArrComp stmts body result_ty)
dsExpr (HsDo MonadComp stmts body return_op result_ty)
= dsMonadComp stmts return_op body result_ty
dsExpr (HsDo PArrComp stmts body _ result_ty)
= -- Special case for array comprehensions
dsPArrComp (map unLoc stmts) body elt_ty
where
......@@ -722,7 +725,7 @@ dsDo stmts body result_ty
goL [] = dsLExpr body
goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
go _ (ExprStmt rhs then_expr _) stmts
go _ (ExprStmt rhs then_expr _ _) stmts
= do { rhs2 <- dsLExpr rhs
; case tcSplitAppTy_maybe (exprType rhs2) of
Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty
......@@ -769,7 +772,7 @@ dsDo stmts body result_ty
mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
(mkFunTy tup_ty body_ty))
mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
body = noLoc $ HsDo DoExpr rec_stmts return_app body_ty
body = noLoc $ HsDo DoExpr rec_stmts return_app noSyntaxExpr body_ty
return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
body_ty = mkAppTy m_ty tup_ty
tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
......@@ -869,7 +872,7 @@ dsMDo ctxt tbl stmts body result_ty
rets = map nlHsVar later_ids' ++ map noLoc rec_rets
mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats
body = noLoc $ HsDo ctxt rec_stmts return_app body_ty
body = noLoc $ HsDo ctxt rec_stmts return_app noSyntaxExpr body_ty
body_ty = mkAppTy m_ty tup_ty
tup_ty = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids)) -- Deals with singleton case
......@@ -888,7 +891,6 @@ dsMDo ctxt tbl stmts body result_ty
-}
\end{code}
%************************************************************************
%* *
Warning about identities
......
......@@ -106,11 +106,11 @@ matchGuards [] _ rhs _
-- NB: The success of this clause depends on the typechecker not
-- wrapping the 'otherwise' in empty HsTyApp or HsWrap constructors
-- If it does, you'll get bogus overlap warnings
matchGuards (ExprStmt e _ _ : stmts) ctx rhs rhs_ty
matchGuards (ExprStmt e _ _ _ : stmts) ctx rhs rhs_ty
| Just addTicks <- isTrueLHsExpr e = do
match_result <- matchGuards stmts ctx rhs rhs_ty
return (adjustMatchResultDs addTicks match_result)
matchGuards (ExprStmt expr _ _ : stmts) ctx rhs rhs_ty = do
matchGuards (ExprStmt expr _ _ _ : stmts) ctx rhs rhs_ty = do
match_result <- matchGuards stmts ctx rhs rhs_ty
pred_expr <- dsLExpr expr
return (mkGuardedMatchResult pred_expr match_result)
......
......@@ -3,9 +3,10 @@
% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
%
Desugaring list comprehensions and array comprehensions
Desugaring list comprehensions, monad comprehensions and array comprehensions
\begin{code}
{-# LANGUAGE NamedFieldPuns #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}
-- The above warning supression flag is a temporary kludge.
-- While working on this module you are encouraged to remove it and fix
......@@ -13,11 +14,11 @@ Desugaring list comprehensions and array comprehensions
-- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
-- for details
module DsListComp ( dsListComp, dsPArrComp ) where
module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where
#include "HsVersions.h"
import {-# SOURCE #-} DsExpr ( dsLExpr, dsLocalBinds )
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import HsSyn
import TcHsSyn
......@@ -37,6 +38,7 @@ import PrelNames
import SrcLoc
import Outputable
import FastString
import TcType
\end{code}
List comprehensions may be desugared in one of two ways: ``ordinary''
......@@ -72,8 +74,8 @@ dsListComp lquals body elt_ty = do
-- mix of possibly a single element in length, so we do this to leave the possibility open
isParallelComp = any isParallelStmt
isParallelStmt (ParStmt _) = True
isParallelStmt _ = False
isParallelStmt (ParStmt _ _ _ _) = True
isParallelStmt _ = False
-- This function lets you desugar a inner list comprehension and a list of the binders
......@@ -92,7 +94,7 @@ dsInnerListComp (stmts, bndrs) = do
-- Given such a statement it gives you back an expression representing how to compute the transformed
-- list and the tuple that you need to bind from that list in order to proceed with your desugaring
dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr)
dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr _ _)
= do { (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders)
; usingExpr' <- dsLExpr usingExpr
......@@ -116,7 +118,7 @@ dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr)
-- Given such a statement it gives you back an expression representing how to compute the transformed
-- list and the tuple that you need to bind from that list in order to proceed with your desugaring
dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
dsGroupStmt (GroupStmt stmts binderMap by using) = do
dsGroupStmt (GroupStmt stmts binderMap by using _ _ _) = do
let (fromBinders, toBinders) = unzip binderMap
fromBindersTypes = map idType fromBinders
......@@ -228,7 +230,7 @@ with the Unboxed variety.
deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr
deListComp (ParStmt stmtss_w_bndrs : quals) body list
deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list
= do
exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
let (exps, qual_tys) = unzip exps_and_qual_tys
......@@ -252,7 +254,7 @@ deListComp [] body list = do -- Figure 7.4, SLPJ, p 135, rule C above
return (mkConsExpr (exprType core_body) core_body list)
-- Non-last: must be a guard
deListComp (ExprStmt guard _ _ : quals) body list = do -- rule B above
deListComp (ExprStmt guard _ _ _ : quals) body list = do -- rule B above
core_guard <- dsLExpr guard
core_rest <- deListComp quals body list
return (mkIfThenElse core_guard core_rest list)
......@@ -344,7 +346,7 @@ dfListComp c_id n_id [] body = do
return (mkApps (Var c_id) [core_body, Var n_id])
-- Non-last: must be a guard
dfListComp c_id n_id (ExprStmt guard _ _ : quals) body = do
dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) body = do
core_guard <- dsLExpr guard
core_rest <- dfListComp c_id n_id quals body
return (mkIfThenElse core_guard core_rest (Var n_id))
......@@ -501,7 +503,7 @@ dsPArrComp :: [Stmt Id]
-> LHsExpr Id
-> Type -- Don't use; called with `undefined' below
-> DsM CoreExpr
dsPArrComp [ParStmt qss] body _ = -- parallel comprehension
dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension
dePArrParComp qss body
-- Special case for simple generators:
......@@ -550,7 +552,7 @@ dePArrComp [] e' pa cea = do
--
-- <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea)
--
dePArrComp (ExprStmt b _ _ : qs) body pa cea = do
dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do
filterP <- dsLookupDPHId filterPName
let ty = parrElemType cea
(clam,_) <- deLambda ty pa b
......@@ -616,7 +618,7 @@ dePArrComp (LetStmt ds : qs) body pa cea = do
-- singeltons qualifier lists, which we already special case in the caller.
-- So, encountering one here is a bug.
--
dePArrComp (ParStmt _ : _) _ _ _ =
dePArrComp (ParStmt _ _ _ _ : _) _ _ _ =
panic "DsListComp.dePArrComp: malformed comprehension AST"
-- <<[:e' | qs | qss:]>> pa ea =
......@@ -682,3 +684,341 @@ parrElemType e =
_ -> panic
"DsListComp.parrElemType: not a parallel array type"
\end{code}
Translation for monad comprehensions
\begin{code}
-- | Keep the "context" of a monad comprehension in a small data type to avoid
-- some boilerplate...
data DsMonadComp = DsMonadComp
{ mc_return :: Either (SyntaxExpr Id) (Expr CoreBndr)
, mc_body :: LHsExpr Id
, mc_m_ty :: Type
}
--
-- Entry point for monad comprehension desugaring
--
dsMonadComp :: [LStmt Id] -- the statements
-> SyntaxExpr Id -- the "return" function
-> LHsExpr Id -- the body
-> Type -- the final type
-> DsM CoreExpr
dsMonadComp stmts return_op body res_ty
= dsMcStmts stmts (DsMonadComp (Left return_op) body m_ty)
where
(m_ty, _) = tcSplitAppTy res_ty
dsMcStmts :: [LStmt Id]
-> DsMonadComp
-> DsM CoreExpr
-- No statements left for desugaring. Desugar the body after calling "return"
-- on it.
dsMcStmts [] DsMonadComp { mc_return, mc_body }
= case mc_return of
Left ret -> dsLExpr $ noLoc ret `nlHsApp` mc_body
Right ret' -> do
{ body' <- dsLExpr mc_body
; return $ mkApps ret' [body'] }
-- Otherwise desugar each statement step by step
dsMcStmts ((L loc stmt) : lstmts) mc
= putSrcSpanDs loc (dsMcStmt stmt lstmts mc)
dsMcStmt :: Stmt Id
-> [LStmt Id]
-> DsMonadComp
-> DsM CoreExpr
-- [ .. | let binds, stmts ]
dsMcStmt (LetStmt binds) stmts mc
= do { rest <- dsMcStmts stmts mc
; dsLocalBinds binds rest }
-- [ .. | a <- m, stmts ]
dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts mc
= do { rhs' <- dsLExpr rhs
; dsMcBindStmt pat rhs' bind_op fail_op stmts mc }
-- Apply `guard` to the `exp` expression
--
-- [ .. | exp, stmts ]
--
dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc
= do { exp' <- dsLExpr exp
; guard_exp' <- dsExpr guard_exp
; then_exp' <- dsExpr then_exp
; rest <- dsMcStmts stmts mc
; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
, rest ] }
-- Transform statements desugar like this:
--
-- [ .. | qs, then f by e ] -> f (\q_v -> e) [| qs |]
--
-- where [| qs |] is the desugared inner monad comprehenion generated by the
-- statements `qs`.
dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest mc
= do { (expr, _) <- dsInnerMonadComp (stmts, binders) (mc { mc_return = Left return_op })
; let binders_tuple_type = mkBigCoreTupTy $ map idType binders
; usingExpr' <- dsLExpr usingExpr
; using_args <- case maybeByExpr of
Nothing -> return [expr]
Just byExpr -> do
byExpr' <- dsLExpr byExpr
us <- newUniqueSupply
tuple_binder <- newSysLocalDs binders_tuple_type
let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder)
return [Lam tuple_binder byExprWrapper, expr]
; let pat = mkBigLHsVarPatTup binders
rhs = mkApps usingExpr' ((Type binders_tuple_type) : using_args)
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
-- Group statements desugar like this:
--
-- [| q, then group by e using f |] -> (f (\q_v -> e) [| q |]) >>= (return . (unzip q_v))
--
-- which is equal to
--
-- [| q, then group by e using f |] -> liftM (unzip q_v) (f (\q_v -> e) [| q |])
--
-- where unzip is of the form
--
-- unzip :: m (a,b,c,..) -> (m a,m b,m c,..)
-- unzip m_tuple = ( liftM selN1 m_tuple
-- , liftM selN2 m_tuple
-- , .. )
-- where selN1 (a,b,c,..) = a
-- selN2 (a,b,c,..) = b
-- ..
--
dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_rest mc
= do { let (fromBinders, toBinders) = unzip binderMap
fromBindersTypes = map idType fromBinders
fromBindersTupleTy = mkBigCoreTupTy fromBindersTypes
toBindersTypes = map idType toBinders
toBindersTupleTy = mkBigCoreTupTy toBindersTypes
m_ty = mc_m_ty mc
-- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
; (expr, _) <- dsInnerMonadComp (stmts, fromBinders) (mc { mc_return = Left return_op })
-- Work out what arguments should be supplied to that expression: i.e. is an extraction
-- function required? If so, create that desugared function and add to arguments
; usingExpr' <- dsLExpr (either id noLoc using)
; usingArgs <- case by of
Nothing -> return [expr]
Just by_e -> do { by_e' <- dsLExpr by_e
; us <- newUniqueSupply
; from_tup_id <- newSysLocalDs fromBindersTupleTy
; let by_wrap = mkTupleCase us fromBinders by_e'
from_tup_id (Var from_tup_id)
; return [Lam from_tup_id by_wrap, expr] }
-- Create an unzip function for the appropriate arity and element types
; liftM_op' <- dsExpr liftM_op
; (unzip_fn, unzip_rhs) <- mkMcUnzipM liftM_op' m_ty fromBindersTypes
-- Generate the expressions to build the grouped list
; let -- First we apply the grouping function to the inner monad
inner_monad_expr = mkApps usingExpr' ((Type fromBindersTupleTy) : usingArgs)
-- Then we map our "unzip" across it to turn the "monad of tuples" into "tuples of monads"
-- We make sure we instantiate the type variable "a" to be a "monad of 'from' tuples" and
-- the "b" to be a "tuple of 'to' monads"!
unzipped_inner_monad_expr = mkApps liftM_op' -- !
-- Types:
[ Type (m_ty `mkAppTy` fromBindersTupleTy), Type toBindersTupleTy
-- And arguments:
, Var unzip_fn, inner_monad_expr ]
-- Then finally we bind the unzip function around that expression
bound_unzipped_inner_monad_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_monad_expr
-- Build a pattern that ensures the consumer binds into the NEW binders, which hold monads
-- rather than single values
; let pat = mkBigLHsVarPatTup toBinders
rhs = bound_unzipped_inner_monad_expr
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel
-- statements, for example:
--
-- [ body | qs1 | qs2 | qs3 ]
-- -> [ body | (bndrs1, (bndrs2, bndrs3)) <- mzip qs1 (mzip qs2 qs3) ]
--
-- where `mzip` is of the form
--
-- mzip :: m a -> m b -> m (a,b)
--
dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc
= do { -- Get types for `return`
return_op' <- dsExpr return_op
; let pairs_with_return = map (\tp@(_,b) -> (mkReturn b,tp)) pairs
mkReturn bndrs = mkApps return_op' [Type (mkBigCoreTupTy (map idType bndrs))]
; pairs' <- mapM (\(r,tp) -> dsInnerMonadComp tp mc{mc_return = Right r})
pairs_with_return
; let (exps, _qual_tys) = unzip pairs'
-- Types of our `Id`s are getting messed up by `dsInnerMonadComp`
-- so we construct them by hand:
qual_tys = map (mkBigCoreTupTy . map idType . snd) pairs
; mzip_op' <- dsExpr mzip_op
; (zip_fn, zip_rhs) <- mkMcZipM mzip_op' (mc_m_ty mc) qual_tys
; let -- The pattern variables
vars = map (mkBigLHsVarPatTup . snd) pairs
-- Pattern with tuples of variables
-- [v1,v2,v3] => (v1, (v2, v3))
pat = foldr (\tn tm -> mkBigLHsPatTup [tn, tm]) (last vars) (init vars)
rhs = Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
dsMcStmt stmt _ _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
-- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a
-- desugared `CoreExpr`
dsMcBindStmt :: LPat Id
-> CoreExpr -- ^ the desugared rhs of the bind statement
-> SyntaxExpr Id
-> SyntaxExpr Id
-> [LStmt Id]
-> DsMonadComp
-> DsM CoreExpr
dsMcBindStmt pat rhs' bind_op fail_op stmts mc
= do { body <- dsMcStmts stmts mc
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
where
-- In a monad comprehension expression, pattern-match failure just calls
-- the monadic `fail` rather than throwing an exception
handle_failure pat match fail_op
| matchCanFail match
= do { fail_op' <- dsExpr fail_op
; fail_msg <- mkStringExpr (mk_fail_msg pat)
; extractMatchResult match (App fail_op' fail_msg) }
| otherwise
= extractMatchResult match (error "It can't fail")
mk_fail_msg :: Located e -> String
mk_fail_msg pat = "Pattern match failure in monad comprehension at " ++
showSDoc (ppr (getLoc pat))
-- Desugar nested monad comprehensions, for example in `then..` constructs
dsInnerMonadComp :: ([LStmt Id], [Id])
-> DsMonadComp
-> DsM (CoreExpr, Type)
dsInnerMonadComp (stmts, bndrs) DsMonadComp{ mc_return, mc_m_ty }
= do { expr <- dsMcStmts stmts mc'
; return (expr, bndrs_tuple_type) }
where
bndrs_types = map idType bndrs
bndrs_tuple_type = mkAppTy mc_m_ty $ mkBigCoreTupTy bndrs_types
mc' = DsMonadComp mc_return (mkBigLHsVarTup bndrs) mc_m_ty
-- The `unzip` function for `GroupStmt` in a monad comprehensions
--
-- unzip :: m (a,b,..) -> (m a,m b,..)
-- unzip m_tuple = ( liftM selN1 m_tuple
-- , liftM selN2 m_tuple
-- , .. )
--
-- mkMcUnzipM m [t1, t2]
-- = (unzip_fn, \ys :: m (t1, t2) ->
-- ( liftM (selN1 :: (t1, t2) -> t1) ys
-- , liftM (selN2 :: (t1, t2) -> t2) ys
-- ))
--
mkMcUnzipM :: CoreExpr
-> Type -- m
-> [Type] -- [a,b,c,..]
-> DsM (Id, CoreExpr)
mkMcUnzipM liftM_op m_ty elt_tys
= do { ys <- newSysLocalDs monad_tuple_ty
; xs <- mapM newSysLocalDs elt_tys
; scrut <- newSysLocalDs tuple_tys
; unzip_fn <- newSysLocalDs unzip_fn_ty
; let -- Select one Id from our tuple
selectExpr n = mkLams [scrut] $ mkTupleSelector xs (xs !! n) scrut (Var scrut)
-- Apply 'selectVar' and 'ys' to 'liftM'
tupleElem n = mkApps liftM_op
-- Types (m is figured out by the type checker):
-- liftM :: forall a b. (a -> b) -> m a -> m b
[ Type tuple_tys, Type (elt_tys !! n)
-- Arguments:
, selectExpr n, Var ys ]
-- The final expression with the big tuple
unzip_body = mkBigCoreTup [ tupleElem n | n <- [0..length elt_tys - 1] ]
; return (unzip_fn, mkLams [ys] unzip_body) }
where monad_tys = map (m_ty `mkAppTy`) elt_tys -- [m a,m b,m c,..]
tuple_monad_tys = mkBigCoreTupTy monad_tys -- (m a,m b,m c,..)
tuple_tys = mkBigCoreTupTy elt_tys -- (a,b,c,..)
monad_tuple_ty = m_ty `mkAppTy` tuple_tys -- m (a,b,c,..)
unzip_fn_ty = monad_tuple_ty `mkFunTy` tuple_monad_tys -- m (a,b,c,..) -> (m a,m b,m c,..)
-- Generate the `mzip` function for `ParStmt` in monad comprehensions, for
-- example:
--
-- mzip :: m t1
-- -> (m t2 -> m t3 -> m (t2, t3))
-- -> m (t1, (t2, t3))
--
-- mkMcZipM m [t1, t2, t3]
-- = (zip_fn, \(q1::t1) (q2::t2) (q3::t3) ->
-- mzip q1 (mzip q2 q3))
--
mkMcZipM :: CoreExpr
-> Type
-> [Type]
-> DsM (Id, CoreExpr)
mkMcZipM mzip_op m_ty tys@(_:_:_) -- min. 2 types
= do { (ids, t1, tuple_ty, zip_body) <- loop tys
; zip_fn <- newSysLocalDs $
(m_ty `mkAppTy` t1)
`mkFunTy`
(m_ty `mkAppTy` tuple_ty)
`mkFunTy`
(m_ty `mkAppTy` mkBigCoreTupTy [t1, tuple_ty])
; return (zip_fn, mkLams ids zip_body) }
where
-- loop :: [Type] -> DsM ([Id], Type, [Type], CoreExpr)
loop [t1, t2] = do -- last run of the `loop`
{ ids@[a,b] <- newSysLocalsDs (map (m_ty `mkAppTy`) [t1,t2])
; let zip_body = mkApps mzip_op [ Type t1, Type t2 , Var a, Var b ]
; return (ids, t1, t2, zip_body) }
loop (t1:tr) = do
{ -- Get ty, ids etc from the "inner" zip
(ids', t1', t2', zip_body') <- loop tr
; a <- newSysLocalDs $ m_ty `mkAppTy` t1
; let tuple_ty' = mkBigCoreTupTy [t1', t2']
zip_body = mkApps mzip_op [ Type t1, Type tuple_ty', Var a, zip_body' ]
; return ((a:ids'), t1, tuple_ty', zip_body) }
-- This case should never happen:
mkMcZipM _ _ tys = pprPanic "mkMcZipM: unexpected argument" (ppr tys)
\end{code}
......@@ -721,7 +721,7 @@ repE (HsLet bs e) = do { (ss,ds) <- repBinds bs
; wrapGenSyms ss z }
-- FIXME: I haven't got the types here right yet
repE e@(HsDo ctxt sts body _)
repE e@(HsDo ctxt sts body _ _)
| case ctxt of { DoExpr -> True; GhciStmt -> True; _ -> False }
= do { (ss,zs) <- repLSts sts;
body' <- addBinds ss $ repLE body;
......@@ -737,7 +737,7 @@ repE e@(HsDo ctxt sts body _)
wrapGenSyms ss e' }
| otherwise