Commit 6a78503e authored by cactus's avatar cactus

Typecheck the wrapper definition of a pattern synonym,

after everything in the same scope is typechecked
parent 25c2eebc
......@@ -16,7 +16,7 @@ module TcBinds ( tcLocalBinds, tcTopBinds, tcRecSelBinds,
import {-# SOURCE #-} TcMatches ( tcGRHSsPat, tcMatchesFun )
import {-# SOURCE #-} TcExpr ( tcMonoExpr )
import {-# SOURCE #-} TcPatSyn ( tcPatSynDecl )
import {-# SOURCE #-} TcPatSyn ( tcPatSynDecl, tcPatSynWrapper )
import DynFlags
import HsSyn
......@@ -315,14 +315,28 @@ tcValBinds top_lvl binds sigs thing_inside
-- Extend the envt right away with all
-- the Ids declared with type signatures
-- Use tcExtendIdEnv2 to avoid extending the TcIdBinder stack
; tcExtendIdEnv2 [(idName id, id) | id <- poly_ids] $
tcBindGroups top_lvl sig_fn prag_fn
binds thing_inside }
; tcExtendIdEnv2 [(idName id, id) | id <- poly_ids] $ do
{ (binds', (extra_binds', thing)) <- tcBindGroups top_lvl sig_fn prag_fn binds $ do
{ thing <- thing_inside
; patsyn_wrappers <- forM patsyns $ \(name, loc, args, lpat, dir) -> do
{ patsyn <- tcLookupPatSyn name
; case patSynWrapper patsyn of
Nothing -> return emptyBag
Just wrapper_id -> tcPatSynWrapper (L loc wrapper_id) lpat dir args }
; let extra_binds = [ (NonRecursive, wrapper) | wrapper <- patsyn_wrappers ]
; return (extra_binds, thing) }
; return (binds' ++ extra_binds', thing) }}
where
patsyns = [ (name, loc, args, lpat, dir)
| (_, lbinds) <- binds
, L loc (PatSynBind{ patsyn_id = L _ name, patsyn_args = details, patsyn_def = lpat, patsyn_dir = dir }) <- bagToList lbinds
, let args = map unLoc $ case details of
PrefixPatSyn args -> args
InfixPatSyn arg1 arg2 -> [arg1, arg2]
]
patsyn_placeholder_kinds -- See Note [Placeholder PatSyn kinds]
= [ (name, placeholder_patsyn_tything)
| (_, lbinds) <- binds
, L _ (PatSynBind{ patsyn_id = L _ name }) <- bagToList lbinds ]
| (name, _, _, _, _) <- patsyns ]
placeholder_patsyn_tything
= AGlobal $ AConLike $ PatSynCon $ panic "fakePatSynCon"
......
......@@ -7,7 +7,7 @@
\begin{code}
{-# LANGUAGE CPP #-}
module TcPatSyn (tcPatSynDecl) where
module TcPatSyn (tcPatSynDecl, tcPatSynWrapper) where
import HsSyn
import TcPat
......@@ -95,9 +95,10 @@ tcPatSynDecl lname@(L _ name) details lpat dir
prov_dicts req_dicts
prov_theta req_theta
pat_ty
; m_wrapper <- tcPatSynWrapper lname lpat dir args
univ_tvs ex_tvs theta pat_ty
; let binds = matcher_bind `unionBags` maybe emptyBag snd m_wrapper
; wrapper_id <- if isBidirectional dir
then fmap Just $ mkPatSynWrapperId lname args univ_tvs ex_tvs theta pat_ty
else return Nothing
; traceTc "tcPatSynDecl }" $ ppr name
; let patSyn = mkPatSyn name is_infix
......@@ -105,8 +106,8 @@ tcPatSynDecl lname@(L _ name) details lpat dir
univ_tvs ex_tvs
prov_theta req_theta
pat_ty
matcher_id (fmap fst m_wrapper)
; return (patSyn, binds) }
matcher_id wrapper_id
; return (patSyn, matcher_bind) }
\end{code}
......@@ -188,44 +189,41 @@ tcPatSynMatcher (L loc name) lpat args univ_tvs ex_tvs ev_binds prov_dicts req_d
name <- newName . mkVarOccFS . fsLit $ s
return $ mkLocalId name ty
tcPatSynWrapper :: Located Name
isBidirectional :: HsPatSynDir a -> Bool
isBidirectional Unidirectional = False
isBidirectional ImplicitBidirectional = True
isBidirectional ExplicitBidirectional{} = True
tcPatSynWrapper :: Located Id
-> LPat Name
-> HsPatSynDir Name
-> [Var]
-> [TyVar] -> [TyVar]
-> ThetaType
-> TcType
-> TcM (Maybe (Id, LHsBinds Id))
-> [Name]
-> TcM (LHsBinds Id)
-- See Note [Matchers and wrappers for pattern synonyms] in PatSyn
tcPatSynWrapper lname lpat dir args univ_tvs ex_tvs theta pat_ty
= do { let argNames = mkNameSet (map Var.varName args)
; case (dir, tcPatToExpr argNames lpat) of
(Unidirectional, _) ->
return Nothing
(ImplicitBidirectional, Nothing) ->
cannotInvertPatSynErr lpat
(ImplicitBidirectional, Just lexpr) ->
fmap Just $ mkWrapper $ \wrapper_lname args' ->
do { let wrapper_args = map (noLoc . VarPat . Var.varName) args'
wrapper_match = mkMatch wrapper_args lexpr EmptyLocalBinds
bind = mkTopFunBind Generated wrapper_lname [wrapper_match]
; return bind }
(ExplicitBidirectional mg, _) ->
fmap Just $ mkWrapper $ \wrapper_lname _args' ->
return FunBind{ fun_id = wrapper_lname
, fun_infix = False
, fun_matches = mg
, fun_co_fn = idHsWrapper
, bind_fvs = placeHolderNames
, fun_tick = Nothing } }
where
mkWrapper = mkPatSynWrapper lname args univ_tvs ex_tvs theta pat_ty
mkPatSynWrapper :: Located Name
-> [Var] -> [TyVar] -> [TyVar] -> ThetaType -> Type
-> (Located Name -> [Var] -> TcM (HsBind Name))
-> TcM (Id, LHsBinds Id)
mkPatSynWrapper (L loc name) args univ_tvs ex_tvs theta pat_ty mk_bind
tcPatSynWrapper _ _ Unidirectional _
= panic "tcPatSynWrapper"
tcPatSynWrapper (L _ wrapper_id) lpat ImplicitBidirectional args
= do { lexpr <- case tcPatToExpr (mkNameSet args) lpat of
Nothing -> cannotInvertPatSynErr lpat
Just lexpr -> return lexpr
; let wrapper_args = map (noLoc . VarPat) args
wrapper_lname = L (getLoc lpat) (idName wrapper_id)
wrapper_match = mkMatch wrapper_args lexpr EmptyLocalBinds
wrapper_bind = mkTopFunBind Generated wrapper_lname [wrapper_match]
; mkPatSynWrapper wrapper_id wrapper_bind }
tcPatSynWrapper (L loc wrapper_id) _ (ExplicitBidirectional mg) _
= mkPatSynWrapper wrapper_id $
FunBind{ fun_id = L loc (idName wrapper_id)
, fun_infix = False
, fun_matches = mg
, fun_co_fn = idHsWrapper
, bind_fvs = placeHolderNames
, fun_tick = Nothing }
mkPatSynWrapperId :: Located Name
-> [Var] -> [TyVar] -> [TyVar] -> ThetaType -> Type
-> TcM Id
mkPatSynWrapperId (L _ name) args univ_tvs ex_tvs theta pat_ty
= do { let qtvs = univ_tvs ++ ex_tvs
; (subst, wrapper_tvs) <- tcInstSkolTyVars qtvs
; let wrapper_theta = substTheta subst theta
......@@ -235,20 +233,25 @@ mkPatSynWrapper (L loc name) args univ_tvs ex_tvs theta pat_ty mk_bind
wrapper_sigma = mkSigmaTy wrapper_tvs wrapper_theta wrapper_tau
; wrapper_name <- newImplicitBinder name mkDataConWrapperOcc
; let wrapper_lname = L loc wrapper_name
wrapper_id = mkExportedLocalId VanillaId wrapper_name wrapper_sigma
; bind <- mk_bind wrapper_lname args'
; let sig = TcSigInfo{ sig_id = wrapper_id
, sig_tvs = map (\tv -> (Nothing, tv)) wrapper_tvs
, sig_theta = wrapper_theta
, sig_tau = wrapper_tau
, sig_loc = loc
}
; (wrapper_binds, _, _) <- tcPolyCheck NonRecursive (const []) sig (noLoc bind)
; return $ mkExportedLocalId VanillaId wrapper_name wrapper_sigma }
mkPatSynWrapper :: Id
-> HsBind Name
-> TcM (LHsBinds Id)
mkPatSynWrapper wrapper_id bind
= do { (wrapper_binds, _, _) <- tcPolyCheck NonRecursive (const []) sig (noLoc bind)
; traceTc "tcPatSynDecl wrapper" $ ppr wrapper_binds
; traceTc "tcPatSynDecl wrapper type" $ ppr (varType wrapper_id)
; return (wrapper_id, wrapper_binds) }
; return wrapper_binds }
where
sig = TcSigInfo{ sig_id = wrapper_id
, sig_tvs = map (\tv -> (Nothing, tv)) wrapper_tvs
, sig_theta = wrapper_theta
, sig_tau = wrapper_tau
, sig_loc = noSrcSpan
}
(wrapper_tvs, wrapper_theta, wrapper_tau) = tcSplitSigmaTy (idType wrapper_id)
\end{code}
Note [As-patterns in pattern synonym definitions]
......
......@@ -13,4 +13,10 @@ tcPatSynDecl :: Located Name
-> LPat Name
-> HsPatSynDir Name
-> TcM (PatSyn, LHsBinds Id)
tcPatSynWrapper :: Located Id
-> LPat Name
-> HsPatSynDir Name
-> [Name]
-> TcM (LHsBinds Id)
\end{code}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment