Commit c86e9006 authored by simonpj's avatar simonpj

[project @ 2003-02-26 17:04:11 by simonpj]

----------------------------------
	Improve higher-rank type inference
	----------------------------------

Yanling Wang pointed out that if we have

	f = \ (x :: forall a. a->a). x

it would be reasonable to expect that type inference would get the "right"
rank-2 type for f.  She also found that the plausible definition

	f :: (forall a. a->a) = \x -> x

acutally failed to type check.

This commit fixes up TcBinds.tcMonoBinds so that it does a better job.
The main idea is that there are three cases to consider in a function binding:

  a) 'f' has a separate type signature
	In this case, we know f's type everywhere

  b) The binding is recursive, and there is no type sig
	In this case we must give f a monotype in its RHS

  c) The binding is non-recursive, and there is no type sig
	Then we do not need to add 'f' to the envt, and can
	simply infer a type for the RHS, which may be higher
	ranked.
parent c68c1f2e
......@@ -40,7 +40,8 @@ import {-# SOURCE #-} TcExpr( tcExpr )
import HsSyn ( HsLit(..), HsOverLit(..), HsExpr(..) )
import TcHsSyn ( TcExpr, TcId, TcIdSet, TypecheckedHsExpr,
mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId
mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId,
mkCoercion, ExprCoFn
)
import TcRnMonad
import TcEnv ( tcGetInstEnv, tcLookupId, tcLookupTyCon, checkWellStaged, topIdLvl )
......@@ -256,7 +257,7 @@ newIPDict orig ip_name ty
\begin{code}
tcInstCall :: InstOrigin -> TcType -> TcM (TypecheckedHsExpr -> TypecheckedHsExpr, TcType)
tcInstCall :: InstOrigin -> TcType -> TcM (ExprCoFn, TcType)
tcInstCall orig fun_ty -- fun_ty is usually a sigma-type
= tcInstType VanillaTv fun_ty `thenM` \ (tyvars, theta, tau) ->
newDicts orig theta `thenM` \ dicts ->
......@@ -264,7 +265,7 @@ tcInstCall orig fun_ty -- fun_ty is usually a sigma-type
let
inst_fn e = mkHsDictApp (mkHsTyApp e (mkTyVarTys tyvars)) (map instToId dicts)
in
returnM (inst_fn, tau)
returnM (mkCoercion inst_fn, tau)
tcInstDataCon :: InstOrigin -> DataCon
-> TcM ([TcType], -- Types to instantiate at
......
......@@ -255,8 +255,9 @@ tcBindWithSigs top_lvl mbind sigs is_rec
) $
-- TYPECHECK THE BINDINGS
getLIE (tcMonoBinds mbind tc_ty_sigs is_rec) `thenM` \ ((mbind', binder_names, mono_ids), lie_req) ->
getLIE (tcMonoBinds mbind tc_ty_sigs is_rec) `thenM` \ ((mbind', bndr_names_w_ids), lie_req) ->
let
(binder_names, mono_ids) = unzip (bagToList bndr_names_w_ids)
tau_tvs = foldr (unionVarSet . tyVarsOfType . idType) emptyVarSet mono_ids
in
......@@ -620,91 +621,86 @@ The signatures have been dealt with already.
\begin{code}
tcMonoBinds :: RenamedMonoBinds
-> [TcSigInfo]
-> RecFlag
-> [TcSigInfo] -> RecFlag
-> TcM (TcMonoBinds,
[Name], -- Bound names
[TcId]) -- Corresponding monomorphic bound things
Bag (Name, -- Bound names
TcId)) -- Corresponding monomorphic bound things
tcMonoBinds mbinds tc_ty_sigs is_rec
= tc_mb_pats mbinds `thenM` \ (complete_it, tvs, ids, lie_avail) ->
let
id_list = bagToList ids
(names, mono_ids) = unzip id_list
-- This last defn is the key one:
-- extend the val envt with bindings for the
-- things bound in this group, overriding the monomorphic
-- ids with the polymorphic ones from the pattern
extra_val_env = case is_rec of
Recursive -> map mk_bind id_list
NonRecursive -> []
in
-- Don't know how to deal with pattern-bound existentials yet
checkTc (isEmptyBag tvs && null lie_avail)
(existentialExplode mbinds) `thenM_`
-- *Before* checking the RHSs, but *after* checking *all* the patterns,
-- extend the envt with bindings for all the bound ids;
-- and *then* override with the polymorphic Ids from the signatures
-- That is the whole point of the "complete_it" stuff.
--
-- There's a further wrinkle: we have to delay extending the environment
-- until after we've dealt with any pattern-bound signature type variables
-- Consider f (x::a) = ...f...
-- We're going to check that a isn't unified with anything in the envt,
-- so f itself had better not be! So we pass the envt binding f into
-- complete_it, which extends the actual envt in TcMatches.tcMatch, after
-- dealing with the signature tyvars
complete_it extra_val_env `thenM` \ mbinds' ->
returnM (mbinds', names, mono_ids)
-- Three stages:
-- 1. Check the patterns, building up an environment binding
-- the variables in this group (in the recursive case)
-- 2. Extend the environment
-- 3. Check the RHSs
= tc_mb_pats mbinds `thenM` \ (complete_it, xve) ->
tcExtendLocalValEnv2 (bagToList xve) complete_it
where
mk_bind (name, mono_id) = case maybeSig tc_ty_sigs name of
Nothing -> (name, mono_id)
Just sig -> (idName poly_id, poly_id)
where
poly_id = tcSigPolyId sig
tc_mb_pats EmptyMonoBinds
= returnM (\ xve -> returnM EmptyMonoBinds, emptyBag, emptyBag, [])
tc_mb_pats EmptyMonoBinds
= returnM (returnM (EmptyMonoBinds, emptyBag), emptyBag)
tc_mb_pats (AndMonoBinds mb1 mb2)
= tc_mb_pats mb1 `thenM` \ (complete_it1, tvs1, ids1, lie_avail1) ->
tc_mb_pats mb2 `thenM` \ (complete_it2, tvs2, ids2, lie_avail2) ->
= tc_mb_pats mb1 `thenM` \ (complete_it1, xve1) ->
tc_mb_pats mb2 `thenM` \ (complete_it2, xve2) ->
let
complete_it xve = complete_it1 xve `thenM` \ mb1' ->
complete_it2 xve `thenM` \ mb2' ->
returnM (AndMonoBinds mb1' mb2')
complete_it = complete_it1 `thenM` \ (mb1', bs1) ->
complete_it2 `thenM` \ (mb2', bs2) ->
returnM (AndMonoBinds mb1' mb2', bs1 `unionBags` bs2)
in
returnM (complete_it,
tvs1 `unionBags` tvs2,
ids1 `unionBags` ids2,
lie_avail1 ++ lie_avail2)
returnM (complete_it, xve1 `unionBags` xve2)
tc_mb_pats (FunMonoBind name inf matches locn)
= (case maybeSig tc_ty_sigs name of
Just sig -> returnM (tcSigMonoId sig)
Nothing -> newLocalName name `thenM` \ bndr_name ->
newTyVarTy openTypeKind `thenM` \ bndr_ty ->
-- NB: not a 'hole' tyvar; since there is no type
-- signature, we revert to ordinary H-M typechecking
-- which means the variable gets an inferred tau-type
returnM (mkLocalId bndr_name bndr_ty)
) `thenM` \ bndr_id ->
-- Three cases:
-- a) Type sig supplied
-- b) No type sig and recursive
-- c) No type sig and non-recursive
| Just sig <- maybeSig tc_ty_sigs name
= let -- (a) There is a type signature
-- Use it for the environment extension, and check
-- the RHS has the appropriate type (with outer for-alls stripped off)
mono_id = tcSigMonoId sig
mono_ty = idType mono_id
complete_it = addSrcLoc locn $
tcMatchesFun name mono_ty matches `thenM` \ matches' ->
returnM (FunMonoBind mono_id inf matches' locn,
unitBag (name, mono_id))
in
returnM (complete_it, if isRec is_rec then unitBag (name,tcSigPolyId sig)
else emptyBag)
| isRec is_rec
= -- (b) No type signature, and recursive
-- So we must use an ordinary H-M type variable
-- which means the variable gets an inferred tau-type
newLocalName name `thenM` \ mono_name ->
newTyVarTy openTypeKind `thenM` \ mono_ty ->
let
bndr_ty = idType bndr_id
complete_it xve = addSrcLoc locn $
tcMatchesFun xve name bndr_ty matches `thenM` \ matches' ->
returnM (FunMonoBind bndr_id inf matches' locn)
mono_id = mkLocalId mono_name mono_ty
complete_it = addSrcLoc locn $
tcMatchesFun name mono_ty matches `thenM` \ matches' ->
returnM (FunMonoBind mono_id inf matches' locn,
unitBag (name, mono_id))
in
returnM (complete_it, emptyBag, unitBag (name, bndr_id), [])
returnM (complete_it, unitBag (name, mono_id))
| otherwise -- (c) No type signature, and non-recursive
= let -- So we can use a 'hole' type to infer a higher-rank type
complete_it
= addSrcLoc locn $
newHoleTyVarTy `thenM` \ fun_ty ->
tcMatchesFun name fun_ty matches `thenM` \ matches' ->
readHoleResult fun_ty `thenM` \ fun_ty' ->
newLocalName name `thenM` \ mono_name ->
let
mono_id = mkLocalId mono_name fun_ty'
in
returnM (FunMonoBind mono_id inf matches' locn,
unitBag (name, mono_id))
in
returnM (complete_it, emptyBag)
tc_mb_pats bind@(PatMonoBind pat grhss locn)
= addSrcLoc locn $
newHoleTyVarTy `thenM` \ pat_ty ->
-- Now typecheck the pattern
-- We do now support binding fresh (not-already-in-scope) scoped
......@@ -714,16 +710,21 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
-- The type variables are brought into scope in tc_binds_and_then,
-- so we don't have to do anything here.
tcPat tc_pat_bndr pat pat_ty `thenM` \ (pat', tvs, ids, lie_avail) ->
readHoleResult pat_ty `thenM` \ pat_ty' ->
newHoleTyVarTy `thenM` \ pat_ty ->
tcPat tc_pat_bndr pat pat_ty `thenM` \ (pat', tvs, ids, lie_avail) ->
readHoleResult pat_ty `thenM` \ pat_ty' ->
-- Don't know how to deal with pattern-bound existentials yet
checkTc (isEmptyBag tvs && null lie_avail)
(existentialExplode bind) `thenM_`
let
complete_it xve = addSrcLoc locn $
addErrCtxt (patMonoBindsCtxt bind) $
tcExtendLocalValEnv2 xve $
tcGRHSs PatBindRhs grhss pat_ty' `thenM` \ grhss' ->
returnM (PatMonoBind pat' grhss' locn)
complete_it = addSrcLoc locn $
addErrCtxt (patMonoBindsCtxt bind) $
tcGRHSs PatBindRhs grhss pat_ty' `thenM` \ grhss' ->
returnM (PatMonoBind pat' grhss' locn, ids)
in
returnM (complete_it, tvs, ids, lie_avail)
returnM (complete_it, if isRec is_rec then ids else emptyBag)
-- tc_pat_bndr is used when dealing with a LHS binder in a pattern.
-- If there was a type sig for that Id, we want to make it much
......@@ -735,9 +736,8 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
tc_pat_bndr name pat_ty
= case maybeSig tc_ty_sigs name of
Nothing
-> newLocalName name `thenM` \ bndr_name ->
tcMonoPatBndr bndr_name pat_ty
Nothing -> newLocalName name `thenM` \ bndr_name ->
tcMonoPatBndr bndr_name pat_ty
Just sig -> addSrcLoc (getSrcLoc name) $
tcSubPat (idType mono_id) pat_ty `thenM` \ co_fn ->
......
......@@ -457,7 +457,7 @@ tcMethodBind xtve inst_tyvars inst_theta avail_insts prags
tcExtendTyVarEnv2 xtve (
addErrCtxt (methodCtxt sel_id) $
getLIE (tcMonoBinds meth_bind [meth_sig] NonRecursive)
) `thenM` \ ((meth_bind, _, _), meth_lie) ->
) `thenM` \ ((meth_bind, _), meth_lie) ->
-- Now do context reduction. We simplify wrt both the local tyvars
-- and the ones of the class/instance decl, so that there is
......
......@@ -19,11 +19,10 @@ import qualified DsMeta
import HsSyn ( HsExpr(..), HsLit(..), ArithSeqInfo(..), recBindFields )
import RnHsSyn ( RenamedHsExpr, RenamedRecordBinds )
import TcHsSyn ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet )
import TcHsSyn ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet, (<$>) )
import TcRnMonad
import TcUnify ( tcSubExp, tcGen, (<$>),
unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy,
unifyTupleTy )
import TcUnify ( tcSubExp, tcGen,
unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy )
import BasicTypes ( isMarkedStrict )
import Inst ( InstOrigin(..),
newOverloadedLit, newMethodFromName, newIPDict,
......@@ -34,7 +33,7 @@ import TcBinds ( tcBindsAndThen )
import TcEnv ( tcLookupClass, tcLookupGlobal_maybe, tcLookupIdLvl,
tcLookupTyCon, tcLookupDataCon, tcLookupId
)
import TcMatches ( tcMatchesCase, tcMatchLambda, tcDoStmts )
import TcMatches ( tcMatchesCase, tcMatchLambda, tcDoStmts, tcThingWithSig )
import TcMonoType ( tcHsSigType, UserTypeCtxt(..) )
import TcPat ( badFieldCon )
import TcMType ( tcInstTyVars, tcInstType, newHoleTyVarTy, zapToType,
......@@ -136,17 +135,10 @@ tcMonoExpr (HsIPVar ip) res_ty
\begin{code}
tcMonoExpr in_expr@(ExprWithTySig expr poly_ty) res_ty
= addErrCtxt (exprSigCtxt in_expr) $
tcHsSigType ExprSigCtxt poly_ty `thenM` \ sig_tc_ty ->
tcExpr expr sig_tc_ty `thenM` \ expr' ->
-- Must instantiate the outer for-alls of sig_tc_ty
-- else we risk instantiating a ? res_ty to a forall-type
-- which breaks the invariant that tcMonoExpr only returns phi-types
tcInstCall SignatureOrigin sig_tc_ty `thenM` \ (inst_fn, inst_sig_ty) ->
tcSubExp res_ty inst_sig_ty `thenM` \ co_fn ->
returnM (co_fn <$> inst_fn expr')
= addErrCtxt (exprSigCtxt in_expr) $
tcHsSigType ExprSigCtxt poly_ty `thenM` \ sig_tc_ty ->
tcThingWithSig sig_tc_ty (tcMonoExpr expr) res_ty `thenM` \ (co_fn, expr') ->
returnM (co_fn <$> expr')
tcMonoExpr (HsType ty) res_ty
= failWithTc (text "Can't handle type argument:" <+> ppr ty)
......@@ -832,7 +824,7 @@ tcId name -- Look up the Id and instantiate its type
loop fun fun_ty
| isSigmaTy fun_ty
= tcInstCall orig fun_ty `thenM` \ (inst_fn, tau) ->
loop (inst_fn fun) tau
loop (inst_fn <$> fun) tau
| otherwise
= returnM (fun, fun_ty)
......
......@@ -27,6 +27,11 @@ module TcHsSyn (
mkHsTyLam, mkHsDictLam, mkHsLet,
hsLitType, hsPatType,
-- Coercions
Coercion, ExprCoFn, PatCoFn,
(<$>), (<.>), mkCoercion,
idCoercion, isIdCoercion,
-- re-exported from TcMonad
TcId, TcIdSet,
......@@ -65,6 +70,7 @@ import VarSet
import VarEnv
import BasicTypes ( RecFlag(..), Boxity(..), IPName(..), ipNameName, mapIPName )
import Maybes ( orElse )
import Maybe ( isNothing )
import Unique ( Uniquable(..) )
import SrcLoc ( noSrcLoc )
import Bag
......@@ -182,12 +188,37 @@ hsLitType (HsDoublePrim d) = doublePrimTy
hsLitType (HsLitLit _ ty) = ty
\end{code}
%************************************************************************
%* *
\subsection{Coercion functions}
%* *
%************************************************************************
\begin{code}
-- zonkId is used *during* typechecking just to zonk the Id's type
zonkId :: TcId -> TcM TcId
zonkId id
= zonkTcType (idType id) `thenM` \ ty' ->
returnM (setIdType id ty')
type Coercion a = Maybe (a -> a)
-- Nothing => identity fn
type ExprCoFn = Coercion TypecheckedHsExpr
type PatCoFn = Coercion TcPat
(<.>) :: Coercion a -> Coercion a -> Coercion a -- Composition
Nothing <.> Nothing = Nothing
Nothing <.> Just f = Just f
Just f <.> Nothing = Just f
Just f1 <.> Just f2 = Just (f1 . f2)
(<$>) :: Coercion a -> a -> a
Just f <$> e = f e
Nothing <$> e = e
mkCoercion :: (a -> a) -> Coercion a
mkCoercion f = Just f
idCoercion :: Coercion a
idCoercion = Nothing
isIdCoercion :: Coercion a -> Bool
isIdCoercion = isNothing
\end{code}
......@@ -197,7 +228,16 @@ zonkId id
%* *
%************************************************************************
This zonking pass runs over the bindings
\begin{code}
-- zonkId is used *during* typechecking just to zonk the Id's type
zonkId :: TcId -> TcM TcId
zonkId id
= zonkTcType (idType id) `thenM` \ ty' ->
returnM (setIdType id ty')
\end{code}
The rest of the zonking is done *after* typechecking.
The main zonking pass runs over the bindings
a) to convert TcTyVars to TyVars etc, dereferencing any bindings etc
b) convert unbound TcTyVar to Void
......
......@@ -8,8 +8,7 @@ _declarations_
-> TcType.TcType
-> TcMonad.TcM s (TcHsSyn.TcGRHSs, TcMonad.LIE) ;;
3 tcMatchesFun _:_ _forall_ [s] =>
[(Name.Name,Var.Id)]
-> Name.Name
Name.Name
-> TcType.TcType
-> [RnHsSyn.RenamedMatch]
-> TcMonad.TcM s ([TcHsSyn.TcMatch], TcMonad.LIE) ;;
......
......@@ -5,8 +5,7 @@ __export TcMatches tcGRHSs tcMatchesFun;
-> TcType.TcType
-> TcRnTypes.TcM TcHsSyn.TcGRHSs ;
1 tcMatchesFun ::
[(Name.Name,Var.Id)]
-> Name.Name
Name.Name
-> TcType.TcType
-> [RnHsSyn.RenamedMatch]
-> TcRnTypes.TcM [TcHsSyn.TcMatch] ;
......
......@@ -5,9 +5,7 @@ tcGRHSs :: HsExpr.HsMatchContext Name.Name
-> TcType.TcType
-> TcRnTypes.TcM TcHsSyn.TcGRHSs
tcMatchesFun ::
[(Name.Name,Var.Id)]
-> Name.Name
tcMatchesFun :: Name.Name
-> TcType.TcType
-> [RnHsSyn.RenamedMatch]
-> TcRnTypes.TcM [TcHsSyn.TcMatch]
......
......@@ -5,7 +5,7 @@
\begin{code}
module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda,
tcDoStmts, tcStmtsAndThen, tcGRHSs
tcDoStmts, tcStmtsAndThen, tcGRHSs, tcThingWithSig
) where
#include "HsVersions.h"
......@@ -21,20 +21,22 @@ import HsSyn ( HsExpr(..), HsBinds(..), Match(..), GRHSs(..), GRHS(..),
import RnHsSyn ( RenamedMatch, RenamedGRHSs, RenamedStmt,
RenamedPat, RenamedMatchContext )
import TcHsSyn ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, TcHsBinds,
TcMonoBinds, TcPat, TcStmt )
TcMonoBinds, TcPat, TcStmt, ExprCoFn,
isIdCoercion, (<$>), (<.>) )
import TcRnMonad
import TcMonoType ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) )
import Inst ( tcSyntaxName )
import Inst ( tcSyntaxName, tcInstCall )
import TcEnv ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
import TcPat ( tcPat, tcMonoPatBndr )
import TcMType ( newTyVarTy, newTyVarTys, zonkTcType, zapToType )
import TcType ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType,
import TcType ( TcType, TcTyVar, TcSigmaType, TcRhoType,
tyVarsOfType, tidyOpenTypes, tidyOpenType, isSigmaTy,
mkFunTy, isOverloadedTy, liftedTypeKind, openTypeKind,
mkArrowKind, mkAppTy )
import TcBinds ( tcBindsAndThen )
import TcUnify ( unifyPArrTy,subFunTy, unifyListTy, unifyTauTy,
checkSigTyVarsWrt, tcSubExp, isIdCoercion, (<$>) )
checkSigTyVarsWrt, tcSubExp, tcGen )
import TcSimplify ( tcSimplifyCheck, bindInstsOfLocalFuns )
import Name ( Name )
import PrelNames ( monadNames, mfixName )
......@@ -63,13 +65,12 @@ is used in error messages. It checks that all the equations have the
same number of arguments before using @tcMatches@ to do the work.
\begin{code}
tcMatchesFun :: [(Name,Id)] -- Bindings for the variables bound in this group
-> Name
tcMatchesFun :: Name
-> TcType -- Expected type
-> [RenamedMatch]
-> TcM [TcMatch]
tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
tcMatchesFun fun_name expected_ty matches@(first_match:_)
= -- Check that they all have the same no of arguments
-- Set the location to that of the first equation, so that
-- any inter-equation error messages get some vaguely
......@@ -86,7 +87,7 @@ tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
-- may show up as something wrong with the (non-existent) type signature
-- No need to zonk expected_ty, because subFunTy does that on the fly
tcMatches xve (FunRhs fun_name) matches expected_ty
tcMatches (FunRhs fun_name) matches expected_ty
\end{code}
@tcMatchesCase@ doesn't do the argument-count check because the
......@@ -100,22 +101,21 @@ tcMatchesCase :: [RenamedMatch] -- The case alternatives
tcMatchesCase matches expr_ty
= newTyVarTy openTypeKind `thenM` \ scrut_ty ->
tcMatches [] CaseAlt matches (mkFunTy scrut_ty expr_ty) `thenM` \ matches' ->
tcMatches CaseAlt matches (mkFunTy scrut_ty expr_ty) `thenM` \ matches' ->
returnM (scrut_ty, matches')
tcMatchLambda :: RenamedMatch -> TcType -> TcM TcMatch
tcMatchLambda match res_ty = tcMatch [] LambdaExpr match res_ty
tcMatchLambda match res_ty = tcMatch LambdaExpr match res_ty
\end{code}
\begin{code}
tcMatches :: [(Name,Id)]
-> RenamedMatchContext
tcMatches :: RenamedMatchContext
-> [RenamedMatch]
-> TcType
-> TcM [TcMatch]
tcMatches xve ctxt matches expected_ty
tcMatches ctxt matches expected_ty
= -- If there is more than one branch, and expected_ty is a 'hole',
-- all branches must be types, not type schemes, otherwise the
-- in which we check them would affect the result.
......@@ -126,7 +126,7 @@ tcMatches xve ctxt matches expected_ty
mappM (tc_match expected_ty') matches
where
tc_match expected_ty match = tcMatch xve ctxt match expected_ty
tc_match expected_ty match = tcMatch ctxt match expected_ty
\end{code}
......@@ -137,8 +137,7 @@ tcMatches xve ctxt matches expected_ty
%************************************************************************
\begin{code}
tcMatch :: [(Name,Id)]
-> RenamedMatchContext
tcMatch :: RenamedMatchContext
-> RenamedMatch
-> TcType -- Expected result-type of the Match.
-- Early unification with this guy gives better error messages
......@@ -147,7 +146,7 @@ tcMatch :: [(Name,Id)]
-- where there are n patterns.
-> TcM TcMatch
tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
tcMatch ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
= addSrcLoc (getMatchLoc match) $ -- At one stage I removed this;
addErrCtxt (matchCtxt ctxt match) $ -- I'm not sure why, so I put it back
tcMatchPats pats expected_ty tc_grhss `thenM` \ (pats', grhss', ex_binds) ->
......@@ -155,17 +154,14 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
where
tc_grhss rhs_ty
= tcExtendLocalValEnv2 xve1 $
-- Deal with the result signature
= -- Deal with the result signature
case maybe_rhs_sig of
Nothing -> tcGRHSs ctxt grhss rhs_ty
Just sig -> tcAddScopedTyVars [sig] $
-- Bring into scope the type variables in the signature
tcHsSigType ResSigCtxt sig `thenM` \ sig_ty ->
tcGRHSs ctxt grhss sig_ty `thenM` \ grhss' ->
tcSubExp rhs_ty sig_ty `thenM` \ co_fn ->
tcHsSigType ResSigCtxt sig `thenM` \ sig_ty ->
tcThingWithSig sig_ty (tcGRHSs ctxt grhss) rhs_ty `thenM` \ (co_fn, grhss') ->
returnM (lift_grhss co_fn rhs_ty grhss')
-- lift_grhss pushes the coercion down to the right hand sides,
......@@ -173,7 +169,7 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
lift_grhss co_fn rhs_ty grhss
| isIdCoercion co_fn = grhss
lift_grhss co_fn rhs_ty (GRHSs grhss binds ty)
= GRHSs (map lift_grhs grhss) binds rhs_ty -- Change the type, since we
= GRHSs (map lift_grhs grhss) binds rhs_ty -- Change the type, since the coercion does
where
lift_grhs (GRHS stmts loc) = GRHS (map lift_stmt stmts) loc
......@@ -206,6 +202,31 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
\end{code}
\begin{code}
tcThingWithSig :: TcSigmaType -- Type signature
-> (TcRhoType -> TcM r) -- How to type check the thing inside
-> TcRhoType -- Overall expected result type
-> TcM (ExprCoFn, r)
-- Used for expressions with a type signature, and for result type signatures
tcThingWithSig sig_ty thing_inside res_ty
| not (isSigmaTy sig_ty)
= thing_inside sig_ty `thenM` \ result ->
tcSubExp res_ty sig_ty `thenM` \ co_fn ->
returnM (co_fn, result)
| otherwise -- The signature has some outer foralls
= -- Must instantiate the outer for-alls of sig_tc_ty
-- else we risk instantiating a ? res_ty to a forall-type
-- which breaks the invariant that tcMonoExpr only returns phi-types
tcGen sig_ty emptyVarSet thing_inside `thenM` \ (gen_fn, result) ->
tcInstCall SignatureOrigin sig_ty `thenM` \ (inst_fn, inst_sig_ty) ->
tcSubExp res_ty inst_sig_ty `thenM` \ co_fn ->
returnM (co_fn <.> inst_fn <.> gen_fn, result)
-- Note that we generalise, then instantiate. Ah well.
\end{code}
%************************************************************************
%* *
\subsection{tcMatchPats}
......
......@@ -12,7 +12,9 @@ module TcPat ( tcPat, tcMonoPatBndr, tcSubPat,
import HsSyn ( Pat(..), HsConDetails(..), HsLit(..), HsOverLit(..), HsExpr(..) )
import RnHsSyn ( RenamedPat )
import TcHsSyn ( TcPat, TcId, hsLitType )
import TcHsSyn ( TcPat, TcId, hsLitType,
mkCoercion, idCoercion, isIdCoercion,
(<$>), PatCoFn )
import TcRnMonad
import Inst ( InstOrigin(..),
......@@ -27,9 +29,7 @@ import TcMType ( newTyVarTy, zapToType, arityErr )
import TcType ( TcType, TcTyVar, TcSigmaType,
mkClassPred, liftedTypeKind )
import TcUnify ( tcSubOff, TcHoleType,
unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy,
mkCoercion, idCoercion, isIdCoercion,
(<$>), PatCoFn )
unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy )
import TcMonoType ( tcHsSigType, UserTypeCtxt(..) )
import TysWiredIn ( stringTy )
......
......@@ -12,12 +12,7 @@ module TcUnify (
-- Various unifications
unifyTauTy, unifyTauTyList, unifyTauTyLists,
unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy,
unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind,
-- Coercions
Coercion, ExprCoFn, PatCoFn,
(<$>), (<.>), mkCoercion,
idCoercion, isIdCoercion
unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind
) where
......@@ -25,7 +20,8 @@ module TcUnify (
import HsSyn ( HsExpr(..) )
import TcHsSyn ( TypecheckedHsExpr, TcPat, mkHsLet )
import TcHsSyn ( TypecheckedHsExpr, TcPat, mkHsLet,
ExprCoFn, idCoercion, isIdCoercion, mkCoercion, (<.>), (<$>) )
import TypeRep ( Type(..), SourceType(..), TyNote(..), openKindCon )
import TcRnMonad -- TcType, amongst others
......@@ -181,7 +177,7 @@ tc_sub exp_sty expected_ty act_sty actual_ty
| isSigmaTy actual_ty
= tcInstCall Rank2Origin actual_ty `thenM` \ (inst_fn, body_ty) ->
tc_sub exp_sty expected_ty body_ty body_ty `thenM` \ co_fn ->
returnM (co_fn <.> mkCoercion inst_fn)
returnM (co_fn <.> inst_fn)
-----------------------------------
-- Function case
......@@ -351,39 +347,6 @@ tcGen expected_ty extra_tvs thing_inside -- We expect expected_ty to be a forall
%************************************************************************
%* *
\subsection{Coercion functions}
%* *
%************************************************************************
\begin{code}
type Coercion a = Maybe (a -> a)
-- Nothing => identity fn
type ExprCoFn = Coercion TypecheckedHsExpr
type PatCoFn = Coercion TcPat
(<.>) :: Coercion a -> Coercion a -> Coercion a -- Composition
Nothing <.> Nothing = Nothing
Nothing <.> Just f = Just f
Just f <.> Nothing = Just f
Just f1 <.> Just f2 = Just (f1 . f2)
(<$>) :: Coercion a -> a -> a
Just f <$> e = f e
Nothing <$> e = e
mkCoercion :: (a -> a) -> Coercion a
mkCoercion f = Just f
idCoercion :: Coercion a
idCoercion = Nothing
isIdCoercion :: Coercion a -> Bool
isIdCoercion = isNothing
\end{code}
%************************************************************************
%* *
\subsection[Unify-exported]{Exported unification functions}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment