Commit 67c793a3 authored by Simon Peyton Jones's avatar Simon Peyton Jones

Tidy up a remaining glitch in unification

There was one place, in type checking parallel list comprehensions
where we were unifying types, but had no convenient way to use the
resulting coercion; instead we just checked that it was Refl.  This
was Wrong Wrong; it might fail unpredicably in a GADT-like situation,
and it led to extra error-generation code used only in this one place.

This patch tidies it all up, by moving the 'return' method from the
*comprehension* to the ParStmtBlock. The latter is a new data type,
now used for each sub-chunk of a parallel list comprehension.

Because of the data type change, quite a few modules are touched,
but only in a fairly trivial way. The real changes are in TcMatches
(and corresponding desugaring); plus deleting code from TcUnify.

This patch also fixes the pretty-printing bug in Trac #6060
parent 2822e00d
......@@ -620,12 +620,11 @@ addTickStmt isGuard (ExprStmt e bind' guard' ty) = do
addTickStmt _isGuard (LetStmt binds) = do
liftM LetStmt
(addTickHsLocalBinds binds)
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do
liftM4 ParStmt
addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr) = do
liftM3 ParStmt
(mapM (addTickStmtAndBinders isGuard) pairs)
(addTickSyntaxExpr hpcSrcSpan mzipExpr)
(addTickSyntaxExpr hpcSrcSpan bindExpr)
(addTickSyntaxExpr hpcSrcSpan returnExpr)
addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
, trS_by = by, trS_using = using
......@@ -652,12 +651,13 @@ addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id)
addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e
| otherwise = addTickLHsExprRHS e
addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ([LStmt Id], a)
-> TM ([LStmt Id], a)
addTickStmtAndBinders isGuard (stmts, ids) =
liftM2 (,)
addTickStmtAndBinders :: Maybe (Bool -> BoxLabel) -> ParStmtBlock Id Id
-> TM (ParStmtBlock Id Id)
addTickStmtAndBinders isGuard (ParStmtBlock stmts ids returnExpr) =
liftM3 ParStmtBlock
(addTickLStmts isGuard stmts)
(return ids)
(addTickSyntaxExpr hpcSrcSpan returnExpr)
addTickHsLocalBinds :: HsLocalBinds Id -> TM (HsLocalBinds Id)
addTickHsLocalBinds (HsValBinds binds) =
......
......@@ -1124,8 +1124,8 @@ collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
collectStmtBinders (LetStmt binds) = collectLocalBinders binds
collectStmtBinders (ExprStmt {}) = []
collectStmtBinders (LastStmt {}) = []
collectStmtBinders (ParStmt xs _ _ _) = collectLStmtsBinders
$ concatMap fst xs
collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders
$ [ s | ParStmtBlock ss _ _ <- xs, s <- ss]
collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt { recS_later_ids = later_ids }) = later_ids
......
......@@ -19,7 +19,6 @@ import TcHsSyn
import CoreSyn
import MkCore
import TcEvidence
import DsMonad -- the monadery used in the desugarer
import DsUtils
......@@ -71,15 +70,15 @@ dsListComp lquals res_ty = do
-- mix of possibly a single element in length, so we do this to leave the possibility open
isParallelComp = any isParallelStmt
isParallelStmt (ParStmt _ _ _ _) = True
isParallelStmt _ = False
isParallelStmt (ParStmt {}) = True
isParallelStmt _ = False
-- This function lets you desugar a inner list comprehension and a list of the binders
-- of that comprehension that we need in the outer comprehension into such an expression
-- and the type of the elements that it outputs (tuples of binders)
dsInnerListComp :: ([LStmt Id], [Id]) -> DsM (CoreExpr, Type)
dsInnerListComp (stmts, bndrs)
dsInnerListComp :: (ParStmtBlock Id Id) -> DsM (CoreExpr, Type)
dsInnerListComp (ParStmtBlock stmts bndrs _)
= do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)])
(mkListTy bndrs_tuple_type)
; return (expr, bndrs_tuple_type) }
......@@ -98,7 +97,7 @@ dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderM
to_bndrs_tup_ty = mkBigCoreTupTy to_bndrs_tys
-- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
(expr, from_tup_ty) <- dsInnerListComp (stmts, from_bndrs)
(expr, from_tup_ty) <- dsInnerListComp (ParStmtBlock stmts from_bndrs noSyntaxExpr)
-- Work out what arguments should be supplied to that expression: i.e. is an extraction
-- function required? If so, create that desugared function and add to arguments
......@@ -233,7 +232,7 @@ deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
core_list1 <- dsLExpr list1
deBindComp pat core_list1 quals core_list2
deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
deListComp (ParStmt stmtss_w_bndrs _ _ : quals) list
= do { exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
; let (exps, qual_tys) = unzip exps_and_qual_tys
......@@ -243,7 +242,7 @@ deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
; deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps))
quals list }
where
bndrs_s = map snd stmtss_w_bndrs
bndrs_s = [bs | ParStmtBlock _ bs _ <- stmtss_w_bndrs]
-- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
pat = mkBigLHsPatTup pats
......@@ -473,7 +472,7 @@ dsPArrComp :: [Stmt Id]
-> DsM CoreExpr
-- Special case for parallel comprehension
dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals
dsPArrComp (ParStmt qss _ _ : quals) = dePArrParComp qss quals
-- Special case for simple generators:
--
......@@ -590,7 +589,7 @@ dePArrComp (LetStmt ds : qs) pa cea = do
-- singeltons qualifier lists, which we already special case in the caller.
-- So, encountering one here is a bug.
--
dePArrComp (ParStmt _ _ _ _ : _) _ _ =
dePArrComp (ParStmt {} : _) _ _ =
panic "DsListComp.dePArrComp: malformed comprehension AST: ParStmt"
dePArrComp (TransStmt {} : _) _ _ = panic "DsListComp.dePArrComp: TransStmt"
dePArrComp (RecStmt {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt"
......@@ -601,7 +600,7 @@ dePArrComp (RecStmt {} : _) _ _ = panic "DsListComp.dePArrComp: RecStmt"
-- where
-- {x_1, ..., x_n} = DV (qs)
--
dePArrParComp :: [([LStmt Id], [Id])] -> [Stmt Id] -> DsM CoreExpr
dePArrParComp :: [ParStmtBlock Id Id] -> [Stmt Id] -> DsM CoreExpr
dePArrParComp qss quals = do
(pQss, ceQss) <- deParStmt qss
dePArrComp quals pQss ceQss
......@@ -609,13 +608,13 @@ dePArrParComp qss quals = do
deParStmt [] =
-- empty parallel statement lists have no source representation
panic "DsListComp.dePArrComp: Empty parallel list comprehension"
deParStmt ((qs, xs):qss) = do -- first statement
deParStmt (ParStmtBlock qs xs _:qss) = do -- first statement
let res_expr = mkLHsVarTuple xs
cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
parStmts qss (mkLHsVarPatTup xs) cqs
---
parStmts [] pa cea = return (pa, cea)
parStmts ((qs, xs):qss) pa cea = do -- subsequent statements (zip'ed)
parStmts (ParStmtBlock qs xs _:qss) pa cea = do -- subsequent statements (zip'ed)
zipP <- dsDPHBuiltin zipPVar
let pa' = mkLHsPatTup [pa, mkLHsVarPatTup xs]
ty'cea = parrElemType cea
......@@ -763,12 +762,12 @@ dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
-- mzip :: forall a b. m a -> m b -> m (a,b)
-- NB: we need a polymorphic mzip because we call it several times
dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest
= do { exps_w_tys <- mapM ds_inner pairs -- Pairs (exp :: m ty, ty)
dsMcStmt (ParStmt blocks mzip_op bind_op) stmts_rest
= do { exps_w_tys <- mapM ds_inner blocks -- Pairs (exp :: m ty, ty)
; mzip_op' <- dsExpr mzip_op
; let -- The pattern variables
pats = map (mkBigLHsVarPatTup . snd) pairs
pats = [ mkBigLHsVarPatTup bs | ParStmtBlock _ bs _ <- blocks]
-- Pattern with tuples of variables
-- [v1,v2,v3] => (v1, (v2, v3))
pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats
......@@ -779,11 +778,9 @@ dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest
; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
where
ds_inner (stmts, bndrs) = do { exp <- dsInnerMonadComp stmts bndrs mono_ret_op
; return (exp, tup_ty) }
where
mono_ret_op = HsWrap (WpTyApp tup_ty) return_op
tup_ty = mkBigCoreVarTupTy bndrs
ds_inner (ParStmtBlock stmts bndrs return_op)
= do { exp <- dsInnerMonadComp stmts bndrs return_op
; return (exp, mkBigCoreVarTupTy bndrs) }
dsMcStmt stmt _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
......
......@@ -652,9 +652,9 @@ cvtStmt (NoBindS e) = do { e' <- cvtl e; returnL $ mkExprStmt e' }
cvtStmt (TH.BindS p e) = do { p' <- cvtPat p; e' <- cvtl e; returnL $ mkBindStmt p' e' }
cvtStmt (TH.LetS ds) = do { ds' <- cvtLocalDecs (ptext (sLit "a let binding")) ds
; returnL $ LetStmt ds' }
cvtStmt (TH.ParS dss) = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' noSyntaxExpr noSyntaxExpr noSyntaxExpr }
cvtStmt (TH.ParS dss) = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' noSyntaxExpr noSyntaxExpr }
where
cvt_one ds = do { ds' <- cvtStmts ds; return (ds', undefined) }
cvt_one ds = do { ds' <- cvtStmts ds; return (ParStmtBlock ds' undefined noSyntaxExpr) }
cvtMatch :: TH.Match -> CvtM (Hs.LMatch RdrName)
cvtMatch (TH.Match p body decs)
......
......@@ -875,11 +875,9 @@ data StmtLR idL idR
| LetStmt (HsLocalBindsLR idL idR)
-- ParStmts only occur in a list/monad comprehension
| ParStmt [([LStmt idL], [idR])]
| ParStmt [ParStmtBlock idL idR]
(SyntaxExpr idR) -- Polymorphic `mzip` for monad comprehensions
(SyntaxExpr idR) -- The `>>=` operator
(SyntaxExpr idR) -- Polymorphic `return` operator
-- with type (forall a. a -> m a)
-- See notes [Monad Comprehensions]
-- After renaming, the ids are the binders
-- bound by the stmts and used after themp
......@@ -943,6 +941,13 @@ data TransForm -- The 'f' below is the 'using' function, 'e' is the by function
= ThenForm -- then f or then f by e (depending on trS_by)
| GroupForm -- then group using f or then group by e using f (depending on trS_by)
deriving (Data, Typeable)
data ParStmtBlock idL idR
= ParStmtBlock
[LStmt idL]
[idR] -- The variables to be returned
(SyntaxExpr idR) -- The return operator
deriving( Data, Typeable )
\end{code}
Note [The type of bind in Stmts]
......@@ -1082,6 +1087,10 @@ In any other context than 'MonadComp', the fields for most of these
\begin{code}
instance (OutputableBndr idL, OutputableBndr idR)
=> Outputable (ParStmtBlock idL idR) where
ppr (ParStmtBlock stmts _ _) = interpp'SP stmts
instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where
ppr stmt = pprStmt stmt
......@@ -1090,11 +1099,12 @@ pprStmt (LastStmt expr _) = ifPprDebug (ptext (sLit "[last]")) <+> ppr e
pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds]
pprStmt (ExprStmt expr _ _ _) = ppr expr
pprStmt (ParStmt stmtss _ _ _) = hsep (map doStmts stmtss)
where doStmts stmts = ptext (sLit "| ") <> ppr stmts
pprStmt (ParStmt stmtss _ _) = sep (map doStmts stmtss)
where
doStmts stmts = ptext (sLit "|") <+> ppr stmts
pprStmt (TransStmt { trS_stmts = stmts, trS_by = by, trS_using = using, trS_form = form })
= sep (ppr_lc_stmts stmts ++ [pprTransStmt by using form])
= sep $ punctuate comma (map ppr stmts ++ [pprTransStmt by using form])
pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids
, recS_later_ids = later_ids })
......@@ -1138,16 +1148,17 @@ ppr_do_stmts stmts
= lbrace <+> pprDeeperList vcat (punctuate semi (map ppr stmts))
<+> rbrace
ppr_lc_stmts :: OutputableBndr id => [LStmt id] -> [SDoc]
ppr_lc_stmts stmts = [ppr s <> comma | s <- stmts]
pprComp :: OutputableBndr id => [LStmt id] -> SDoc
pprComp quals -- Prints: body | qual1, ..., qualn
| not (null quals)
, L _ (LastStmt body _) <- last quals
= hang (ppr body <+> char '|') 2 (interpp'SP (dropTail 1 quals))
= hang (ppr body <+> char '|') 2 (pprQuals (dropTail 1 quals))
| otherwise
= pprPanic "pprComp" (interpp'SP quals)
= pprPanic "pprComp" (pprQuals quals)
pprQuals :: OutputableBndr id => [LStmt id] -> SDoc
-- Show list comprehension qualifiers separated by commas
pprQuals quals = interpp'SP quals
\end{code}
%************************************************************************
......
......@@ -93,7 +93,7 @@ import SrcLoc
import FastString
import Util
import Bag
import Outputable
import Data.Either
\end{code}
......@@ -216,7 +216,8 @@ mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL id
mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
emptyTransStmt :: StmtLR idL idR
emptyTransStmt = TransStmt { trS_form = undefined, trS_stmts = [], trS_bndrs = []
emptyTransStmt = TransStmt { trS_form = panic "emptyTransStmt: form"
, trS_stmts = [], trS_bndrs = []
, trS_by = Nothing, trS_using = noLoc noSyntaxExpr
, trS_ret = noSyntaxExpr, trS_bind = noSyntaxExpr
, trS_fmap = noSyntaxExpr }
......@@ -538,8 +539,8 @@ collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
collectStmtBinders (LetStmt binds) = collectLocalBinders binds
collectStmtBinders (ExprStmt {}) = []
collectStmtBinders (LastStmt {}) = []
collectStmtBinders (ParStmt xs _ _ _) = collectLStmtsBinders
$ concatMap fst xs
collectStmtBinders (ParStmt xs _ _) = collectLStmtsBinders
$ [s | ParStmtBlock ss _ _ <- xs, s <- ss]
collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts
collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss
......@@ -714,8 +715,7 @@ lStmtsImplicits = hs_lstmts
hs_stmt (LetStmt binds) = hs_local_binds binds
hs_stmt (ExprStmt {}) = emptyNameSet
hs_stmt (LastStmt {}) = emptyNameSet
hs_stmt (ParStmt xs _ _ _) = hs_lstmts $ concatMap fst xs
hs_stmt (ParStmt xs _ _) = hs_lstmts [s | ParStmtBlock ss _ _ <- xs, s <- ss]
hs_stmt (TransStmt { trS_stmts = stmts }) = hs_lstmts stmts
hs_stmt (RecStmt { recS_stmts = ss }) = hs_lstmts ss
......
......@@ -1477,8 +1477,8 @@ hscDeclsWithLocation hsc_env0 str source linenumber =
{- Desugar it -}
-- We use a basically null location for iNTERACTIVE
let iNTERACTIVELoc = ModLocation{ ml_hs_file = Nothing,
ml_hi_file = undefined,
ml_obj_file = undefined}
ml_hi_file = panic "hsDeclsWithLocation:ml_hi_file",
ml_obj_file = panic "hsDeclsWithLocation:ml_hi_file"}
ds_result <- hscDesugar' iNTERACTIVELoc tc_gblenv
{- Simplify -}
......
......@@ -1582,7 +1582,8 @@ flattenedpquals :: { Located [LStmt RdrName] }
-- We just had one thing in our "parallel" list so
-- we simply return that thing directly
qss -> L1 [L1 $ ParStmt [(qs, undefined) | qs <- qss] noSyntaxExpr noSyntaxExpr noSyntaxExpr]
qss -> L1 [L1 $ ParStmt [ParStmtBlock qs undefined noSyntaxExpr | qs <- qss]
noSyntaxExpr noSyntaxExpr]
-- We actually found some actual parallel lists so
-- we wrap them into as a ParStmt
}
......
......@@ -544,8 +544,8 @@ methodNamesStmt (LastStmt cmd _) = methodNamesLCmd cmd
methodNamesStmt (ExprStmt cmd _ _ _) = methodNamesLCmd cmd
methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd
methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
methodNamesStmt (LetStmt _) = emptyFVs
methodNamesStmt (ParStmt _ _ _ _) = emptyFVs
methodNamesStmt (LetStmt {}) = emptyFVs
methodNamesStmt (ParStmt {}) = emptyFVs
methodNamesStmt (TransStmt {}) = emptyFVs
-- ParStmt and TransStmt can't occur in commands, but it's not convenient to error
-- here so we just do what's convenient
......@@ -767,12 +767,12 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
rnStmt ctxt (L loc (ParStmt segs _ _)) thing_inside
= do { (mzip_op, fvs1) <- lookupStmtName ctxt mzipName
; (bind_op, fvs2) <- lookupStmtName ctxt bindMName
; (return_op, fvs3) <- lookupStmtName ctxt returnMName
; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
; return ( ([L loc (ParStmt segs' mzip_op bind_op return_op)], thing)
; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) return_op segs thing_inside
; return ( ([L loc (ParStmt segs' mzip_op bind_op)], thing)
, fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
rnStmt ctxt (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = form
......@@ -810,27 +810,26 @@ rnStmt ctxt (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = form
, trS_ret = return_op, trS_bind = bind_op
, trS_fmap = fmap_op })], thing), all_fvs) }
type ParSeg id = ([LStmt id], [id]) -- The Names are bound by the Stmts
rnParallelStmts :: forall thing. HsStmtContext Name
-> [ParSeg RdrName]
-> SyntaxExpr Name
-> [ParStmtBlock RdrName RdrName]
-> ([Name] -> RnM (thing, FreeVars))
-> RnM (([ParSeg Name], thing), FreeVars)
-> RnM (([ParStmtBlock Name Name], thing), FreeVars)
-- Note [Renaming parallel Stmts]
rnParallelStmts ctxt segs thing_inside
rnParallelStmts ctxt return_op segs thing_inside
= do { orig_lcl_env <- getLocalRdrEnv
; rn_segs orig_lcl_env [] segs }
where
rn_segs :: LocalRdrEnv
-> [Name] -> [ParSeg RdrName]
-> RnM (([ParSeg Name], thing), FreeVars)
-> [Name] -> [ParStmtBlock RdrName RdrName]
-> RnM (([ParStmtBlock Name Name], thing), FreeVars)
rn_segs _ bndrs_so_far []
= do { let (bndrs', dups) = removeDups cmpByOcc bndrs_so_far
; mapM_ dupErr dups
; (thing, fvs) <- bindLocalNames bndrs' (thing_inside bndrs')
; return (([], thing), fvs) }
rn_segs env bndrs_so_far ((stmts,_) : segs)
rn_segs env bndrs_so_far (ParStmtBlock stmts _ _ : segs)
= do { ((stmts', (used_bndrs, segs', thing)), fvs)
<- rnStmts ctxt stmts $ \ bndrs ->
setLocalRdrEnv env $ do
......@@ -838,7 +837,7 @@ rnParallelStmts ctxt segs thing_inside
; let used_bndrs = filter (`elemNameSet` fvs) bndrs
; return ((used_bndrs, segs', thing), fvs) }
; let seg' = (stmts', used_bndrs)
; let seg' = ParStmtBlock stmts' used_bndrs return_op
; return ((seg':segs', thing), fvs) }
cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
......@@ -973,7 +972,7 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
= rn_rec_stmts_lhs fix_env stmts
rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _ _ _ _)) -- Syntactically illegal in mdo
rn_rec_stmt_lhs _ stmt@(L _ (ParStmt {})) -- Syntactically illegal in mdo
= pprPanic "rn_rec_stmt" (ppr stmt)
rn_rec_stmt_lhs _ stmt@(L _ (TransStmt {})) -- Syntactically illegal in mdo
......
......@@ -10,8 +10,6 @@
module TcErrors(
reportUnsolved, ErrEnv,
warnDefaulting,
unifyCtxt,
misMatchMsg,
flattenForAllErrorTcS,
solverDepthErrorTcS
......@@ -641,12 +639,6 @@ kindErrorMsg ty1 ty2
k2 = typeKind ty2
--------------------
unifyCtxt :: EqOrigin -> TidyEnv -> TcM (TidyEnv, SDoc)
unifyCtxt (UnifyOrigin { uo_actual = act_ty, uo_expected = exp_ty }) tidy_env
= do { (env1, act_ty') <- zonkTidyTcType tidy_env act_ty
; (env2, exp_ty') <- zonkTidyTcType env1 exp_ty
; return (env2, mkExpectedActualMsg exp_ty' act_ty') }
misMatchMsg :: Bool -> TcType -> TcType -> SDoc -- Types are already tidy
-- If oriented then ty1 is expected, ty2 is actual
misMatchMsg oriented ty1 ty2
......
......@@ -770,19 +770,18 @@ zonkStmts env (s:ss) = do { (env1, s') <- wrapLocSndM (zonkStmt env) s
; return (env2, s' : ss') }
zonkStmt :: ZonkEnv -> Stmt TcId -> TcM (ZonkEnv, Stmt Id)
zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op return_op)
= mappM zonk_branch stmts_w_bndrs `thenM` \ new_stmts_w_bndrs ->
let
new_binders = concat (map snd new_stmts_w_bndrs)
env1 = extendIdZonkEnv env new_binders
in
zonkExpr env1 mzip_op `thenM` \ new_mzip ->
zonkExpr env1 bind_op `thenM` \ new_bind ->
zonkExpr env1 return_op `thenM` \ new_return ->
return (env1, ParStmt new_stmts_w_bndrs new_mzip new_bind new_return)
zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op)
= do { new_stmts_w_bndrs <- mapM zonk_branch stmts_w_bndrs
; let new_binders = [b | ParStmtBlock _ bs _ <- new_stmts_w_bndrs, b <- bs]
env1 = extendIdZonkEnv env new_binders
; new_mzip <- zonkExpr env1 mzip_op
; new_bind <- zonkExpr env1 bind_op
; return (env1, ParStmt new_stmts_w_bndrs new_mzip new_bind) }
where
zonk_branch (stmts, bndrs) = zonkStmts env stmts `thenM` \ (env1, new_stmts) ->
returnM (new_stmts, zonkIdOccs env1 bndrs)
zonk_branch (ParStmtBlock stmts bndrs return_op)
= do { (env1, new_stmts) <- zonkStmts env stmts
; new_return <- zonkExpr env1 return_op
; return (ParStmtBlock new_stmts (zonkIdOccs env1 bndrs) new_return) }
zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_ids = rvs
, recS_ret_fn = ret_id, recS_mfix_fn = mfix_id, recS_bind_fn = bind_id
......
......@@ -31,7 +31,6 @@ import TcMType
import TcType
import TcBinds
import TcUnify
import TcErrors ( misMatchMsg )
import Name
import TysWiredIn
import Id
......@@ -398,21 +397,21 @@ tcLcStmt _ _ (ExprStmt rhs _ _ _) elt_ty thing_inside
; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr boolTy, thing) }
-- ParStmt: See notes with tcMcStmt
tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _ _) elt_ty thing_inside
tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _) elt_ty thing_inside
= do { (pairs', thing) <- loop bndr_stmts_s
; return (ParStmt pairs' noSyntaxExpr noSyntaxExpr noSyntaxExpr, thing) }
; return (ParStmt pairs' noSyntaxExpr noSyntaxExpr, thing) }
where
-- loop :: [([LStmt Name], [Name])] -> TcM ([([LStmt TcId], [TcId])], thing)
loop [] = do { thing <- thing_inside elt_ty
; return ([], thing) } -- matching in the branches
loop ((stmts, names) : pairs)
loop (ParStmtBlock stmts names _ : pairs)
= do { (stmts', (ids, pairs', thing))
<- tcStmtsAndThen ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
do { ids <- tcLookupLocalIds names
; (pairs', thing) <- loop pairs
; return (ids, pairs', thing) }
; return ( (stmts', ids) : pairs', thing ) }
; return ( ParStmtBlock stmts' ids noSyntaxExpr : pairs', thing ) }
tcLcStmt m_tc ctxt (TransStmt { trS_form = form, trS_stmts = stmts
, trS_bndrs = bindersMap
......@@ -675,7 +674,7 @@ tcMcStmt ctxt (TransStmt { trS_stmts = stmts, trS_bndrs = bindersMap
-- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call
-- -> m (st1, (st2, st3))
--
tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside
tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op) res_ty thing_inside
= do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
; m_ty <- newFlexiTyVarTy star_star_kind
......@@ -687,14 +686,10 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_insi
(m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy])
; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty
; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
mkForAllTy alphaTyVar $
alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
; (pairs', thing) <- loop m_ty bndr_stmts_s
; (blocks', thing) <- loop m_ty bndr_stmts_s
-- Typecheck bind:
; let tys = map (mkBigCoreVarTupTy . snd) pairs'
; let tys = [ mkBigCoreVarTupTy bs | ParStmtBlock _ bs _ <- blocks']
tuple_ty = mk_tuple_ty tys
; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
......@@ -702,7 +697,7 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_insi
`mkFunTy` (tuple_ty `mkFunTy` res_ty)
`mkFunTy` res_ty
; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
; return (ParStmt blocks' mzip_op' bind_op', thing) }
where
mk_tuple_ty tys = foldr1 (\tn tm -> mkBoxedTupleTy [tn, tm]) tys
......@@ -713,31 +708,19 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_insi
loop _ [] = do { thing <- thing_inside res_ty
; return ([], thing) } -- matching in the branches
loop m_ty ((stmts, names) : pairs)
loop m_ty (ParStmtBlock stmts names return_op : pairs)
= do { -- type dummy since we don't know all binder types yet
ty_dummy <- newFlexiTyVarTy liftedTypeKind
; (stmts', (ids, pairs', thing))
<- tcStmtsAndThen ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
id_tys <- mapM (const (newFlexiTyVarTy liftedTypeKind)) names
; let m_tup_ty = m_ty `mkAppTy` mkBigCoreTupTy id_tys
; (stmts', (ids, return_op', pairs', thing))
<- tcStmtsAndThen ctxt tcMcStmt stmts m_tup_ty $ \m_tup_ty' ->
do { ids <- tcLookupLocalIds names
; let m_tup_ty = m_ty `mkAppTy` mkBigCoreVarTupTy ids
; check_same m_tup_ty res_ty'
; check_same m_tup_ty ty_dummy
; let tup_ty = mkBigCoreVarTupTy ids
; return_op' <- tcSyntaxOp MCompOrigin return_op
(tup_ty `mkFunTy` m_tup_ty')
; (pairs', thing) <- loop m_ty pairs
; return (ids, pairs', thing) }
; return ( (stmts', ids) : pairs', thing ) }
-- Check that the types match up.
-- This is a grevious hack. They always *will* match
-- If (>>=) and (>>) are polymorpic in the return type,
-- but we don't have any good way to incorporate the coercion
-- so for now we just check that it's the identity
check_same actual expected
= do { co <- unifyType actual expected
; unless (isTcReflCo co) $
failWithMisMatch [UnifyOrigin { uo_expected = expected
, uo_actual = actual }] }
; return (ids, return_op', pairs', thing) }
; return (ParStmtBlock stmts' ids return_op' : pairs', thing) }
tcMcStmt _ stmt _ _
= pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
......@@ -877,22 +860,5 @@ checkArgs fun (MatchGroup (match1:matches) _)
args_in_match :: LMatch Name -> Int
args_in_match (L _ (Match pats _ _)) = length pats
checkArgs fun _ = pprPanic "TcPat.checkArgs" (ppr fun) -- Matches always non-empty
failWithMisMatch :: [EqOrigin] -> TcM a
-- Generate the message when two types fail to match,
-- going to some trouble to make it helpful.
-- We take the failing types from the top of the origin stack
-- rather than reporting the particular ones we are looking
-- at right now
failWithMisMatch (item:origin)
= wrapEqCtxt origin $
do { ty_act <- zonkTcType (uo_actual item)
; ty_exp <- zonkTcType (uo_expected item)
; env0 <- tcInitTidyEnv
; let (env1, pp_exp) = tidyOpenType env0 ty_exp
(env2, pp_act) = tidyOpenType env1 ty_act
; failWithTcM (env2, misMatchMsg True pp_exp pp_act) }
failWithMisMatch []
= panic "failWithMisMatch"
\end{code}
......@@ -31,7 +31,6 @@ module TcUnify (
matchExpectedFunTys,
matchExpectedFunKind,
wrapFunResCoercion,
wrapEqCtxt,
--------------------------------
-- Errors
......@@ -43,7 +42,6 @@ module TcUnify (
import HsSyn
import TypeRep
import TcErrors ( unifyCtxt )
import TcMType
import TcIface
import TcRnMonad
......@@ -1005,15 +1003,6 @@ we return a made-up TcTyVarDetails, but I think it works smoothly.
pushOrigin :: TcType -> TcType -> [EqOrigin] -> [EqOrigin]
pushOrigin ty_act ty_exp origin
= UnifyOrigin { uo_actual = ty_act, uo_expected = ty_exp } : origin
---------------
wrapEqCtxt :: [EqOrigin] -> TcM a -> TcM a
-- Build a suitable error context from the origin and do the thing inside
-- The "couldn't match" error comes from the innermost item on the stack,
-- and, if there is more than one item, the "Expected/inferred" part
-- comes from the outermost item
wrapEqCtxt [] thing_inside = thing_inside
wrapEqCtxt items thing_inside = addErrCtxtM (unifyCtxt (last items)) thing_inside
\end{code}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment