Commit 00cbbab3 authored by eir@cis.upenn.edu's avatar eir@cis.upenn.edu

Refactor the typechecker to use ExpTypes.

The idea here is described in [wiki:Typechecker]. Briefly,
this refactor keeps solid track of "synthesis" mode vs
"checking" in GHC's bidirectional type-checking algorithm.
When in synthesis mode, the expected type is just an IORef
to write to.

In addition, this patch does a significant reworking of
RebindableSyntax, allowing much more freedom in the types
of the rebindable operators. For example, we can now have
`negate :: Int -> Bool` and
`(>>=) :: m a -> (forall x. a x -> m b) -> m b`. The magic
is in tcSyntaxOp.

This addresses tickets #11397, #11452, and #11458.

Tests:
  typecheck/should_compile/{RebindHR,RebindNegate,T11397,T11458}
  th/T11452
parent 2899aa58
This diff is collapsed.
......@@ -592,8 +592,9 @@ addTickHsExpr (ExplicitList ty wit es) =
(addTickWit wit)
(mapM (addTickLHsExpr) es)
where addTickWit Nothing = return Nothing
addTickWit (Just fln) = do fln' <- addTickHsExpr fln
return (Just fln')
addTickWit (Just fln)
= do fln' <- addTickSyntaxExpr hpcSrcSpan fln
return (Just fln')
addTickHsExpr (ExplicitPArr ty es) =
liftM2 ExplicitPArr
(return ty)
......@@ -621,7 +622,7 @@ addTickHsExpr (ArithSeq ty wit arith_seq) =
(addTickWit wit)
(addTickArithSeqInfo arith_seq)
where addTickWit Nothing = return Nothing
addTickWit (Just fl) = do fl' <- addTickHsExpr fl
addTickWit (Just fl) = do fl' <- addTickSyntaxExpr hpcSrcSpan fl
return (Just fl')
-- We might encounter existing ticks (multiple Coverage passes)
......@@ -732,12 +733,13 @@ addTickStmt _isGuard (LastStmt e noret ret) = do
(addTickLHsExpr e)
(pure noret)
(addTickSyntaxExpr hpcSrcSpan ret)
addTickStmt _isGuard (BindStmt pat e bind fail) = do
liftM4 BindStmt
addTickStmt _isGuard (BindStmt pat e bind fail ty) = do
liftM5 BindStmt
(addTickLPat pat)
(addTickLHsExprRHS e)
(addTickSyntaxExpr hpcSrcSpan bind)
(addTickSyntaxExpr hpcSrcSpan fail)
(return ty)
addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
liftM4 BodyStmt
(addTick isGuard e)
......@@ -747,11 +749,12 @@ addTickStmt isGuard (BodyStmt e bind' guard' ty) = do
addTickStmt _isGuard (LetStmt (L l binds)) = do
liftM (LetStmt . L l)
(addTickHsLocalBinds binds)
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do
liftM3 ParStmt
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr ty) = do
liftM4 ParStmt
(mapM (addTickStmtAndBinders isGuard) pairs)
(addTickSyntaxExpr hpcSrcSpan mzipExpr)
(unLoc <$> addTickLHsExpr (L hpcSrcSpan mzipExpr))
(addTickSyntaxExpr hpcSrcSpan bindExpr)
(return ty)
addTickStmt isGuard (ApplicativeStmt args mb_join body_ty) = do
args' <- mapM (addTickApplicativeArg isGuard) args
return (ApplicativeStmt args' mb_join body_ty)
......@@ -765,7 +768,7 @@ addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
t_u <- addTickLHsExprRHS using
t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr
t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr
t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr
L _ t_m <- addTickLHsExpr (L hpcSrcSpan liftMExpr)
return $ stmt { trS_stmts = t_s, trS_by = t_y, trS_using = t_u
, trS_ret = t_f, trS_bind = t_b, trS_fmap = t_m }
......@@ -792,7 +795,7 @@ addTickApplicativeArg isGuard (op, arg) =
addTickArg (ApplicativeArgMany stmts ret pat) =
ApplicativeArgMany
<$> addTickLStmts isGuard stmts
<*> addTickSyntaxExpr hpcSrcSpan ret
<*> (unLoc <$> addTickLHsExpr (L hpcSrcSpan ret))
<*> addTickLPat pat
addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id
......@@ -837,9 +840,9 @@ addTickIPBind (IPBind nm e) =
-- There is no location here, so we might need to use a context location??
addTickSyntaxExpr :: SrcSpan -> SyntaxExpr Id -> TM (SyntaxExpr Id)
addTickSyntaxExpr pos x = do
addTickSyntaxExpr pos syn@(SyntaxExpr { syn_expr = x }) = do
L _ x' <- addTickLHsExpr (L pos x)
return $ x'
return $ syn { syn_expr = x' }
-- we do not walk into patterns.
addTickLPat :: LPat Id -> TM (LPat Id)
addTickLPat pat = return pat
......@@ -951,12 +954,13 @@ addTickLCmdStmts' lstmts res
binders = collectLStmtsBinders lstmts
addTickCmdStmt :: Stmt Id (LHsCmd Id) -> TM (Stmt Id (LHsCmd Id))
addTickCmdStmt (BindStmt pat c bind fail) = do
liftM4 BindStmt
addTickCmdStmt (BindStmt pat c bind fail ty) = do
liftM5 BindStmt
(addTickLPat pat)
(addTickLHsCmd c)
(return bind)
(return fail)
(return ty)
addTickCmdStmt (LastStmt c noret ret) = do
liftM3 LastStmt
(addTickLHsCmd c)
......
......@@ -25,7 +25,7 @@ import qualified HsUtils
-- So WATCH OUT; check each use of split*Ty functions.
-- Sigh. This is a pain.
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds, dsSyntaxExpr )
import TcType
import TcEvidence
......@@ -465,9 +465,8 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdIf mb_fun cond then_cmd else_cmd)
core_right = mk_right_expr then_ty else_ty (buildEnvStack else_ids stack_id)
core_if <- case mb_fun of
Just fun -> do { core_fun <- dsExpr fun
; matchEnvStack env_ids stack_id $
mkCoreApps core_fun [core_cond, core_left, core_right] }
Just fun -> do { fun_apps <- dsSyntaxExpr fun [core_cond, core_left, core_right]
; matchEnvStack env_ids stack_id fun_apps }
Nothing -> matchEnvStack env_ids stack_id $
mkIfThenElse core_cond core_left core_right
......@@ -782,7 +781,7 @@ dsCmdStmt ids local_vars out_ids (BodyStmt cmd _ _ c_ty) env_ids = do
-- It would be simpler and more consistent to do this using second,
-- but that's likely to be defined in terms of first.
dsCmdStmt ids local_vars out_ids (BindStmt pat cmd _ _) env_ids = do
dsCmdStmt ids local_vars out_ids (BindStmt pat cmd _ _ _) env_ids = do
let pat_ty = hsLPatType pat
(core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars unitTy pat_ty cmd
let pat_vars = mkVarSet (collectPatBinders pat)
......@@ -1142,8 +1141,8 @@ collectl (L _ pat) bndrs
collectEvBinders ds
++ foldr collectl bndrs (hsConPatArgs ps)
go (LitPat _) = bndrs
go (NPat _ _ _) = bndrs
go (NPlusKPat (L _ n) _ _ _) = n : bndrs
go (NPat {}) = bndrs
go (NPlusKPat (L _ n) _ _ _ _ _) = n : bndrs
go (SigPatIn pat _) = collectl pat bndrs
go (SigPatOut pat _) = collectl pat bndrs
......
......@@ -8,7 +8,8 @@ Desugaring exporessions.
{-# LANGUAGE CPP #-}
module DsExpr ( dsExpr, dsLExpr, dsLocalBinds, dsValBinds, dsLit ) where
module DsExpr ( dsExpr, dsLExpr, dsLocalBinds
, dsValBinds, dsLit, dsSyntaxExpr ) where
#include "HsVersions.h"
......@@ -221,7 +222,8 @@ dsExpr (HsWrap co_fn e)
; return wrapped_e }
dsExpr (NegApp expr neg_expr)
= App <$> dsExpr neg_expr <*> dsLExpr expr
= do { expr' <- dsLExpr expr
; dsSyntaxExpr neg_expr [expr'] }
dsExpr (HsLam a_Match)
= uncurry mkLams <$> matchWrapper LambdaExpr Nothing a_Match
......@@ -354,8 +356,7 @@ dsExpr (HsIf mb_fun guard_expr then_expr else_expr)
; b1 <- dsLExpr then_expr
; b2 <- dsLExpr else_expr
; case mb_fun of
Just fun -> do { core_fun <- dsExpr fun
; return (mkCoreApps core_fun [pred,b1,b2]) }
Just fun -> dsSyntaxExpr fun [pred, b1, b2]
Nothing -> return $ mkIfThenElse pred b1 b2 }
dsExpr (HsMultiIf res_ty alts)
......@@ -398,10 +399,8 @@ dsExpr (ExplicitPArr ty xs) = do
dsExpr (ArithSeq expr witness seq)
= case witness of
Nothing -> dsArithSeq expr seq
Just fl -> do {
; fl' <- dsExpr fl
; newArithSeq <- dsArithSeq expr seq
; return (App fl' newArithSeq)}
Just fl -> do { newArithSeq <- dsArithSeq expr seq
; dsSyntaxExpr fl [newArithSeq] }
dsExpr (PArrSeq expr (FromTo from to))
= mkApps <$> dsExpr expr <*> mapM dsLExpr [from, to]
......@@ -741,6 +740,16 @@ dsExpr (HsRecFld {}) = panic "dsExpr:HsRecFld"
dsExpr (HsTypeOut {})
= panic "dsExpr: tried to desugar a naked type application argument (HsTypeOut)"
------------------------------
dsSyntaxExpr :: SyntaxExpr Id -> [CoreExpr] -> DsM CoreExpr
dsSyntaxExpr (SyntaxExpr { syn_expr = expr
, syn_arg_wraps = arg_wraps
, syn_res_wrap = res_wrap })
arg_exprs
= do { args <- zipWithM dsHsWrapper arg_wraps arg_exprs
; fun <- dsExpr expr
; dsHsWrapper res_wrap $ mkApps fun args }
findField :: [LHsRecField Id arg] -> Name -> [arg]
findField rbinds sel
= [hsRecFieldArg fld | L _ fld <- rbinds
......@@ -832,10 +841,9 @@ dsExplicitList elt_ty Nothing xs
; return (foldr (App . App (Var c)) folded_suffix prefix) }
dsExplicitList elt_ty (Just fln) xs
= do { fln' <- dsExpr fln
; list <- dsExplicitList elt_ty Nothing xs
= do { list <- dsExplicitList elt_ty Nothing xs
; dflags <- getDynFlags
; return (App (App fln' (mkIntExprInt dflags (length xs))) list) }
; dsSyntaxExpr fln [mkIntExprInt dflags (length xs), list] }
spanTail :: (a -> Bool) -> [a] -> ([a], [a])
spanTail f xs = (reverse rejected, reverse satisfying)
......@@ -882,25 +890,21 @@ dsDo stmts
go _ (BodyStmt rhs then_expr _ _) stmts
= do { rhs2 <- dsLExpr rhs
; warnDiscardedDoBindings rhs (exprType rhs2)
; then_expr2 <- dsExpr then_expr
; rest <- goL stmts
; return (mkApps then_expr2 [rhs2, rest]) }
; dsSyntaxExpr then_expr [rhs2, rest] }
go _ (LetStmt (L _ binds)) stmts
= do { rest <- goL stmts
; dsLocalBinds binds rest }
go _ (BindStmt pat rhs bind_op fail_op) stmts
go _ (BindStmt pat rhs bind_op fail_op res1_ty) stmts
= do { body <- goL stmts
; rhs' <- dsLExpr rhs
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
; dsSyntaxExpr bind_op [rhs', Lam var match_code] }
go _ (ApplicativeStmt args mb_join body_ty) stmts
= do {
......@@ -915,7 +919,6 @@ dsDo stmts
arg_tys = map hsLPatType pats
; rhss' <- sequence rhss
; ops' <- mapM dsExpr (map fst args)
; let body' = noLoc $ HsDo DoExpr (noLoc stmts) body_ty
......@@ -926,30 +929,30 @@ dsDo stmts
, mg_origin = Generated }
; fun' <- dsLExpr fun
; let mk_ap_call l (op,r) = mkApps op [l,r]
expr = foldl mk_ap_call fun' (zip ops' rhss')
; let mk_ap_call l (op,r) = dsSyntaxExpr op [l,r]
; expr <- foldlM mk_ap_call fun' (zip (map fst args) rhss')
; case mb_join of
Nothing -> return expr
Just join_op ->
do { join_op' <- dsExpr join_op
; return (App join_op' expr) } }
Just join_op -> dsSyntaxExpr join_op [expr] }
go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = return_op
, recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
, recS_bind_ty = bind_ty
, recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts
= goL (new_bind_stmt : stmts) -- rec_ids can be empty; eg rec { print 'x' }
where
new_bind_stmt = L loc $ BindStmt (mkBigLHsPatTupId later_pats)
mfix_app bind_op
noSyntaxExpr -- Tuple cannot fail
bind_ty
tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids
tup_ty = mkBigCoreTupTy (map idType tup_ids) -- Deals with singleton case
rec_tup_pats = map nlVarPat tup_ids
later_pats = rec_tup_pats
rets = map noLoc rec_rets
mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
mfix_app = nlHsSyntaxApps mfix_op [mfix_arg]
mfix_arg = noLoc $ HsLam
(MG { mg_alts = noLoc [mkSimpleMatch [mfix_pat] body]
, mg_arg_tys = [tup_ty], mg_res_ty = body_ty
......@@ -957,7 +960,7 @@ dsDo stmts
mfix_pat = noLoc $ LazyPat $ mkBigLHsPatTupId rec_tup_pats
body = noLoc $ HsDo
DoExpr (noLoc (rec_stmts ++ [ret_stmt])) body_ty
ret_app = nlHsApp (noLoc return_op) (mkBigLHsTupId rets)
ret_app = nlHsSyntaxApps return_op [mkBigLHsTupId rets]
ret_stmt = noLoc $ mkLastStmt ret_app
-- This LastStmt will be desugared with dsDo,
-- which ignores the return_op in the LastStmt,
......@@ -971,10 +974,10 @@ handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr
-- the monadic 'fail' rather than throwing an exception
handle_failure pat match fail_op
| matchCanFail match
= do { fail_op' <- dsExpr fail_op
; dflags <- getDynFlags
= do { dflags <- getDynFlags
; fail_msg <- mkStringExpr (mk_fail_msg dflags pat)
; extractMatchResult match (App fail_op' fail_msg) }
; fail_expr <- dsSyntaxExpr fail_op [fail_msg]
; extractMatchResult match fail_expr }
| otherwise
= extractMatchResult match (error "It can't fail")
......
module DsExpr where
import HsSyn ( HsExpr, LHsExpr, HsLocalBinds )
import HsSyn ( HsExpr, LHsExpr, HsLocalBinds, SyntaxExpr )
import Var ( Id )
import DsMonad ( DsM )
import CoreSyn ( CoreExpr )
dsExpr :: HsExpr Id -> DsM CoreExpr
dsLExpr :: LHsExpr Id -> DsM CoreExpr
dsSyntaxExpr :: SyntaxExpr Id -> [CoreExpr] -> DsM CoreExpr
dsLocalBinds :: HsLocalBinds Id -> CoreExpr -> DsM CoreExpr
......@@ -114,7 +114,7 @@ matchGuards (LetStmt (L _ binds) : stmts) ctx rhs rhs_ty = do
-- so we can't desugar the bindings without the
-- body expression in hand
matchGuards (BindStmt pat bind_rhs _ _ : stmts) ctx rhs rhs_ty = do
matchGuards (BindStmt pat bind_rhs _ _ _ : stmts) ctx rhs rhs_ty = do
match_result <- matchGuards stmts ctx rhs rhs_ty
core_rhs <- dsLExpr bind_rhs
matchSinglePat core_rhs (StmtCtxt ctx) pat rhs_ty match_result
......
......@@ -12,7 +12,7 @@ module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where
#include "HsVersions.h"
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds, dsSyntaxExpr )
import HsSyn
import TcHsSyn
......@@ -233,11 +233,11 @@ deListComp (stmt@(TransStmt {}) : quals) list = do
(inner_list_expr, pat) <- dsTransStmt stmt
deBindComp pat inner_list_expr quals list
deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
deListComp (BindStmt pat list1 _ _ _ : quals) core_list2 = do -- rule A' above
core_list1 <- dsLExpr list1
deBindComp pat core_list1 quals core_list2
deListComp (ParStmt stmtss_w_bndrs _ _ : quals) list
deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
= do { exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
; let (exps, qual_tys) = unzip exps_and_qual_tys
......@@ -339,7 +339,7 @@ dfListComp c_id n_id (stmt@(TransStmt {}) : quals) = do
-- Anyway, we bind the newly grouped list via the generic binding function
dfBindComp c_id n_id (pat, inner_list_expr) quals
dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do
dfListComp c_id n_id (BindStmt pat list1 _ _ _ : quals) = do
-- evaluate the two lists
core_list1 <- dsLExpr list1
......@@ -476,7 +476,7 @@ dsPArrComp :: [ExprStmt Id]
-> DsM CoreExpr
-- Special case for parallel comprehension
dsPArrComp (ParStmt qss _ _ : quals) = dePArrParComp qss quals
dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals
-- Special case for simple generators:
--
......@@ -487,7 +487,7 @@ dsPArrComp (ParStmt qss _ _ : quals) = dePArrParComp qss quals
-- <<[:e' | p <- e, qs:]>> =
-- <<[:e' | qs:]>> p (filterP (\x -> case x of {p -> True; _ -> False}) e)
--
dsPArrComp (BindStmt p e _ _ : qs) = do
dsPArrComp (BindStmt p e _ _ _ : qs) = do
filterP <- dsDPHBuiltin filterPVar
ce <- dsLExpr e
let ety'ce = parrElemType ce
......@@ -546,7 +546,7 @@ dePArrComp (BodyStmt b _ _ _ : qs) pa cea = do
-- in
-- <<[:e' | qs:]>> (pa, p) (crossMapP ea ef)
--
dePArrComp (BindStmt p e _ _ : qs) pa cea = do
dePArrComp (BindStmt p e _ _ _ : qs) pa cea = do
filterP <- dsDPHBuiltin filterPVar
crossMapP <- dsDPHBuiltin crossMapPVar
ce <- dsLExpr e
......@@ -679,8 +679,7 @@ dsMcStmt :: ExprStmt Id -> [ExprLStmt Id] -> DsM CoreExpr
dsMcStmt (LastStmt body _ ret_op) stmts
= ASSERT( null stmts )
do { body' <- dsLExpr body
; ret_op' <- dsExpr ret_op
; return (App ret_op' body') }
; dsSyntaxExpr ret_op [body'] }
-- [ .. | let binds, stmts ]
dsMcStmt (LetStmt (L _ binds)) stmts
......@@ -688,9 +687,9 @@ dsMcStmt (LetStmt (L _ binds)) stmts
; dsLocalBinds binds rest }
-- [ .. | a <- m, stmts ]
dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts
dsMcStmt (BindStmt pat rhs bind_op fail_op bind_ty) stmts
= do { rhs' <- dsLExpr rhs
; dsMcBindStmt pat rhs' bind_op fail_op stmts }
; dsMcBindStmt pat rhs' bind_op fail_op bind_ty stmts }
-- Apply `guard` to the `exp` expression
--
......@@ -698,11 +697,9 @@ dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts
--
dsMcStmt (BodyStmt exp then_exp guard_exp _) stmts
= do { exp' <- dsLExpr exp
; guard_exp' <- dsExpr guard_exp
; then_exp' <- dsExpr then_exp
; rest <- dsMcStmts stmts
; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
, rest ] }
; guard_exp' <- dsSyntaxExpr guard_exp [exp']
; dsSyntaxExpr then_exp [guard_exp', rest] }
-- Group statements desugar like this:
--
......@@ -721,6 +718,7 @@ dsMcStmt (BodyStmt exp then_exp guard_exp _) stmts
dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
, trS_by = by, trS_using = using
, trS_ret = return_op, trS_bind = bind_op
, trS_bind_arg_ty = n_tup_ty' -- n (a,b,c)
, trS_fmap = fmap_op, trS_form = form }) stmts_rest
= do { let (from_bndrs, to_bndrs) = unzip bndrs
......@@ -742,10 +740,7 @@ dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
-- Generate the expressions to build the grouped list
-- Build a pattern that ensures the consumer binds into the NEW binders,
-- which hold monads rather than single values
; bind_op' <- dsExpr bind_op
; let bind_ty' = exprType bind_op' -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2
n_tup_ty' = funArgTy $ funArgTy $ funResultTy bind_ty' -- n (a,b,c)
tup_n_ty' = mkBigCoreVarTupTy to_bndrs
; let tup_n_ty' = mkBigCoreVarTupTy to_bndrs
; body <- dsMcStmts stmts_rest
; n_tup_var' <- newSysLocalDs n_tup_ty'
......@@ -755,7 +750,7 @@ dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
; let rhs' = mkApps usingExpr' usingArgs'
body' = mkTupleCase us to_bndrs body tup_n_var' tup_n_expr'
; return (mkApps bind_op' [rhs', Lam n_tup_var' body']) }
; dsSyntaxExpr bind_op [rhs', Lam n_tup_var' body'] }
-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel
-- statements, for example:
......@@ -768,7 +763,7 @@ dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
-- mzip :: forall a b. m a -> m b -> m (a,b)
-- NB: we need a polymorphic mzip because we call it several times
dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest
dsMcStmt (ParStmt blocks mzip_op bind_op bind_ty) stmts_rest
= do { exps_w_tys <- mapM ds_inner blocks -- Pairs (exp :: m ty, ty)
; mzip_op' <- dsExpr mzip_op
......@@ -782,7 +777,7 @@ dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest
mkBoxedTupleTy [t1,t2]))
exps_w_tys
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
; dsMcBindStmt pat rhs bind_op noSyntaxExpr bind_ty stmts_rest }
where
ds_inner (ParStmtBlock stmts bndrs return_op)
= do { exp <- dsInnerMonadComp stmts bndrs return_op
......@@ -806,28 +801,26 @@ dsMcBindStmt :: LPat Id
-> CoreExpr -- ^ the desugared rhs of the bind statement
-> SyntaxExpr Id
-> SyntaxExpr Id
-> Type -- ^ S in (>>=) :: Q -> (R -> S) -> T
-> [ExprLStmt Id]
-> DsM CoreExpr
dsMcBindStmt pat rhs' bind_op fail_op stmts
dsMcBindStmt pat rhs' bind_op fail_op res1_ty stmts
= do { body <- dsMcStmts stmts
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
res1_ty (cantFailMatchResult body)
; match_code <- handle_failure pat match fail_op
; return (mkApps bind_op' [rhs', Lam var match_code]) }
; dsSyntaxExpr bind_op [rhs', Lam var match_code] }
where
-- In a monad comprehension expression, pattern-match failure just calls
-- the monadic `fail` rather than throwing an exception
handle_failure pat match fail_op
| matchCanFail match
= do { fail_op' <- dsExpr fail_op
; dflags <- getDynFlags
= do { dflags <- getDynFlags
; fail_msg <- mkStringExpr (mk_fail_msg dflags pat)
; extractMatchResult match (App fail_op' fail_msg) }
; fail_expr <- dsSyntaxExpr fail_op [fail_msg]
; extractMatchResult match fail_expr }
| otherwise
= extractMatchResult match (error "It can't fail")
......@@ -842,8 +835,8 @@ dsMcBindStmt pat rhs' bind_op fail_op stmts
-- [ (a,b,c) | quals ]
dsInnerMonadComp :: [ExprLStmt Id]
-> [Id] -- Return a tuple of these variables
-> HsExpr Id -- The monomorphic "return" operator
-> [Id] -- Return a tuple of these variables
-> SyntaxExpr Id -- The monomorphic "return" operator
-> DsM CoreExpr
dsInnerMonadComp stmts bndrs ret_op
= dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTupId bndrs) False ret_op)])
......@@ -860,7 +853,7 @@ dsInnerMonadComp stmts bndrs ret_op
-- , fmap (selN2 :: (t1, t2) -> t2) ys )
mkMcUnzipM :: TransForm
-> SyntaxExpr TcId -- fmap
-> HsExpr TcId -- fmap
-> Id -- Of type n (a,b,c)
-> [Type] -- [a,b,c]
-> DsM CoreExpr -- Of type (n a, n b, n c)
......
......@@ -1279,7 +1279,7 @@ repLSts :: [LStmt Name (LHsExpr Name)] -> DsM ([GenSymBind], [Core TH.StmtQ])
repLSts stmts = repSts (map unLoc stmts)
repSts :: [Stmt Name (LHsExpr Name)] -> DsM ([GenSymBind], [Core TH.StmtQ])
repSts (BindStmt p e _ _ : ss) =
repSts (BindStmt p e _ _ _ : ss) =
do { e2 <- repLE e
; ss1 <- mkGenSyms (collectPatBinders p)
; addBinds ss1 $ do {
......@@ -1297,7 +1297,7 @@ repSts (BodyStmt e _ _ _ : ss) =
; z <- repNoBindSt e2
; (ss2,zs) <- repSts ss
; return (ss2, z : zs) }
repSts (ParStmt stmt_blocks _ _ : ss) =
repSts (ParStmt stmt_blocks _ _ _ : ss) =
do { (ss_s, stmt_blocks1) <- mapAndUnzipM rep_stmt_block stmt_blocks
; let stmt_blocks2 = nonEmptyCoreList stmt_blocks1
ss1 = concat ss_s
......@@ -1463,7 +1463,7 @@ repP (BangPat p) = do { p1 <- repLP p; repPbang p1 }
repP (AsPat x p) = do { x' <- lookupLBinder x; p1 <- repLP p; repPaspat x' p1 }
repP (ParPat p) = repLP p
repP (ListPat ps _ Nothing) = do { qs <- repLPs ps; repPlist qs }
repP (ListPat ps ty1 (Just (_,e))) = do { p <- repP (ListPat ps ty1 Nothing); e' <- repE e; repPview e' p}
repP (ListPat ps ty1 (Just (_,e))) = do { p <- repP (ListPat ps ty1 Nothing); e' <- repE (syn_expr e); repPview e' p}
repP (TuplePat ps boxed _)
| isBoxed boxed = do { qs <- repLPs ps; repPtup qs }
| otherwise = do { qs <- repLPs ps; repPunboxedTup qs }
......@@ -1483,9 +1483,9 @@ repP (ConPatIn dc details)
; MkC p <- repLP (hsRecFieldArg fld)
; rep2 fieldPatName [v,p] }
repP (NPat (L _ l) Nothing _) = do { a <- repOverloadedLiteral l; repPlit a }
repP (NPat (L _ l) Nothing _ _) = do { a <- repOverloadedLiteral l; repPlit a }
repP (ViewPat e p _) = do { e' <- repLE e; p' <- repLP p; repPview e' p' }
repP p@(NPat _ (Just _) _) = notHandled "Negative overloaded patterns" (ppr p)
repP p@(NPat _ (Just _) _ _) = notHandled "Negative overloaded patterns" (ppr p)
repP p@(SigPatIn {}) = notHandled "Type signatures in patterns" (ppr p)
-- The problem is to do with scoped type variables.
-- To implement them, we have to implement the scoping rules
......
......@@ -239,11 +239,11 @@ seqVar var body = Case (Var var) var (exprType body)
mkCoLetMatchResult :: CoreBind -> MatchResult -> MatchResult
mkCoLetMatchResult bind = adjustMatchResult (mkCoreLet bind)
-- (mkViewMatchResult var' viewExpr var mr) makes the expression
-- let var' = viewExpr var in mr
mkViewMatchResult :: Id -> CoreExpr -> Id -> MatchResult -> MatchResult
mkViewMatchResult var' viewExpr var =
adjustMatchResult (mkCoreLet (NonRec var' (mkCoreAppDs (text "mkView" <+> ppr var') viewExpr (Var var))))
-- (mkViewMatchResult var' viewExpr mr) makes the expression
-- let var' = viewExpr in mr
mkViewMatchResult :: Id -> CoreExpr -> MatchResult -> MatchResult
mkViewMatchResult var' viewExpr =
adjustMatchResult (mkCoreLet (NonRec var' viewExpr))
mkEvalMatchResult :: Id -> Type -> MatchResult -> MatchResult
mkEvalMatchResult var ty
......
......@@ -12,7 +12,7 @@ module Match ( match, matchEquations, matchWrapper, matchSimply, matchSinglePat
#include "HsVersions.h"
import {-#SOURCE#-} DsExpr (dsLExpr, dsExpr)
import {-#SOURCE#-} DsExpr (dsLExpr, dsSyntaxExpr)
import DynFlags
import HsSyn
......@@ -269,7 +269,9 @@ matchView (var:vars) ty (eqns@(eqn1:_))
map (decomposeFirstPat getViewPat) eqns
-- compile the view expressions
; viewExpr' <- dsLExpr viewExpr
; return (mkViewMatchResult var' viewExpr' var match_result) }
; return (mkViewMatchResult var'
(mkCoreAppDs (text "matchView") viewExpr' (Var var))
match_result) }
matchView _ _ _ = panic "matchView"
matchOverloadedList :: [Id] -> Type -> [EquationInfo] -> DsM MatchResult
......@@ -280,8 +282,8 @@ matchOverloadedList (var:vars) ty (eqns@(eqn1:_))
; var' <- newUniqueId var (mkListTy elt_ty) -- we construct the overall type by hand
; match_result <- match (var':vars) ty $
map (decomposeFirstPat getOLPat) eqns -- getOLPat builds the pattern inside as a non-overloaded version of the overloaded list pattern
; e' <- dsExpr e
; return (mkViewMatchResult var' e' var match_result) }
; e' <- dsSyntaxExpr e [Var var]
; return (mkViewMatchResult var' e' match_result) }
matchOverloadedList _ _ _ = panic "matchOverloadedList"
-- decompose the first pattern and leave the rest alone
......@@ -457,8 +459,8 @@ tidy1 _ (LitPat lit)
= return (idDsWrapper, tidyLitPat lit)
-- NPats: we *might* be able to replace these w/ a simpler form
tidy1 _ (NPat (L _ lit) mb_neg eq)
= return (idDsWrapper, tidyNPat tidyLitPat lit mb_neg eq)
tidy1 _ (NPat (L _ lit) mb_neg eq ty)
= return (idDsWrapper, tidyNPat tidyLitPat lit mb_neg eq ty)
-- Everything else goes through unchanged...
......@@ -939,7 +941,7 @@ viewLExprEq (e1,_) (e2,_) = lexp e1 e2
-- to ignore them?
exp (OpApp l o _ ri) (OpApp l' o' _ ri') =
lexp l l' && lexp o o' && lexp ri ri'
exp (NegApp e n) (NegApp e' n') = lexp e e' && exp n n'
exp (NegApp e n) (NegApp e' n') = lexp e e' && syn_exp n n'
exp (SectionL e1 e2) (SectionL e1' e2') =
lexp e1 e1' && lexp e2 e2'
exp (SectionR e1 e2) (SectionR e1' e2') =
......@@ -955,6 +957,18 @@ viewLExprEq (e1,_) (e2,_) = lexp e1 e2
-- because they cannot be functions
exp _ _ = False
---------
syn_exp :: SyntaxExpr Id -> SyntaxExpr Id -> Bool
syn_exp (SyntaxExpr { syn_expr = expr1
, syn_arg_wraps = arg_wraps1
, syn_res_wrap = res_wrap1 })
(SyntaxExpr { syn_expr = expr2
, syn_arg_wraps = arg_wraps2
, syn_res_wrap = res_wrap2 })
= exp expr1 expr2 &&
and (zipWithEqual "viewLExprEq" wrap arg_wraps1 arg_wraps2) &&
wrap res_wrap1 res_wrap2
---------
tup_arg (L _ (Present e1)) (L _ (Present e2)) = lexp e1 e2
tup_arg (L _ (Missing t1)) (L _ (Missing t2)) = eqType t1 t2
......@@ -998,8 +1012,8 @@ patGroup _ (ConPatOut { pat_con = L _ con
| PatSynCon psyn <- con = PgSyn psyn tys
patGroup _ (WildPat {}) = PgAny
patGroup _ (BangPat {}) = PgBang
patGroup _ (NPat (L _ olit) mb_neg _) = PgN (hsOverLitKey olit (isJust mb_neg))
patGroup _ (NPlusKPat _ (L _ olit) _ _) = PgNpK (hsOverLitKey olit False)
patGroup _ (NPat (L _ olit) mb_neg _ _) = PgN (hsOverLitKey olit (isJust mb_neg))
patGroup _ (NPlusKPat _ (L _ olit) _ _ _ _)= PgNpK (hsOverLitKey olit False)
patGroup _ (CoPat _ p _) = PgCo (hsPatType p) -- Type of innelexp pattern
patGroup _ (ViewPat