Commit e01036f8 authored by Simon Peyton Jones's avatar Simon Peyton Jones
Browse files

More hacking on monad-comp

Lots of refactoring. In particular I have now combined
TansformStmt and GroupStmt into a single constructor TransStmt.
This gives lots of useful code sharing.
parent f6d254cc
This diff is collapsed.
...@@ -455,26 +455,18 @@ addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do ...@@ -455,26 +455,18 @@ addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do
(addTickSyntaxExpr hpcSrcSpan bindExpr) (addTickSyntaxExpr hpcSrcSpan bindExpr)
(addTickSyntaxExpr hpcSrcSpan returnExpr) (addTickSyntaxExpr hpcSrcSpan returnExpr)
addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr returnExpr bindExpr) = do addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
t_s <- (addTickLStmts isGuard stmts) , trS_by = by, trS_using = using
t_u <- (addTickLHsExprAlways usingExpr) , trS_ret = returnExpr, trS_bind = bindExpr
t_m <- (addTickMaybeByLHsExpr maybeByExpr) , trS_fmap = liftMExpr }) = do
t_r <- (addTickSyntaxExpr hpcSrcSpan returnExpr)
t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr)
return $ TransformStmt t_s ids t_u t_m t_r t_b
addTickStmt isGuard stmt@(GroupStmt { grpS_stmts = stmts
, grpS_by = by, grpS_using = using
, grpS_ret = returnExpr, grpS_bind = bindExpr
, grpS_fmap = liftMExpr }) = do
t_s <- addTickLStmts isGuard stmts t_s <- addTickLStmts isGuard stmts
t_y <- fmapMaybeM addTickLHsExprAlways by t_y <- fmapMaybeM addTickLHsExprAlways by
t_u <- addTickLHsExprAlways using t_u <- addTickLHsExprAlways using
t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr
t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr
t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr
return $ stmt { grpS_stmts = t_s, grpS_by = t_y, grpS_using = t_u return $ stmt { trS_stmts = t_s, trS_by = t_y, trS_using = t_u
, grpS_ret = t_f, grpS_bind = t_b, grpS_fmap = t_m } , trS_ret = t_f, trS_bind = t_b, trS_fmap = t_m }
addTickStmt isGuard stmt@(RecStmt {}) addTickStmt isGuard stmt@(RecStmt {})
= do { stmts' <- addTickLStmts isGuard (recS_stmts stmt) = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt)
...@@ -495,12 +487,6 @@ addTickStmtAndBinders isGuard (stmts, ids) = ...@@ -495,12 +487,6 @@ addTickStmtAndBinders isGuard (stmts, ids) =
(addTickLStmts isGuard stmts) (addTickLStmts isGuard stmts)
(return ids) (return ids)
addTickMaybeByLHsExpr :: Maybe (LHsExpr Id) -> TM (Maybe (LHsExpr Id))
addTickMaybeByLHsExpr maybeByExpr =
case maybeByExpr of
Nothing -> return Nothing
Just byExpr -> addTickLHsExprAlways byExpr >>= (return . Just)
addTickHsLocalBinds :: HsLocalBinds Id -> TM (HsLocalBinds Id) addTickHsLocalBinds :: HsLocalBinds Id -> TM (HsLocalBinds Id)
addTickHsLocalBinds (HsValBinds binds) = addTickHsLocalBinds (HsValBinds binds) =
liftM HsValBinds liftM HsValBinds
......
...@@ -91,45 +91,19 @@ dsInnerListComp (stmts, bndrs) ...@@ -91,45 +91,19 @@ dsInnerListComp (stmts, bndrs)
where where
bndrs_tuple_type = mkBigCoreVarTupTy bndrs bndrs_tuple_type = mkBigCoreVarTupTy bndrs
-- This function factors out commonality between the desugaring strategies for TransformStmt.
-- Given such a statement it gives you back an expression representing how to compute the transformed
-- list and the tuple that you need to bind from that list in order to proceed with your desugaring
dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr _ _)
= do { (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders)
; usingExpr' <- dsLExpr usingExpr
; using_args <-
case maybeByExpr of
Nothing -> return [expr]
Just byExpr -> do
byExpr' <- dsLExpr byExpr
us <- newUniqueSupply
[tuple_binder] <- newSysLocalsDs [binders_tuple_type]
let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder)
return [Lam tuple_binder byExprWrapper, expr]
; let inner_list_expr = mkApps usingExpr' ((Type binders_tuple_type) : using_args)
pat = mkBigLHsVarPatTup binders
; return (inner_list_expr, pat) }
-- This function factors out commonality between the desugaring strategies for GroupStmt. -- This function factors out commonality between the desugaring strategies for GroupStmt.
-- Given such a statement it gives you back an expression representing how to compute the transformed -- Given such a statement it gives you back an expression representing how to compute the transformed
-- list and the tuple that you need to bind from that list in order to proceed with your desugaring -- list and the tuple that you need to bind from that list in order to proceed with your desugaring
dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) dsTransStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
dsGroupStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = binderMap dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderMap
, grpS_by = by, grpS_using = using }) = do , trS_by = by, trS_using = using }) = do
let (fromBinders, toBinders) = unzip binderMap let (from_bndrs, to_bndrs) = unzip binderMap
from_bndrs_tys = map idType from_bndrs
fromBindersTypes = map idType fromBinders to_bndrs_tys = map idType to_bndrs
toBindersTypes = map idType toBinders to_bndrs_tup_ty = mkBigCoreTupTy to_bndrs_tys
toBindersTupleType = mkBigCoreTupTy toBindersTypes
-- Desugar an inner comprehension which outputs a list of tuples of the "from" binders -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
(expr, from_tup_ty) <- dsInnerListComp (stmts, fromBinders) (expr, from_tup_ty) <- dsInnerListComp (stmts, from_bndrs)
-- Work out what arguments should be supplied to that expression: i.e. is an extraction -- Work out what arguments should be supplied to that expression: i.e. is an extraction
-- function required? If so, create that desugared function and add to arguments -- function required? If so, create that desugared function and add to arguments
...@@ -137,31 +111,34 @@ dsGroupStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = binderMap ...@@ -137,31 +111,34 @@ dsGroupStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = binderMap
usingArgs <- case by of usingArgs <- case by of
Nothing -> return [expr] Nothing -> return [expr]
Just by_e -> do { by_e' <- dsLExpr by_e Just by_e -> do { by_e' <- dsLExpr by_e
; us <- newUniqueSupply ; lam <- matchTuple from_bndrs by_e'
; [from_tup_id] <- newSysLocalsDs [from_tup_ty] ; return [lam, expr] }
; let by_wrap = mkTupleCase us fromBinders by_e'
from_tup_id (Var from_tup_id)
; return [Lam from_tup_id by_wrap, expr] }
-- Create an unzip function for the appropriate arity and element types and find "map" -- Create an unzip function for the appropriate arity and element types and find "map"
(unzip_fn, unzip_rhs) <- mkUnzipBind fromBindersTypes unzip_stuff <- mkUnzipBind form from_bndrs_tys
map_id <- dsLookupGlobalId mapName map_id <- dsLookupGlobalId mapName
-- Generate the expressions to build the grouped list -- Generate the expressions to build the grouped list
let -- First we apply the grouping function to the inner list let -- First we apply the grouping function to the inner list
inner_list_expr = mkApps usingExpr' ((Type from_tup_ty) : usingArgs) inner_list_expr = mkApps usingExpr' (Type from_tup_ty : usingArgs)
-- Then we map our "unzip" across it to turn the lists of tuples into tuples of lists -- Then we map our "unzip" across it to turn the lists of tuples into tuples of lists
-- We make sure we instantiate the type variable "a" to be a list of "from" tuples and -- We make sure we instantiate the type variable "a" to be a list of "from" tuples and
-- the "b" to be a tuple of "to" lists! -- the "b" to be a tuple of "to" lists!
unzipped_inner_list_expr = mkApps (Var map_id)
[Type (mkListTy from_tup_ty), Type toBindersTupleType, Var unzip_fn, inner_list_expr]
-- Then finally we bind the unzip function around that expression -- Then finally we bind the unzip function around that expression
bound_unzipped_inner_list_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_list_expr bound_unzipped_inner_list_expr
= case unzip_stuff of
-- Build a pattern that ensures the consumer binds into the NEW binders, which hold lists rather than single values Nothing -> inner_list_expr
let pat = mkBigLHsVarPatTup toBinders Just (unzip_fn, unzip_rhs) -> Let (Rec [(unzip_fn, unzip_rhs)]) $
mkApps (Var map_id) $
[ Type (mkListTy from_tup_ty)
, Type to_bndrs_tup_ty
, Var unzip_fn
, inner_list_expr]
-- Build a pattern that ensures the consumer binds into the NEW binders,
-- which hold lists rather than single values
let pat = mkBigLHsVarPatTup to_bndrs
return (bound_unzipped_inner_list_expr, pat) return (bound_unzipped_inner_list_expr, pat)
\end{code} \end{code}
%************************************************************************ %************************************************************************
...@@ -251,12 +228,8 @@ deListComp (LetStmt binds : quals) list = do ...@@ -251,12 +228,8 @@ deListComp (LetStmt binds : quals) list = do
core_rest <- deListComp quals list core_rest <- deListComp quals list
dsLocalBinds binds core_rest dsLocalBinds binds core_rest
deListComp (stmt@(TransformStmt {}) : quals) list = do deListComp (stmt@(TransStmt {}) : quals) list = do
(inner_list_expr, pat) <- dsTransformStmt stmt (inner_list_expr, pat) <- dsTransStmt stmt
deBindComp pat inner_list_expr quals list
deListComp (stmt@(GroupStmt {}) : quals) list = do
(inner_list_expr, pat) <- dsGroupStmt stmt
deBindComp pat inner_list_expr quals list deBindComp pat inner_list_expr quals list
deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
...@@ -264,16 +237,14 @@ deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above ...@@ -264,16 +237,14 @@ deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
deBindComp pat core_list1 quals core_list2 deBindComp pat core_list1 quals core_list2
deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
= do = do { exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs ; let (exps, qual_tys) = unzip exps_and_qual_tys
let (exps, qual_tys) = unzip exps_and_qual_tys
(zip_fn, zip_rhs) <- mkZipBind qual_tys ; (zip_fn, zip_rhs) <- mkZipBind qual_tys
-- Deal with [e | pat <- zip l1 .. ln] in example above -- Deal with [e | pat <- zip l1 .. ln] in example above
deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)) ; deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps))
quals list quals list }
where where
bndrs_s = map snd stmtss_w_bndrs bndrs_s = map snd stmtss_w_bndrs
...@@ -361,13 +332,8 @@ dfListComp c_id n_id (LetStmt binds : quals) = do ...@@ -361,13 +332,8 @@ dfListComp c_id n_id (LetStmt binds : quals) = do
core_rest <- dfListComp c_id n_id quals core_rest <- dfListComp c_id n_id quals
dsLocalBinds binds core_rest dsLocalBinds binds core_rest
dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) = do dfListComp c_id n_id (stmt@(TransStmt {}) : quals) = do
(inner_list_expr, pat) <- dsTransformStmt stmt (inner_list_expr, pat) <- dsTransStmt stmt
-- Anyway, we bind the newly transformed list via the generic binding function
dfBindComp c_id n_id (pat, inner_list_expr) quals
dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) = do
(inner_list_expr, pat) <- dsGroupStmt stmt
-- Anyway, we bind the newly grouped list via the generic binding function -- Anyway, we bind the newly grouped list via the generic binding function
dfBindComp c_id n_id (pat, inner_list_expr) quals dfBindComp c_id n_id (pat, inner_list_expr) quals
...@@ -445,7 +411,7 @@ mkZipBind elt_tys = do ...@@ -445,7 +411,7 @@ mkZipBind elt_tys = do
-- Increasing order of tag -- Increasing order of tag
mkUnzipBind :: [Type] -> DsM (Id, CoreExpr) mkUnzipBind :: TransForm -> [Type] -> DsM (Maybe (Id, CoreExpr))
-- mkUnzipBind [t1, t2] -- mkUnzipBind [t1, t2]
-- = (unzip, \ys :: [(t1, t2)] -> foldr (\ax :: (t1, t2) axs :: ([t1], [t2]) -- = (unzip, \ys :: [(t1, t2)] -> foldr (\ax :: (t1, t2) axs :: ([t1], [t2])
-- -> case ax of -- -> case ax of
...@@ -455,28 +421,29 @@ mkUnzipBind :: [Type] -> DsM (Id, CoreExpr) ...@@ -455,28 +421,29 @@ mkUnzipBind :: [Type] -> DsM (Id, CoreExpr)
-- ys) -- ys)
-- --
-- We use foldr here in all cases, even if rules are turned off, because we may as well! -- We use foldr here in all cases, even if rules are turned off, because we may as well!
mkUnzipBind elt_tys = do mkUnzipBind ThenForm _
ax <- newSysLocalDs elt_tuple_ty = return Nothing -- No unzipping for ThenForm
axs <- newSysLocalDs elt_list_tuple_ty mkUnzipBind _ elt_tys
ys <- newSysLocalDs elt_tuple_list_ty = do { ax <- newSysLocalDs elt_tuple_ty
xs <- mapM newSysLocalDs elt_tys ; axs <- newSysLocalDs elt_list_tuple_ty
xss <- mapM newSysLocalDs elt_list_tys ; ys <- newSysLocalDs elt_tuple_list_ty
; xs <- mapM newSysLocalDs elt_tys
; xss <- mapM newSysLocalDs elt_list_tys
unzip_fn <- newSysLocalDs unzip_fn_ty ; unzip_fn <- newSysLocalDs unzip_fn_ty
[us1, us2] <- sequence [newUniqueSupply, newUniqueSupply] ; [us1, us2] <- sequence [newUniqueSupply, newUniqueSupply]
let nil_tuple = mkBigCoreTup (map mkNilExpr elt_tys) ; let nil_tuple = mkBigCoreTup (map mkNilExpr elt_tys)
concat_expressions = map mkConcatExpression (zip3 elt_tys (map Var xs) (map Var xss))
concat_expressions = map mkConcatExpression (zip3 elt_tys (map Var xs) (map Var xss)) tupled_concat_expression = mkBigCoreTup concat_expressions
tupled_concat_expression = mkBigCoreTup concat_expressions
folder_body_inner_case = mkTupleCase us1 xss tupled_concat_expression axs (Var axs)
folder_body_inner_case = mkTupleCase us1 xss tupled_concat_expression axs (Var axs) folder_body_outer_case = mkTupleCase us2 xs folder_body_inner_case ax (Var ax)
folder_body_outer_case = mkTupleCase us2 xs folder_body_inner_case ax (Var ax) folder_body = mkLams [ax, axs] folder_body_outer_case
folder_body = mkLams [ax, axs] folder_body_outer_case
; unzip_body <- mkFoldrExpr elt_tuple_ty elt_list_tuple_ty folder_body nil_tuple (Var ys)
unzip_body <- mkFoldrExpr elt_tuple_ty elt_list_tuple_ty folder_body nil_tuple (Var ys) ; return (Just (unzip_fn, mkLams [ys] unzip_body)) }
return (unzip_fn, mkLams [ys] unzip_body)
where where
elt_tuple_ty = mkBigCoreTupTy elt_tys elt_tuple_ty = mkBigCoreTupTy elt_tys
elt_tuple_list_ty = mkListTy elt_tuple_ty elt_tuple_list_ty = mkListTy elt_tuple_ty
...@@ -730,30 +697,6 @@ dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts ...@@ -730,30 +697,6 @@ dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts
; return $ mkApps then_exp' [ mkApps guard_exp' [exp'] ; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
, rest ] } , rest ] }
-- Transform statements desugar like this:
--
-- [ .. | qs, then f by e ] -> f (\q_v -> e) [| qs |]
--
-- where [| qs |] is the desugared inner monad comprehenion generated by the
-- statements `qs`.
dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest
= do { expr <- dsInnerMonadComp stmts binders return_op
; let binders_tup_type = mkBigCoreTupTy $ map idType binders
; usingExpr' <- dsLExpr usingExpr
; using_args <- case maybeByExpr of
Nothing -> return [expr]
Just byExpr -> do
byExpr' <- dsLExpr byExpr
us <- newUniqueSupply
tup_binder <- newSysLocalDs binders_tup_type
let byExprWrapper = mkTupleCase us binders byExpr' tup_binder (Var tup_binder)
return [Lam tup_binder byExprWrapper, expr]
; let pat = mkBigLHsVarPatTup binders
rhs = mkApps usingExpr' ((Type binders_tup_type) : using_args)
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
-- Group statements desugar like this: -- Group statements desugar like this:
-- --
-- [| (q, then group by e using f); rest |] -- [| (q, then group by e using f); rest |]
...@@ -768,10 +711,10 @@ dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) s ...@@ -768,10 +711,10 @@ dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) s
-- n_tup :: n qt -- n_tup :: n qt
-- unzip :: n qt -> (n t1, ..., n tk) (needs Functor n) -- unzip :: n qt -> (n t1, ..., n tk) (needs Functor n)
dsMcStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bndrs dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
, grpS_by = by, grpS_using = using , trS_by = by, trS_using = using
, grpS_ret = return_op, grpS_bind = bind_op , trS_ret = return_op, trS_bind = bind_op
, grpS_fmap = fmap_op }) stmts_rest , trS_fmap = fmap_op, trS_form = form }) stmts_rest
= do { let (from_bndrs, to_bndrs) = unzip bndrs = do { let (from_bndrs, to_bndrs) = unzip bndrs
from_bndr_tys = map idType from_bndrs -- Types ty from_bndr_tys = map idType from_bndrs -- Types ty
...@@ -790,16 +733,15 @@ dsMcStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bndrs ...@@ -790,16 +733,15 @@ dsMcStmt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bndrs
-- Generate the expressions to build the grouped list -- Generate the expressions to build the grouped list
-- Build a pattern that ensures the consumer binds into the NEW binders, -- Build a pattern that ensures the consumer binds into the NEW binders,
-- which hold monads rather than single values -- which hold monads rather than single values
; fmap_op' <- dsExpr fmap_op
; bind_op' <- dsExpr bind_op ; bind_op' <- dsExpr bind_op
; let bind_ty = exprType bind_op' -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2 ; let bind_ty = exprType bind_op' -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2
n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty -- n (a,b,c) n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty -- n (a,b,c)
tup_n_ty = mkBigCoreVarTupTy to_bndrs tup_n_ty = mkBigCoreVarTupTy to_bndrs
; body <- dsMcStmts stmts_rest ; body <- dsMcStmts stmts_rest
; n_tup_var <- newSysLocalDs n_tup_ty ; n_tup_var <- newSysLocalDs n_tup_ty
; tup_n_var <- newSysLocalDs tup_n_ty ; tup_n_var <- newSysLocalDs tup_n_ty
; tup_n_expr <- mkMcUnzipM fmap_op' n_tup_var from_bndr_tys ; tup_n_expr <- mkMcUnzipM form fmap_op n_tup_var from_bndr_tys
; us <- newUniqueSupply ; us <- newUniqueSupply
; let rhs' = mkApps usingExpr' usingArgs ; let rhs' = mkApps usingExpr' usingArgs
body' = mkTupleCase us to_bndrs body tup_n_var tup_n_expr body' = mkTupleCase us to_bndrs body tup_n_var tup_n_expr
...@@ -908,16 +850,21 @@ dsInnerMonadComp stmts bndrs ret_op ...@@ -908,16 +850,21 @@ dsInnerMonadComp stmts bndrs ret_op
-- = ( fmap (selN1 :: (t1, t2) -> t1) ys -- = ( fmap (selN1 :: (t1, t2) -> t1) ys
-- , fmap (selN2 :: (t1, t2) -> t2) ys ) -- , fmap (selN2 :: (t1, t2) -> t2) ys )
mkMcUnzipM :: CoreExpr -- fmap mkMcUnzipM :: TransForm
-> SyntaxExpr TcId -- fmap
-> Id -- Of type n (a,b,c) -> Id -- Of type n (a,b,c)
-> [Type] -- [a,b,c] -> [Type] -- [a,b,c]
-> DsM CoreExpr -- Of type (n a, n b, n c) -> DsM CoreExpr -- Of type (n a, n b, n c)
mkMcUnzipM fmap_op ys elt_tys mkMcUnzipM ThenForm _ ys _
= do { xs <- mapM newSysLocalDs elt_tys = return (Var ys) -- No unzipping to do
; tup_xs <- newSysLocalDs (mkBigCoreTupTy elt_tys)
mkMcUnzipM _ fmap_op ys elt_tys
= do { fmap_op' <- dsExpr fmap_op
; xs <- mapM newSysLocalDs elt_tys
; tup_xs <- newSysLocalDs (mkBigCoreTupTy elt_tys)
; let arg_ty = idType ys ; let arg_ty = idType ys
mk_elt i = mkApps fmap_op -- fmap :: forall a b. (a -> b) -> n a -> n b mk_elt i = mkApps fmap_op' -- fmap :: forall a b. (a -> b) -> n a -> n b
[ Type arg_ty, Type (elt_tys !! i) [ Type arg_ty, Type (elt_tys !! i)
, mk_sel i, Var ys] , mk_sel i, Var ys]
......
...@@ -864,48 +864,24 @@ data StmtLR idL idR ...@@ -864,48 +864,24 @@ data StmtLR idL idR
-- with type (forall a. a -> m a) -- with type (forall a. a -> m a)
-- See notes [Monad Comprehensions] -- See notes [Monad Comprehensions]
-- After renaming, the ids are the binders -- After renaming, the ids are the binders
-- bound by the stmts and used after them -- bound by the stmts and used after themp
-- "qs, then f by e" ==> TransformStmt qs binders f (Just e) (return) (>>=) | TransStmt {
-- "qs, then f" ==> TransformStmt qs binders f Nothing (return) (>>=) trS_form :: TransForm,
| TransformStmt trS_stmts :: [LStmt idL], -- Stmts to the *left* of the 'group'
[LStmt idL] -- Stmts are the ones to the left of the 'then'
[idR] -- After renaming, the Ids are the binders occurring
-- within this transform statement that are used after it
(LHsExpr idR) -- "then f"
(Maybe (LHsExpr idR)) -- "by e" (optional)
(SyntaxExpr idR) -- The 'return' function for inner monad
-- comprehensions
(SyntaxExpr idR) -- The '(>>=)' operator.
-- See Note [Monad Comprehensions]
| GroupStmt {
grpS_stmts :: [LStmt idL], -- Stmts to the *left* of the 'group'
-- which generates the tuples to be grouped -- which generates the tuples to be grouped
grpS_bndrs :: [(idR, idR)], -- See Note [GroupStmt binder map] trS_bndrs :: [(idR, idR)], -- See Note [TransStmt binder map]
grpS_by :: Maybe (LHsExpr idR), -- "by e" (optional) trS_using :: LHsExpr idR,
trS_by :: Maybe (LHsExpr idR), -- "by e" (optional)
grpS_using :: LHsExpr idR, -- Invariant: if trS_form = GroupBy, then grp_by = Just e
grpS_explicit :: Bool, -- True <=> explicit "using f"
-- False <=> implicit; grpS_using is filled in with trS_ret :: SyntaxExpr idR, -- The monomorphic 'return' function for
-- 'groupWith' (list comprehensions) or -- the inner monad comprehensions
-- 'groupM' (monad comprehensions) trS_bind :: SyntaxExpr idR, -- The '(>>=)' operator
trS_fmap :: SyntaxExpr idR -- The polymorphic 'fmap' function for desugaring
-- Invariant: if grpS_explicit = False, then grp_by = Just e -- Only for 'group' forms
-- That is, we can have group by e
-- group using f
-- group by e using f
grpS_ret :: SyntaxExpr idR, -- The 'return' function for inner monad
-- comprehensions
grpS_bind :: SyntaxExpr idR, -- The '(>>=)' operator
grpS_fmap :: SyntaxExpr idR -- The polymorphic 'fmap' function for desugaring
} -- See Note [Monad Comprehensions] } -- See Note [Monad Comprehensions]
-- Recursive statement (see Note [How RecStmt works] below) -- Recursive statement (see Note [How RecStmt works] below)
...@@ -943,6 +919,15 @@ data StmtLR idL idR ...@@ -943,6 +919,15 @@ data StmtLR idL idR
-- be quite as simple as (m (tya, tyb, tyc)). -- be quite as simple as (m (tya, tyb, tyc)).
} }
deriving (Data, Typeable) deriving (Data, Typeable)
data TransForm -- The 'f' below is the 'using' function, 'e' is the by function
= ThenForm -- then f or then f by e
| GroupFormU -- group using f or group using f by e
| GroupFormB -- group by e
-- In the GroupByFormB, trS_using is filled in with
-- 'groupWith' (list comprehensions) or
-- 'groupM' (monad comprehensions)
deriving (Data, Typeable)
\end{code} \end{code}
Note [The type of bind in Stmts] Note [The type of bind in Stmts]
...@@ -956,9 +941,9 @@ exotic type, such as ...@@ -956,9 +941,9 @@ exotic type, such as
So we must be careful not to make assumptions about the type. So we must be careful not to make assumptions about the type.
In particular, the monad may not be uniform throughout. In particular, the monad may not be uniform throughout.
Note [GroupStmt binder map] Note [TransStmt binder map]
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
The [(idR,idR)] in a GroupStmt behaves as follows: The [(idR,idR)] in a TransStmt behaves as follows:
* Before renaming: [] * Before renaming: []
...@@ -1098,11 +1083,8 @@ pprStmt (ExprStmt expr _ _ _) = ppr expr ...@@ -1098,11 +1083,8 @@ pprStmt (ExprStmt expr _ _ _) = ppr expr
pprStmt (ParStmt stmtss _ _ _) = hsep (map doStmts stmtss) pprStmt (ParStmt stmtss _ _ _) = hsep (map doStmts stmtss)
where doStmts stmts = ptext (sLit "| ") <> ppr stmts where doStmts stmts = ptext (sLit "| ") <> ppr stmts
pprStmt (TransformStmt stmts bndrs using by _ _) pprStmt (TransStmt { trS_stmts = stmts, trS_by = by, trS_using = using, trS_form = form })
= sep (ppr_lc_stmts stmts ++ [pprTransformStmt bndrs using by]) = sep (ppr_lc_stmts stmts ++ [pprTransStmt by using form])
pprStmt (GroupStmt { grpS_stmts = stmts, grpS_by = by, grpS_using = using, grpS_explicit = explicit })
= sep (ppr_lc_stmts stmts ++ [pprGroupStmt by using explicit])
pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids
, recS_later_ids = later_ids }) , recS_later_ids = later_ids })
...@@ -1117,14 +1099,15 @@ pprTransformStmt bndrs using by ...@@ -1117,14 +1099,15 @@ pprTransformStmt bndrs using by
, nest 2 (ppr using) , nest 2 (ppr using)
, nest 2 (pprBy by)] , nest 2 (pprBy by)]
pprGroupStmt :: OutputableBndr id => Maybe (LHsExpr id) pprTransStmt :: OutputableBndr id => Maybe (LHsExpr id)
-> LHsExpr id -> Bool -> LHsExpr id -> TransForm
-> SDoc -> SDoc
pprGroupStmt by using explicit pprTransStmt by using ThenForm
= sep [ ptext (sLit "then group"), nest 2 (pprBy by), nest 2 pp_using ] = sep [ ptext (sLit "then"), nest 2 (ppr using), nest 2 (pprBy by)]
where pprTransStmt by _ GroupFormB
pp_using | explicit = ptext (sLit "using") <+> ppr using = sep [ ptext (sLit "then group"), nest 2 (pprBy by) ]
| otherwise = empty pprTransStmt by using GroupFormU
= sep [ ptext (sLit "then group"), nest 2 (pprBy by), nest 2 (ptext (sLit "using") <+> ppr using)]
pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc
pprBy Nothing = empty pprBy Nothing = empty
...@@ -1412,8 +1395,7 @@ pprStmtInCtxt ctxt stmt ...@@ -1412,8 +1395,7 @@ pprStmtInCtxt ctxt stmt
2 (ppr_stmt stmt) 2 (ppr_stmt stmt)
where where
-- For Group and Transform Stmts, don't print the nested stmts! -- For Group and Transform Stmts, don't print the nested stmts!
ppr_stmt (GroupStmt { grpS_by = by, grpS_using = using ppr_stmt (TransStmt { trS_by = by, trS_using = using
, grpS_explicit = explicit }) = pprGroupStmt by using explicit , trS_form = form }) = pprTransStmt by using form
ppr_stmt (TransformStmt _ bndrs using by _ _) = pprTransformStmt bndrs using by ppr_stmt stmt = pprStmt stmt
ppr_stmt stmt = pprStmt stmt
\end{code} \end{code}
...@@ -43,7 +43,7 @@ module HsUtils( ...@@ -43,7 +43,7 @@ module HsUtils(
-- Stmts -- Stmts
mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, mkLastStmt, mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, mkLastStmt,
emptyGroupStmt, mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt, emptyTransStmt, mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt,
emptyRecStmt, mkRecStmt, emptyRecStmt, mkRecStmt,
-- Template Haskell -- Template Haskell
...@@ -196,9 +196,6 @@ mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id ...@@ -196,9 +196,6 @@ mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id
mkNPlusKPat :: Located id -> HsOverLit id -> Pat id mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
mkTransformStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR
mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
mkLastStmt :: LHsExpr idR -> StmtLR idL idR mkLastStmt :: LHsExpr idR -> StmtLR idL idR
mkExprStmt :: LHsExpr idR -> StmtLR idL idR mkExprStmt :: LHsExpr idR -> StmtLR idL idR
mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR
...@@ -225,22 +222,23 @@ mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b ...@@ -225,22 +222,23 @@ mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b
mkNPat lit neg = NPat lit neg noSyntaxExpr mkNPat lit neg = NPat lit neg noSyntaxExpr
mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr
mkTransformStmt stmts usingExpr = TransformStmt stmts [] usingExpr Nothing noSyntaxExpr noSyntaxExpr mkTransformStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR
mkTransformByStmt stmts usingExpr byExpr = TransformStmt stmts [] usingExpr (Just byExpr) noSyntaxExpr noSyntaxExpr mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR
mkGroupByStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkGroupByStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR
mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR