Commit f4510d27 authored by simonpj@microsoft.com's avatar simonpj@microsoft.com
Browse files

Use implication constraints to improve type inference

parent d29f86b1
......@@ -163,7 +163,7 @@ checkForConflicts inst_envs famInst
Nothing -> panic "FamInst.checkForConflicts"
Just (tc, tys) -> tc `mkTyConApp` tys
}
; (tvs', _, tau') <- tcInstSkolType (FamInstSkol tycon) ty
; (tvs', _, tau') <- tcInstSkolType FamInstSkol ty
; let (fam, tys') = tcSplitTyConApp tau'
......
......@@ -26,17 +26,17 @@ module Inst (
ipNamesOfInst, ipNamesOfInsts, fdPredsOfInst, fdPredsOfInsts,
instLoc, getDictClassTys, dictPred,
lookupInst, LookupInstResult(..), lookupPred,
lookupSimpleInst, LookupInstResult(..), lookupPred,
tcExtendLocalInstEnv, tcGetInstEnvs, getOverlapFlag,
isDict, isClassDict, isMethod,
isIPDict, isInheritableInst,
isTyVarDict, isMethodFor,
isDict, isClassDict, isMethod, isImplicInst,
isIPDict, isInheritableInst, isMethodOrLit,
isTyVarDict, isMethodFor, getDefaultableDicts,
zonkInst, zonkInsts,
instToId, instToVar, instName,
InstOrigin(..), InstLoc(..), pprInstLoc
InstOrigin(..), InstLoc, pprInstLoc
) where
#include "HsVersions.h"
......@@ -53,6 +53,7 @@ import FunDeps
import TcMType
import TcType
import Type
import Class
import Unify
import Module
import Coercion
......@@ -73,6 +74,7 @@ import BasicTypes
import SrcLoc
import DynFlags
import Maybes
import Util
import Outputable
\end{code}
......@@ -96,6 +98,25 @@ instToVar (Method {tci_id = id})
instToVar (Dict {tci_name = nm, tci_pred = pred})
| isEqPred pred = Var.mkCoVar nm (mkPredTy pred)
| otherwise = mkLocalId nm (mkPredTy pred)
instToVar (ImplicInst {tci_name = nm, tci_tyvars = tvs, tci_given = givens,
tci_wanted = wanteds})
= mkLocalId nm (mkImplicTy tvs givens wanteds)
instType :: Inst -> Type
instType (LitInst {tci_ty = ty}) = ty
instType (Method {tci_id = id}) = idType id
instType (Dict {tci_pred = pred}) = mkPredTy pred
instType imp@(ImplicInst {}) = mkImplicTy (tci_tyvars imp) (tci_given imp)
(tci_wanted imp)
mkImplicTy tvs givens wanteds -- The type of an implication constraint
= -- pprTrace "mkImplicTy" (ppr givens) $
mkForAllTys tvs $
mkPhiTy (map dictPred givens) $
if isSingleton wanteds then
instType (head wanteds)
else
mkTupleTy Boxed (length wanteds) (map instType wanteds)
instLoc inst = tci_loc inst
......@@ -111,9 +132,11 @@ getDictClassTys inst = pprPanic "getDictClassTys" (ppr inst)
-- Leaving these in is really important for the call to fdPredsOfInsts
-- in TcSimplify.inferLoop, because the result is fed to 'grow',
-- which is supposed to be conservative
fdPredsOfInst (Dict {tci_pred = pred}) = [pred]
fdPredsOfInst (Method {tci_theta = theta}) = theta
fdPredsOfInst other = [] -- LitInsts etc
fdPredsOfInst (Dict {tci_pred = pred}) = [pred]
fdPredsOfInst (Method {tci_theta = theta}) = theta
fdPredsOfInst (ImplicInst {tci_given = gs,
tci_wanted = ws}) = fdPredsOfInsts (gs ++ ws)
fdPredsOfInst (LitInst {}) = []
fdPredsOfInsts :: [Inst] -> [PredType]
fdPredsOfInsts insts = concatMap fdPredsOfInst insts
......@@ -123,22 +146,27 @@ isInheritableInst (Method {tci_theta = theta}) = all isInheritablePred theta
isInheritableInst other = True
---------------------------------
-- Get the implicit parameters mentioned by these Insts
-- NB: the results of these functions are insensitive to zonking
ipNamesOfInsts :: [Inst] -> [Name]
ipNamesOfInst :: Inst -> [Name]
-- Get the implicit parameters mentioned by these Insts
-- NB: ?x and %x get different Names
ipNamesOfInsts insts = [n | inst <- insts, n <- ipNamesOfInst inst]
ipNamesOfInst (Dict {tci_pred = IParam n _}) = [ipNameName n]
ipNamesOfInst (Method {tci_theta = theta}) = [ipNameName n | IParam n _ <- theta]
ipNamesOfInst other = []
---------------------------------
tyVarsOfInst :: Inst -> TcTyVarSet
tyVarsOfInst (LitInst {tci_ty = ty}) = tyVarsOfType ty
tyVarsOfInst (Dict {tci_pred = pred}) = tyVarsOfPred pred
tyVarsOfInst (Method {tci_oid = id, tci_tys = tys}) = tyVarsOfTypes tys `unionVarSet` idFreeTyVars id
-- The id might have free type variables; in the case of
-- locally-overloaded class methods, for example
tyVarsOfInst (ImplicInst {tci_tyvars = tvs, tci_given = givens, tci_wanted = wanteds})
= (tyVarsOfInsts givens `unionVarSet` tyVarsOfInsts wanteds) `minusVarSet` mkVarSet tvs
tyVarsOfInsts insts = foldr (unionVarSet . tyVarsOfInst) emptyVarSet insts
......@@ -164,6 +192,9 @@ isIPDict :: Inst -> Bool
isIPDict (Dict {tci_pred = pred}) = isIPPred pred
isIPDict other = False
isImplicInst (ImplicInst {}) = True
isImplicInst other = False
isMethod :: Inst -> Bool
isMethod (Method {}) = True
isMethod other = False
......@@ -171,9 +202,33 @@ isMethod other = False
isMethodFor :: TcIdSet -> Inst -> Bool
isMethodFor ids (Method {tci_oid = id}) = id `elemVarSet` ids
isMethodFor ids inst = False
\end{code}
isMethodOrLit :: Inst -> Bool
isMethodOrLit (Method {}) = True
isMethodOrLit (LitInst {}) = True
isMethodOrLit other = False
\end{code}
\begin{code}
getDefaultableDicts :: [Inst] -> ([(Inst, Class, TcTyVar)], TcTyVarSet)
-- Look for free dicts of the form (C tv), even inside implications
-- *and* the set of tyvars mentioned by all *other* constaints
-- This disgustingly ad-hoc function is solely to support defaulting
getDefaultableDicts insts
= (concat ps, unionVarSets tvs)
where
(ps, tvs) = mapAndUnzip get insts
get d@(Dict {tci_pred = ClassP cls [ty]})
| Just tv <- tcGetTyVar_maybe ty = ([(d,cls,tv)], emptyVarSet)
| otherwise = ([], tyVarsOfType ty)
get (ImplicInst {tci_tyvars = tvs, tci_wanted = wanteds})
= ([ up | up@(_,_,tv) <- ups, not (tv `elemVarSet` tv_set)],
ftvs `minusVarSet` tv_set)
where
tv_set = mkVarSet tvs
(ups, ftvs) = getDefaultableDicts wanteds
get inst = ([], tyVarsOfInst inst)
\end{code}
%************************************************************************
%* *
......@@ -197,7 +252,7 @@ newDictBndrs inst_loc theta = mapM (newDictBndr inst_loc) theta
newDictBndr :: InstLoc -> TcPredType -> TcM Inst
newDictBndr inst_loc pred
= do { uniq <- newUnique
; let name = mkPredName uniq (instLocSrcLoc inst_loc) pred
; let name = mkPredName uniq inst_loc pred
; return (Dict {tci_name = name, tci_pred = pred, tci_loc = inst_loc}) }
----------------
......@@ -240,7 +295,7 @@ instCallDicts loc (EqPred ty1 ty2 : preds)
instCallDicts loc (pred : preds)
= do { uniq <- newUnique
; let name = mkPredName uniq (instLocSrcLoc loc) pred
; let name = mkPredName uniq loc pred
dict = Dict {tci_name = name, tci_pred = pred, tci_loc = loc}
; (dicts, co_fn) <- instCallDicts loc preds
; return (dict:dicts, co_fn <.> WpApp (instToId dict)) }
......@@ -262,13 +317,22 @@ newIPDict orig ip_name ty
newUnique `thenM` \ uniq ->
let
pred = IParam ip_name ty
name = mkPredName uniq (instLocSrcLoc inst_loc) pred
name = mkPredName uniq inst_loc pred
dict = Dict {tci_name = name, tci_pred = pred, tci_loc = inst_loc}
in
returnM (mapIPName (\n -> instToId dict) ip_name, dict)
\end{code}
\begin{code}
mkPredName :: Unique -> InstLoc -> PredType -> Name
mkPredName uniq loc pred_ty
= mkInternalName uniq occ (srcSpanStart (instLocSpan loc))
where
occ = case pred_ty of
ClassP cls tys -> mkDictOcc (getOccName cls)
IParam ip ty -> getOccName (ipNameName ip)
\end{code}
%************************************************************************
%* *
......@@ -340,7 +404,7 @@ newMethod inst_loc id tys
meth_id = mkUserLocal (mkMethodOcc (getOccName id)) new_uniq tau loc
inst = Method {tci_id = meth_id, tci_oid = id, tci_tys = tys,
tci_theta = theta, tci_loc = inst_loc}
loc = instLocSrcLoc inst_loc
loc = srcSpanStart (instLocSpan inst_loc)
in
returnM inst
\end{code}
......@@ -411,6 +475,12 @@ zonkInst lit@(LitInst {tci_ty = ty})
= zonkTcType ty `thenM` \ new_ty ->
returnM (lit {tci_ty = new_ty})
zonkInst implic@(ImplicInst {})
= ASSERT( all isImmutableTyVar (tci_tyvars implic) )
do { givens' <- zonkInsts (tci_given implic)
; wanteds' <- zonkInsts (tci_wanted implic)
; return (implic {tci_given = givens',tci_wanted = wanteds'}) }
zonkInsts insts = mappM zonkInst insts
\end{code}
......@@ -430,36 +500,41 @@ instance Outputable Inst where
pprDictsTheta :: [Inst] -> SDoc
-- Print in type-like fashion (Eq a, Show b)
pprDictsTheta dicts = pprTheta (map dictPred dicts)
-- The Inst can be an implication constraint, but not a Method or LitInst
pprDictsTheta insts = parens (sep (punctuate comma (map (ppr . instType) insts)))
pprDictsInFull :: [Inst] -> SDoc
-- Print in type-like fashion, but with source location
pprDictsInFull dicts
= vcat (map go dicts)
where
go dict = sep [quotes (ppr (dictPred dict)), nest 2 (pprInstLoc (instLoc dict))]
go dict = sep [quotes (ppr (instType dict)), nest 2 (pprInstArising dict)]
pprInsts :: [Inst] -> SDoc
-- Debugging: print the evidence :: type
pprInsts insts = brackets (interpp'SP insts)
pprInsts insts = brackets (interpp'SP insts)
pprInst, pprInstInFull :: Inst -> SDoc
-- Debugging: print the evidence :: type
pprInst (LitInst {tci_name = nm, tci_ty = ty}) = ppr nm <+> dcolon <+> ppr ty
pprInst (Dict {tci_name = nm, tci_pred = pred}) = ppr nm <+> dcolon <+> pprPred pred
pprInst (Method {tci_id = inst_id, tci_oid = id, tci_tys = tys})
= ppr inst_id <+> dcolon <+>
braces (sep [ppr id <+> ptext SLIT("at"),
brackets (sep (map pprParendType tys))])
pprInst inst = ppr (instName inst) <+> dcolon
<+> (braces (ppr (instType inst)) $$
ifPprDebug implic_stuff)
where
implic_stuff | isImplicInst inst = ppr (tci_reft inst)
| otherwise = empty
pprInstInFull inst
= sep [quotes (pprInst inst), nest 2 (pprInstLoc (instLoc inst))]
pprInstInFull inst = sep [quotes (pprInst inst), nest 2 (pprInstArising inst)]
tidyInst :: TidyEnv -> Inst -> Inst
tidyInst env lit@(LitInst {tci_ty = ty}) = lit {tci_ty = tidyType env ty}
tidyInst env dict@(Dict {tci_pred = pred}) = dict {tci_pred = tidyPred env pred}
tidyInst env meth@(Method {tci_tys = tys}) = meth {tci_tys = tidyTypes env tys}
tidyInst env implic@(ImplicInst {})
= implic { tci_tyvars = tvs'
, tci_given = map (tidyInst env') (tci_given implic)
, tci_wanted = map (tidyInst env') (tci_wanted implic) }
where
(env', tvs') = mapAccumL tidyTyVarBndr env (tci_tyvars implic)
tidyMoreInsts :: TidyEnv -> [Inst] -> (TidyEnv, [Inst])
-- This function doesn't assume that the tyvars are in scope
......@@ -509,7 +584,7 @@ addLocalInst home_ie ispec
-- We use tcInstSkolType because we don't want to allocate fresh
-- *meta* type variables.
let dfun = instanceDFunId ispec
; (tvs', theta', tau') <- tcInstSkolType (InstSkol dfun) (idType dfun)
; (tvs', theta', tau') <- tcInstSkolType InstSkol (idType dfun)
; let (cls, tys') = tcSplitDFunHead tau'
dfun' = setIdType dfun (mkSigmaTy tvs' theta' tau')
ispec' = setInstanceDFunId ispec dfun'
......@@ -581,46 +656,46 @@ addDictLoc ispec thing_inside
\begin{code}
data LookupInstResult
= NoInstance
| SimpleInst (LHsExpr TcId) -- Just a variable, type application, or literal
| GenInst [Inst] (LHsExpr TcId) -- The expression and its needed insts
| GenInst [Inst] (LHsExpr TcId) -- The expression and its needed insts
lookupSimpleInst :: Inst -> TcM LookupInstResult
-- This is "simple" in tthat it returns NoInstance for implication constraints
lookupInst :: Inst -> TcM LookupInstResult
-- It's important that lookupInst does not put any new stuff into
-- the LIE. Instead, any Insts needed by the lookup are returned in
-- the LookupInstResult, where they can be further processed by tcSimplify
--------------------- Impliciations ------------------------
lookupSimpleInst (ImplicInst {}) = return NoInstance
-- Methods
lookupInst (Method {tci_oid = id, tci_tys = tys, tci_theta = theta, tci_loc = loc})
--------------------- Methods ------------------------
lookupSimpleInst (Method {tci_oid = id, tci_tys = tys, tci_theta = theta, tci_loc = loc})
= do { (dicts, dict_app) <- instCallDicts loc theta
; let co_fn = dict_app <.> mkWpTyApps tys
; return (GenInst dicts (L span $ HsWrap co_fn (HsVar id))) }
where
span = instLocSrcSpan loc
-- Literals
span = instLocSpan loc
--------------------- Literals ------------------------
-- Look for short cuts first: if the literal is *definitely* a
-- int, integer, float or a double, generate the real thing here.
-- This is essential (see nofib/spectral/nucleic).
-- [Same shortcut as in newOverloadedLit, but we
-- may have done some unification by now]
lookupInst (LitInst {tci_lit = HsIntegral i from_integer_name, tci_ty = ty, tci_loc = loc})
lookupSimpleInst (LitInst {tci_lit = HsIntegral i from_integer_name, tci_ty = ty, tci_loc = loc})
| Just expr <- shortCutIntLit i ty
= returnM (GenInst [] (noLoc expr)) -- GenInst, not SimpleInst, because
-- expr may be a constructor application
= returnM (GenInst [] (noLoc expr))
| otherwise
= ASSERT( from_integer_name `isHsVar` fromIntegerName ) -- A LitInst invariant
tcLookupId fromIntegerName `thenM` \ from_integer ->
tcInstClassOp loc from_integer [ty] `thenM` \ method_inst ->
mkIntegerLit i `thenM` \ integer_lit ->
returnM (GenInst [method_inst]
(mkHsApp (L (instLocSrcSpan loc)
(mkHsApp (L (instLocSpan loc)
(HsVar (instToId method_inst))) integer_lit))
lookupInst (LitInst {tci_lit = HsFractional f from_rat_name, tci_ty = ty, tci_loc = loc})
lookupSimpleInst (LitInst {tci_lit = HsFractional f from_rat_name, tci_ty = ty, tci_loc = loc})
| Just expr <- shortCutFracLit f ty
= returnM (GenInst [] (noLoc expr))
......@@ -629,11 +704,11 @@ lookupInst (LitInst {tci_lit = HsFractional f from_rat_name, tci_ty = ty, tci_lo
tcLookupId fromRationalName `thenM` \ from_rational ->
tcInstClassOp loc from_rational [ty] `thenM` \ method_inst ->
mkRatLit f `thenM` \ rat_lit ->
returnM (GenInst [method_inst] (mkHsApp (L (instLocSrcSpan loc)
returnM (GenInst [method_inst] (mkHsApp (L (instLocSpan loc)
(HsVar (instToId method_inst))) rat_lit))
-- Dictionaries
lookupInst (Dict {tci_pred = pred, tci_loc = loc})
--------------------- Dictionaries ------------------------
lookupSimpleInst (Dict {tci_pred = pred, tci_loc = loc})
= do { mb_result <- lookupPred pred
; case mb_result of {
Nothing -> return NoInstance ;
......@@ -668,11 +743,11 @@ lookupInst (Dict {tci_pred = pred, tci_loc = loc})
-- any nested for-alls in rho. So the in-scope set is unchanged
dfun_rho = substTy tenv' rho
(theta, _) = tcSplitPhiTy dfun_rho
src_loc = instLocSrcSpan loc
src_loc = instLocSpan loc
dfun = HsVar dfun_id
tys = map (substTyVar tenv') tyvars
; if null theta then
returnM (SimpleInst (L src_loc $ HsWrap (mkWpTyApps tys) dfun))
returnM (GenInst [] (L src_loc $ HsWrap (mkWpTyApps tys) dfun))
else do
{ (dicts, dict_app) <- instCallDicts loc theta
; let co_fn = dict_app <.> mkWpTyApps tys
......@@ -799,7 +874,7 @@ syntaxNameCtxt name orig ty tidy_env
msg = vcat [ptext SLIT("When checking that") <+> quotes (ppr name) <+>
ptext SLIT("(needed by a syntactic construct)"),
nest 2 (ptext SLIT("has the required type:") <+> ppr (tidyType tidy_env ty)),
nest 2 (pprInstLoc inst_loc)]
nest 2 (ptext SLIT("arising from") <+> pprInstLoc inst_loc)]
in
returnM (tidy_env, msg)
\end{code}
......@@ -238,7 +238,7 @@ tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)
= addErrCtxt (cmdCtxt cmd) $
do { cmds_w_tys <- zipWithM new_cmd_ty cmd_args [1..]
; span <- getSrcSpanM
; [w_tv] <- tcInstSkolTyVars (ArrowSkol span) [alphaTyVar]
; [w_tv] <- tcInstSkolTyVars ArrowSkol [alphaTyVar]
; let w_ty = mkTyVarTy w_tv -- Just a convenient starting point
-- a ((w,t1) .. tn) t
......@@ -251,7 +251,8 @@ tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)
-- Check expr
; (expr', lie) <- escapeArrowScope (getLIE (tcMonoExpr expr e_ty))
; inst_binds <- tcSimplifyCheck sig_msg [w_tv] [] lie
; loc <- getInstLoc (SigOrigin ArrowSkol)
; inst_binds <- tcSimplifyCheck loc [w_tv] [] lie
-- Check that the polymorphic variable hasn't been unified with anything
-- and is not free in res_ty or the cmd_stk (i.e. t, t1..tn)
......@@ -303,8 +304,6 @@ tc_cmd env cmd@(HsArrForm expr fixity cmd_args) (cmd_stk, res_ty)
other -> (ty, [])
sig_msg = ptext SLIT("expected type of a command form")
-----------------------------------------------------------------
-- Base case for illegal commands
-- This is where expressions that aren't commands get rejected
......
......@@ -710,16 +710,17 @@ generalise dflags top_lvl bind_list sig_fn mono_infos lie_req
= tcSimplifyInfer doc tau_tvs lie_req
| otherwise -- UNRESTRICTED CASE, WITH TYPE SIGS
= do { sig_lie <- unifyCtxts sigs -- sigs is non-empty
= do { sig_lie <- unifyCtxts sigs -- sigs is non-empty; sig_lie is zonked
; let -- The "sig_avails" is the stuff available. We get that from
-- the context of the type signature, BUT ALSO the lie_avail
-- so that polymorphic recursion works right (see Note [Polymorphic recursion])
local_meths = [mkMethInst sig mono_id | (_, Just sig, mono_id) <- mono_infos]
sig_avails = sig_lie ++ local_meths
loc = sig_loc (head sigs)
-- Check that the needed dicts can be
-- expressed in terms of the signature ones
; (forall_tvs, dict_binds) <- tcSimplifyInferCheck doc tau_tvs sig_avails lie_req
; (forall_tvs, dict_binds) <- tcSimplifyInferCheck loc tau_tvs sig_avails lie_req
-- Check that signature type variables are OK
; final_qtvs <- checkSigsTyVars forall_tvs sigs
......@@ -754,14 +755,16 @@ might not otherwise be related. This is a rather subtle issue.
\begin{code}
unifyCtxts :: [TcSigInfo] -> TcM [Inst]
-- Post-condition: the returned Insts are full zonked
unifyCtxts (sig1 : sigs) -- Argument is always non-empty
= do { mapM unify_ctxt sigs
; newDictBndrs (sig_loc sig1) (sig_theta sig1) }
; theta <- zonkTcThetaType (sig_theta sig1)
; newDictBndrs (sig_loc sig1) theta }
where
theta1 = sig_theta sig1
unify_ctxt :: TcSigInfo -> TcM ()
unify_ctxt sig@(TcSigInfo { sig_theta = theta })
= setSrcSpan (instLocSrcSpan (sig_loc sig)) $
= setSrcSpan (instLocSpan (sig_loc sig)) $
addErrCtxt (sigContextsCtxt sig1 sig) $
unifyTheta theta1 theta
......@@ -1060,8 +1063,7 @@ tcInstSig use_skols name scoped_names
= do { poly_id <- tcLookupId name -- Cannot fail; the poly ids are put into
-- scope when starting the binding group
; let skol_info = SigSkol (FunSigCtxt name)
inst_tyvars | use_skols = tcInstSkolTyVars skol_info
| otherwise = tcInstSigTyVars skol_info
inst_tyvars = tcInstSigTyVars use_skols skol_info
; (tvs, theta, tau) <- tcInstType inst_tyvars (idType poly_id)
; loc <- getInstLoc (SigOrigin skol_info)
; return (TcSigInfo { sig_id = poly_id,
......
......@@ -271,7 +271,7 @@ tcDefMeth origin clas tyvars binds_in sig_fn prag_fn sel_id
-- Check the context
{ dict_binds <- tcSimplifyCheck
(ptext SLIT("class") <+> ppr clas)
loc
tyvars
[this_dict]
insts_needed
......@@ -362,18 +362,18 @@ tcMethodBind inst_tyvars inst_theta avail_insts sig_fn prag_fn
let
[(_, Just sig, local_meth_id)] = mono_bind_infos
loc = sig_loc sig
in
addErrCtxtM (sigCtxt sel_id inst_tyvars inst_theta (idType meth_id)) $
newDictBndrs (sig_loc sig) (sig_theta sig) `thenM` \ meth_dicts ->
newDictBndrs loc (sig_theta sig) `thenM` \ meth_dicts ->
let
meth_tvs = sig_tvs sig
all_tyvars = meth_tvs ++ inst_tyvars
all_insts = avail_insts ++ meth_dicts
in
tcSimplifyCheck
(ptext SLIT("class or instance method") <+> quotes (ppr sel_id))
all_tyvars all_insts meth_lie `thenM` \ lie_binds ->
loc all_tyvars all_insts meth_lie `thenM` \ lie_binds ->
checkSigTyVars all_tyvars `thenM_`
......@@ -537,8 +537,8 @@ mkDefMethRhs origin clas inst_tys sel_id loc GenDefMeth
other -> Nothing
other -> Nothing
isInstDecl (SigOrigin (InstSkol _)) = True
isInstDecl (SigOrigin (ClsSkol _)) = False
isInstDecl (SigOrigin InstSkol) = True
isInstDecl (SigOrigin (ClsSkol _)) = False
\end{code}
......
......@@ -400,6 +400,9 @@ refineEnvironment :: Refinement -> TcM a -> TcM a
-- I don't think I have to refine the set of global type variables in scope
-- Reason: the refinement never increases that set
refineEnvironment reft thing_inside
| isEmptyRefinement reft -- Common case
= thing_inside
| otherwise
= do { env <- getLclEnv
; let le' = mapNameEnv refine (tcl_env env)
; setLclEnv (env {tcl_env = le'}) thing_inside }
......
......@@ -11,8 +11,9 @@
\begin{code}
module TcGadt (
Refinement, emptyRefinement, gadtRefine,
refineType, refineResType,
Refinement, emptyRefinement, isEmptyRefinement,
gadtRefine,
refineType, refinePred, refineResType,
dataConCanMatch,
tcUnifyTys, BindFlag(..)
) where
......@@ -22,6 +23,7 @@ module TcGadt (
import HsSyn
import Coercion
import Type
import TypeRep
import DataCon
import Var
......@@ -61,6 +63,8 @@ instance Outputable Refinement where
emptyRefinement :: Refinement
emptyRefinement = (Reft emptyInScopeSet emptyVarEnv)
isEmptyRefinement :: Refinement -> Bool
isEmptyRefinement (Reft _ env) = isEmptyVarEnv env
refineType :: Refinement -> Type -> Maybe (Coercion, Type)
-- Apply the refinement to the type.
......@@ -77,6 +81,17 @@ refineType (Reft in_scope env) ty
tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
co_subst = mkTvSubst in_scope (mapVarEnv fst env)
refinePred :: Refinement -> PredType -> Maybe (Coercion, PredType)
refinePred (Reft in_scope env) pred
| not (isEmptyVarEnv env), -- Common case
any (`elemVarEnv` env) (varSetElems (tyVarsOfPred pred))
= Just (mkPredTy (substPred co_subst pred), substPred tv_subst pred)
| otherwise
= Nothing -- The type doesn't mention any refined type variables
where
tv_subst = mkTvSubst in_scope (mapVarEnv snd env)
co_subst = mkTvSubst in_scope (mapVarEnv fst env)
refineResType :: Refinement -> Type -> (HsWrapper, Type)
-- Like refineType, but returns the 'sym' coercion
-- If (refineResType r ty) = (co, ty')
......
......@@ -800,7 +800,6 @@ pprHsSigCtxt ctxt hs_ty = vcat [ ptext SLIT("In") <+> pprUserTypeCtxt ctxt <> co
pp_sig (FunSigCtxt n) = pp_n_colon n
pp_sig (ConArgCtxt n) = pp_n_colon n
pp_sig (ForSigCtxt n) = pp_n_colon n
pp_sig (RuleSigCtxt n) = pp_n_colon n
pp_sig other = ppr (unLoc hs_ty)
pp_n_colon n = ppr n <+> dcolon <+> ppr (unLoc hs_ty)
......
......@@ -483,7 +483,7 @@ tcInstDecl2 :: InstInfo -> TcM (LHsBinds Id)
tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = NewTypeDerived mb_preds })
= do { let dfun_id = instanceDFunId ispec
rigid_info = InstSkol dfun_id
rigid_info = InstSkol
origin = SigOrigin rigid_info
inst_ty = idType dfun_id
; (tvs, theta, inst_head_ty) <- tcSkolSigType rigid_info inst_ty
......@@ -518,7 +518,8 @@ tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = NewTypeDerived mb_preds })
make_wrapper inst_loc tvs theta (Just preds) -- Case (a)
= ASSERT( null tvs && null theta )
do { dicts <- newDictBndrs inst_loc preds
; sc_binds <- addErrCtxt superClassCtxt (tcSimplifySuperClasses [] [] dicts)
; sc_binds <- addErrCtxt superClassCtxt $
tcSimplifySuperClasses inst_loc [] dicts
-- Use tcSimplifySuperClasses to avoid creating loops, for the
-- same reason as Note [SUPERCLASS-LOOP 1] in TcSimplify
; return (map instToId dicts, idHsWrapper, sc_binds) }
......@@ -584,7 +585,7 @@ tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = NewTypeDerived mb_preds })
tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = VanillaInst monobinds uprags })
= let
dfun_id = instanceDFunId ispec
rigid_info = InstSkol dfun_id
rigid_info = InstSkol
inst_ty = idType dfun_id
in
-- Prime error recovery
......@@ -626,9 +627,8 @@ tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = VanillaInst monobinds uprags })
-- Don't include this_dict in the 'givens', else
-- sc_dicts get bound by just selecting from this_dict!!
addErrCtxt superClassCtxt
(tcSimplifySuperClasses inst_tyvars'
dfun_arg_dicts
sc_dicts) `thenM` \ sc_binds ->
(tcSimplifySuperClasses inst_loc
dfun_arg_dicts sc_dicts) `thenM` \ sc_binds ->
-- It's possible that the superclass stuff might unified one
-- of the inst_tyavars' with something in the envt
......
......@@ -78,6 +78,7 @@ import Util
import Maybes
import ListSetOps
import UniqSupply
import SrcLoc
import Outputable
import Control.Monad ( when )
......@@ -160,19 +161,30 @@ tcSkolSigTyVars :: SkolemInfo -> [TyVar] -> [TcTyVar]
tcSkolSigTyVars info tyvars = [ mkSkolTyVar (tyVarName tv) (tyVarKind tv) info
| tv <- tyvars ]
tcInstSkolType :: SkolemInfo -> TcType -> TcM ([TcTyVar], TcThetaType, TcType)
-- Instantiate a type with fresh skolem constants
tcInstSkolType info ty = tcInstType (tcInstSkolTyVars info) ty
tcInstSkolTyVar :: SkolemInfo -> TyVar -> TcM TcTyVar
tcInstSkolTyVar info tyvar
tcInstSkolTyVar :: SkolemInfo -> Maybe SrcLoc -> TyVar -> TcM TcTyVar
-- Instantiate the tyvar, using
-- * the occ-name and kind of the supplied tyvar,
-- * the unique from the monad,
-- * the location either from the tyvar (mb_loc = Nothing)
-- or from mb_loc (Just loc)
tcInstSkolTyVar info mb_loc tyvar
= do { uniq <- newUnique
; let name = setNameUnique (tyVarName tyvar) uniq
kind = tyVarKind tyvar
; return (mkSkolTyVar name kind info) }
; let old_name = tyVarName tyvar
kind = tyVarKind tyvar
loc = mb_loc `orElse` getSrcLoc old_name
new_name = mkInternalName uniq (nameOccName old_name) loc
; return (mkSkolTyVar new_name kind info) }
tcInstSkolTyVars :: SkolemInfo -> [TyVar] -> TcM [TcTyVar]
tcInstSkolTyVars info tyvars = mapM (tcInstSkolTyVar info) tyvars
-- Get the location from the monad
tcInstSkolTyVars info tyvars
= do { span <- getSrcSpanM
; mapM (tcInstSkolTyVar info (Just (srcSpanStart span))) tyvars }
tcInstSkolType :: SkolemInfo -> TcType -> TcM ([TcTyVar], TcThetaType, TcType)
-- Instantiate a type with fresh skolem constants
-- Binding location comes from the monad
tcInstSkolType info ty = tcInstType (tcInstSkolTyVars info) ty
\end{code}