Commit e3a4d6c3 authored by simonpj's avatar simonpj

[project @ 2005-08-10 11:05:06 by simonpj]

It turned out that doing all binding dependency analysis in the typechecker
meant that the renamer's unused-binding error messages got worse.  So now
I've put the first dep anal back into the renamer, while the second (which
is specific to type checking) remains in the type checker.

I've also made the pretty printer sort the decls back into source order
before printing them (except with -dppr-debug).

Fixes rn041.
parent 31578e22
...@@ -83,7 +83,7 @@ dsLocalBinds (HsIPBinds binds) body = dsIPBinds binds body ...@@ -83,7 +83,7 @@ dsLocalBinds (HsIPBinds binds) body = dsIPBinds binds body
------------------------- -------------------------
dsValBinds :: HsValBinds Id -> CoreExpr -> DsM CoreExpr dsValBinds :: HsValBinds Id -> CoreExpr -> DsM CoreExpr
dsValBinds (ValBindsOut binds) body = foldrDs ds_val_bind body binds dsValBinds (ValBindsOut binds _) body = foldrDs ds_val_bind body binds
------------------------- -------------------------
dsIPBinds (IPBinds ip_binds dict_binds) body dsIPBinds (IPBinds ip_binds dict_binds) body
...@@ -680,7 +680,7 @@ dsMDo tbl stmts body result_ty ...@@ -680,7 +680,7 @@ dsMDo tbl stmts body result_ty
go (new_bind_stmt : let_stmt : stmts) go (new_bind_stmt : let_stmt : stmts)
where where
new_bind_stmt = mkBindStmt (mk_tup_pat later_pats) mfix_app new_bind_stmt = mkBindStmt (mk_tup_pat later_pats) mfix_app
let_stmt = LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)])) let_stmt = LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)] []))
-- Remove the later_ids that appear (without fancy coercions) -- Remove the later_ids that appear (without fancy coercions)
......
...@@ -20,9 +20,10 @@ import Name ( Name ) ...@@ -20,9 +20,10 @@ import Name ( Name )
import NameSet ( NameSet, elemNameSet ) import NameSet ( NameSet, elemNameSet )
import BasicTypes ( IPName, RecFlag(..), Activation(..), Fixity ) import BasicTypes ( IPName, RecFlag(..), Activation(..), Fixity )
import Outputable import Outputable
import SrcLoc ( Located(..), unLoc ) import SrcLoc ( Located(..), SrcSpan, unLoc )
import Util ( sortLe )
import Var ( TyVar, DictId, Id ) import Var ( TyVar, DictId, Id )
import Bag ( Bag, emptyBag, isEmptyBag, bagToList, unionBags ) import Bag ( Bag, emptyBag, isEmptyBag, bagToList, unionBags, unionManyBags )
\end{code} \end{code}
%************************************************************************ %************************************************************************
...@@ -45,9 +46,9 @@ data HsValBinds id -- Value bindings (not implicit parameters) ...@@ -45,9 +46,9 @@ data HsValBinds id -- Value bindings (not implicit parameters)
(LHsBinds id) [LSig id] -- Not dependency analysed (LHsBinds id) [LSig id] -- Not dependency analysed
-- Recursive by default -- Recursive by default
| ValBindsOut -- After typechecking | ValBindsOut -- After renaming
[(RecFlag, LHsBinds id)] -- Dependency analysed [(RecFlag, LHsBinds id)] -- Dependency analysed
[LSig Name]
type LHsBinds id = Bag (LHsBind id) type LHsBinds id = Bag (LHsBind id)
type DictBinds id = LHsBinds id -- Used for dictionary or method bindings type DictBinds id = LHsBinds id -- Used for dictionary or method bindings
...@@ -115,17 +116,32 @@ instance OutputableBndr id => Outputable (HsLocalBinds id) where ...@@ -115,17 +116,32 @@ instance OutputableBndr id => Outputable (HsLocalBinds id) where
instance OutputableBndr id => Outputable (HsValBinds id) where instance OutputableBndr id => Outputable (HsValBinds id) where
ppr (ValBindsIn binds sigs) ppr (ValBindsIn binds sigs)
= vcat [vcat (map ppr sigs), = pprValBindsForUser binds sigs
vcat (map ppr (bagToList binds))
-- *not* pprLHsBinds because we don't want braces; 'let' and ppr (ValBindsOut sccs sigs)
-- 'where' include a list of HsBindGroups and we don't want = getPprStyle $ \ sty ->
-- several groups of bindings each with braces around. if debugStyle sty then -- Print with sccs showing
] vcat (map ppr sigs) $$ vcat (map ppr_scc sccs)
ppr (ValBindsOut sccs) = vcat (map ppr_scc sccs) else
where pprValBindsForUser (unionManyBags (map snd sccs)) sigs
ppr_scc (rec_flag, binds) = pp_rec rec_flag <+> pprLHsBinds binds where
pp_rec Recursive = ptext SLIT("rec") ppr_scc (rec_flag, binds) = pp_rec rec_flag <+> pprLHsBinds binds
pp_rec NonRecursive = ptext SLIT("nonrec") pp_rec Recursive = ptext SLIT("rec")
pp_rec NonRecursive = ptext SLIT("nonrec")
-- *not* pprLHsBinds because we don't want braces; 'let' and
-- 'where' include a list of HsBindGroups and we don't want
-- several groups of bindings each with braces around.
-- Sort by location before printing
pprValBindsForUser binds sigs
= vcat (map snd (sort_by_loc decls))
where
decls :: [(SrcSpan, SDoc)]
decls = [(loc, ppr sig) | L loc sig <- sigs] ++
[(loc, ppr bind) | L loc bind <- bagToList binds]
sort_by_loc decls = sortLe (\(l1,_) (l2,_) -> l1 <= l2) decls
pprLHsBinds :: OutputableBndr id => LHsBinds id -> SDoc pprLHsBinds :: OutputableBndr id => LHsBinds id -> SDoc
pprLHsBinds binds pprLHsBinds binds
...@@ -142,12 +158,12 @@ isEmptyLocalBinds (HsIPBinds ds) = isEmptyIPBinds ds ...@@ -142,12 +158,12 @@ isEmptyLocalBinds (HsIPBinds ds) = isEmptyIPBinds ds
isEmptyLocalBinds EmptyLocalBinds = True isEmptyLocalBinds EmptyLocalBinds = True
isEmptyValBinds :: HsValBinds a -> Bool isEmptyValBinds :: HsValBinds a -> Bool
isEmptyValBinds (ValBindsIn ds sigs) = isEmptyLHsBinds ds && null sigs isEmptyValBinds (ValBindsIn ds sigs) = isEmptyLHsBinds ds && null sigs
isEmptyValBinds (ValBindsOut ds) = null ds isEmptyValBinds (ValBindsOut ds sigs) = null ds && null sigs
emptyValBindsIn, emptyValBindsOut :: HsValBinds a emptyValBindsIn, emptyValBindsOut :: HsValBinds a
emptyValBindsIn = ValBindsIn emptyBag [] emptyValBindsIn = ValBindsIn emptyBag []
emptyValBindsOut = ValBindsOut [] emptyValBindsOut = ValBindsOut [] []
emptyLHsBinds :: LHsBinds id emptyLHsBinds :: LHsBinds id
emptyLHsBinds = emptyBag emptyLHsBinds = emptyBag
...@@ -159,8 +175,8 @@ isEmptyLHsBinds = isEmptyBag ...@@ -159,8 +175,8 @@ isEmptyLHsBinds = isEmptyBag
plusHsValBinds :: HsValBinds a -> HsValBinds a -> HsValBinds a plusHsValBinds :: HsValBinds a -> HsValBinds a -> HsValBinds a
plusHsValBinds (ValBindsIn ds1 sigs1) (ValBindsIn ds2 sigs2) plusHsValBinds (ValBindsIn ds1 sigs1) (ValBindsIn ds2 sigs2)
= ValBindsIn (ds1 `unionBags` ds2) (sigs1 ++ sigs2) = ValBindsIn (ds1 `unionBags` ds2) (sigs1 ++ sigs2)
plusHsValBinds (ValBindsOut ds1) (ValBindsOut ds2) plusHsValBinds (ValBindsOut ds1 sigs1) (ValBindsOut ds2 sigs2)
= ValBindsOut (ds1 ++ ds2) = ValBindsOut (ds1 ++ ds2) (sigs1 ++ sigs2)
\end{code} \end{code}
What AbsBinds means What AbsBinds means
......
...@@ -100,7 +100,7 @@ mkHsDictLet binds expr ...@@ -100,7 +100,7 @@ mkHsDictLet binds expr
| isEmptyLHsBinds binds = expr | isEmptyLHsBinds binds = expr
| otherwise = L (getLoc expr) (HsLet (HsValBinds val_binds) expr) | otherwise = L (getLoc expr) (HsLet (HsValBinds val_binds) expr)
where where
val_binds = ValBindsOut [(Recursive, binds)] val_binds = ValBindsOut [(Recursive, binds)] []
mkHsConApp :: DataCon -> [Type] -> [HsExpr Id] -> LHsExpr Id mkHsConApp :: DataCon -> [Type] -> [HsExpr Id] -> LHsExpr Id
-- Used for constructing dictinoary terms etc, so no locations -- Used for constructing dictinoary terms etc, so no locations
...@@ -279,8 +279,8 @@ collectLocalBinders (HsIPBinds _) = [] ...@@ -279,8 +279,8 @@ collectLocalBinders (HsIPBinds _) = []
collectLocalBinders EmptyLocalBinds = [] collectLocalBinders EmptyLocalBinds = []
collectHsValBinders :: HsValBinds name -> [Located name] collectHsValBinders :: HsValBinds name -> [Located name]
collectHsValBinders (ValBindsIn binds sigs) = collectHsBindLocatedBinders binds collectHsValBinders (ValBindsIn binds sigs) = collectHsBindLocatedBinders binds
collectHsValBinders (ValBindsOut binds) = foldr collect_one [] binds collectHsValBinders (ValBindsOut binds sigs) = foldr collect_one [] binds
where where
collect_one (_,binds) acc = foldrBag (collectAcc . unLoc) acc binds collect_one (_,binds) acc = foldrBag (collectAcc . unLoc) acc binds
...@@ -312,8 +312,8 @@ collectHsBindLocatedBinders binds = foldrBag (collectAcc . unLoc) [] binds ...@@ -312,8 +312,8 @@ collectHsBindLocatedBinders binds = foldrBag (collectAcc . unLoc) [] binds
Get all the pattern type signatures out of a bunch of bindings Get all the pattern type signatures out of a bunch of bindings
\begin{code} \begin{code}
collectSigTysFromHsBinds :: [LHsBind name] -> [LHsType name] collectSigTysFromHsBinds :: LHsBinds name -> [LHsType name]
collectSigTysFromHsBinds binds = concat (map collectSigTysFromHsBind binds) collectSigTysFromHsBinds binds = concatMap collectSigTysFromHsBind (bagToList binds)
collectSigTysFromHsBind :: LHsBind name -> [LHsType name] collectSigTysFromHsBind :: LHsBind name -> [LHsType name]
collectSigTysFromHsBind bind collectSigTysFromHsBind bind
......
...@@ -21,7 +21,6 @@ module RnBinds ( ...@@ -21,7 +21,6 @@ module RnBinds (
import {-# SOURCE #-} RnExpr( rnLExpr, rnStmts ) import {-# SOURCE #-} RnExpr( rnLExpr, rnStmts )
import HsSyn import HsSyn
import HsBinds ( hsSigDoc, eqHsSig )
import RdrHsSyn import RdrHsSyn
import RnHsSyn import RnHsSyn
import TcRnMonad import TcRnMonad
...@@ -41,9 +40,11 @@ import PrelNames ( isUnboundName ) ...@@ -41,9 +40,11 @@ import PrelNames ( isUnboundName )
import RdrName ( RdrName, rdrNameOcc ) import RdrName ( RdrName, rdrNameOcc )
import SrcLoc ( mkSrcSpan, Located(..), unLoc ) import SrcLoc ( mkSrcSpan, Located(..), unLoc )
import ListSetOps ( findDupsEq ) import ListSetOps ( findDupsEq )
import BasicTypes ( RecFlag(..) )
import Digraph ( SCC(..), stronglyConnComp )
import Bag import Bag
import Outputable import Outputable
import Maybes ( orElse ) import Maybes ( orElse, fromJust, isJust )
import Monad ( foldM ) import Monad ( foldM )
\end{code} \end{code}
...@@ -177,7 +178,7 @@ rnTopBindsBoot (ValBindsIn mbinds sigs) ...@@ -177,7 +178,7 @@ rnTopBindsBoot (ValBindsIn mbinds sigs)
rnTopBindsSrc :: HsValBinds RdrName -> RnM (HsValBinds Name, DefUses) rnTopBindsSrc :: HsValBinds RdrName -> RnM (HsValBinds Name, DefUses)
rnTopBindsSrc binds@(ValBindsIn mbinds _) rnTopBindsSrc binds@(ValBindsIn mbinds _)
= bindPatSigTyVars (collectSigTysFromHsBinds (bagToList mbinds)) $ \ _ -> = bindPatSigTyVars (collectSigTysFromHsBinds mbinds) $ \ _ ->
-- Hmm; by analogy with Ids, this doesn't look right -- Hmm; by analogy with Ids, this doesn't look right
-- Top-level bound type vars should really scope over -- Top-level bound type vars should really scope over
-- everything, but we only scope them over the other bindings -- everything, but we only scope them over the other bindings
...@@ -185,7 +186,7 @@ rnTopBindsSrc binds@(ValBindsIn mbinds _) ...@@ -185,7 +186,7 @@ rnTopBindsSrc binds@(ValBindsIn mbinds _)
do { (binds', dus) <- rnValBinds noTrim binds do { (binds', dus) <- rnValBinds noTrim binds
-- Warn about missing signatures, -- Warn about missing signatures,
; let { ValBindsIn _ sigs' = binds' ; let { ValBindsOut _ sigs' = binds'
; ty_sig_vars = mkNameSet [ unLoc n | L _ (Sig n _) <- sigs'] ; ty_sig_vars = mkNameSet [ unLoc n | L _ (Sig n _) <- sigs']
; un_sigd_bndrs = duDefs dus `minusNameSet` ty_sig_vars } ; un_sigd_bndrs = duDefs dus `minusNameSet` ty_sig_vars }
...@@ -253,7 +254,7 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside ...@@ -253,7 +254,7 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside
-- current scope, inventing new names for the new binders -- current scope, inventing new names for the new binders
-- This also checks that the names form a set -- This also checks that the names form a set
bindLocatedLocalsRn doc mbinders_w_srclocs $ \ bndrs -> bindLocatedLocalsRn doc mbinders_w_srclocs $ \ bndrs ->
bindPatSigTyVarsFV (collectSigTysFromHsBinds (bagToList mbinds)) $ bindPatSigTyVarsFV (collectSigTysFromHsBinds mbinds) $
-- Then install local fixity declarations -- Then install local fixity declarations
-- Notice that they scope over thing_inside too -- Notice that they scope over thing_inside too
...@@ -267,12 +268,7 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside ...@@ -267,12 +268,7 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside
-- Final error checking -- Final error checking
let let
all_uses = duUses bind_dus `plusFV` result_fvs all_uses = duUses bind_dus `plusFV` result_fvs
unused_bndrs = [ b | b <- bndrs, not (b `elemNameSet` all_uses)]
in
warnUnusedLocalBinds unused_bndrs `thenM_`
returnM (result, delListFromNameSet all_uses bndrs)
-- duUses: It's important to return all the uses, not the 'real uses' -- duUses: It's important to return all the uses, not the 'real uses'
-- used for warning about unused bindings. Otherwise consider: -- used for warning about unused bindings. Otherwise consider:
-- x = 3 -- x = 3
...@@ -280,6 +276,12 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside ...@@ -280,6 +276,12 @@ rnValBindsAndThen binds@(ValBindsIn mbinds sigs) thing_inside
-- If we don't "see" the dependency of 'y' on 'x', we may put the -- If we don't "see" the dependency of 'y' on 'x', we may put the
-- bindings in the wrong order, and the type checker will complain -- bindings in the wrong order, and the type checker will complain
-- that x isn't in scope -- that x isn't in scope
unused_bndrs = [ b | b <- bndrs, not (b `elemNameSet` all_uses)]
in
warnUnusedLocalBinds unused_bndrs `thenM_`
returnM (result, delListFromNameSet all_uses bndrs)
where where
mbinders_w_srclocs = collectHsBindLocatedBinders mbinds mbinders_w_srclocs = collectHsBindLocatedBinders mbinds
doc = text "In the binding group for:" doc = text "In the binding group for:"
...@@ -294,21 +296,46 @@ rnValBinds :: (FreeVars -> FreeVars) ...@@ -294,21 +296,46 @@ rnValBinds :: (FreeVars -> FreeVars)
rnValBinds trim (ValBindsIn mbinds sigs) rnValBinds trim (ValBindsIn mbinds sigs)
= do { sigs' <- rename_sigs sigs = do { sigs' <- rename_sigs sigs
; let { rn_bind = wrapLocFstM (rnBind sig_fn trim) ; binds_w_dus <- mapBagM (rnBind (mkSigTvFn sigs') trim) mbinds
; sig_fn = mkSigTvFn sigs' }
; (mbinds', du_bag) <- mapAndUnzipBagM rn_bind mbinds ; let (binds', bind_dus) = depAnalBinds binds_w_dus
; let defs, uses :: NameSet ; check_sigs (okBindSig (duDefs bind_dus)) sigs'
(defs, uses) = foldrBag plus (emptyNameSet, emptyNameSet) du_bag
plus (ds1,us1) (ds2,us2) = (ds1 `unionNameSets` ds2,
us1 `unionNameSets` us2)
; check_sigs (okBindSig defs) sigs' ; return (ValBindsOut binds' sigs',
usesOnly (hsSigsFVs sigs') `plusDU` bind_dus) }
---------------------
depAnalBinds :: Bag (LHsBind Name, [Name], Uses)
-> ([(RecFlag, LHsBinds Name)], DefUses)
-- Dependency analysis; this is important so that unused-binding
-- reporting is accurate
depAnalBinds binds_w_dus
= (map get_binds sccs, map get_du sccs)
where
sccs = stronglyConnComp edges
keyd_nodes = bagToList binds_w_dus `zip` [0::Int ..]
edges = [ (node, key, [fromJust mb_key | n <- nameSetToList uses,
let mb_key = lookupNameEnv key_map n,
isJust mb_key ])
| (node@(_,_,uses), key) <- keyd_nodes ]
key_map :: NameEnv Int -- Which binding it comes from
key_map = mkNameEnv [(bndr, key) | ((_, bndrs, _), key) <- keyd_nodes
, bndr <- bndrs ]
get_binds (AcyclicSCC (bind, _, _)) = (NonRecursive, unitBag bind)
get_binds (CyclicSCC binds_w_dus) = (Recursive, listToBag [b | (b,d,u) <- binds_w_dus])
get_du (AcyclicSCC (_, bndrs, uses)) = (Just (mkNameSet bndrs), uses)
get_du (CyclicSCC binds_w_dus) = (Just defs, uses)
where
defs = mkNameSet [b | (_,bs,_) <- binds_w_dus, b <- bs]
uses = unionManyNameSets [u | (_,_,u) <- binds_w_dus]
; traceRn (text "rnValBind" <+> (ppr defs $$ ppr uses))
; return (ValBindsIn mbinds' sigs',
[(Just defs, uses `plusFV` hsSigsFVs sigs')]) }
--------------------- ---------------------
-- Bind the top-level forall'd type variables in the sigs. -- Bind the top-level forall'd type variables in the sigs.
...@@ -348,31 +375,30 @@ trimWith bndrs = intersectNameSet (mkNameSet bndrs) ...@@ -348,31 +375,30 @@ trimWith bndrs = intersectNameSet (mkNameSet bndrs)
--------------------- ---------------------
rnBind :: (Name -> [Name]) -- Signature tyvar function rnBind :: (Name -> [Name]) -- Signature tyvar function
-> (FreeVars -> FreeVars) -- Trimming function for rhs free vars -> (FreeVars -> FreeVars) -- Trimming function for rhs free vars
-> HsBind RdrName -> LHsBind RdrName
-> RnM (HsBind Name, (Defs, Uses)) -> RnM (LHsBind Name, [Name], Uses)
rnBind sig_fn trim (PatBind pat grhss ty _) rnBind sig_fn trim (L loc (PatBind pat grhss ty _))
= do { (pat', pat_fvs) <- rnLPat pat = setSrcSpan loc $
do { (pat', pat_fvs) <- rnLPat pat
; let bndrs = collectPatBinders pat' ; let bndrs = collectPatBinders pat'
; (grhss', fvs) <- bindSigTyVarsFV (concatMap sig_fn bndrs) $ ; (grhss', fvs) <- bindSigTyVarsFV (concatMap sig_fn bndrs) $
rnGRHSs PatBindRhs grhss rnGRHSs PatBindRhs grhss
; return (PatBind pat' grhss' ty (trim fvs), ; return (L loc (PatBind pat' grhss' ty (trim fvs)), bndrs, pat_fvs `plusFV` fvs) }
(mkNameSet bndrs, pat_fvs `plusFV` fvs)) }
rnBind sig_fn trim (FunBind name inf matches _) rnBind sig_fn trim (L loc (FunBind name inf matches _))
= do { new_name <- lookupLocatedBndrRn name = setSrcSpan loc $
; let { plain_name = unLoc new_name do { new_name <- lookupLocatedBndrRn name
; bndrs = unitNameSet plain_name } ; let plain_name = unLoc new_name
; (matches', fvs) <- bindSigTyVarsFV (sig_fn plain_name) $ ; (matches', fvs) <- bindSigTyVarsFV (sig_fn plain_name) $
rnMatchGroup (FunRhs plain_name) matches rnMatchGroup (FunRhs plain_name) matches
; checkPrecMatch inf plain_name matches' ; checkPrecMatch inf plain_name matches'
; return (FunBind new_name inf matches' (trim fvs), ; return (L loc (FunBind new_name inf matches' (trim fvs)), [plain_name], fvs)
(bndrs, fvs))
} }
\end{code} \end{code}
......
...@@ -20,7 +20,7 @@ import HsSyn ( HsExpr(..), HsBind(..), LHsBinds, LHsBind, Sig(..), ...@@ -20,7 +20,7 @@ import HsSyn ( HsExpr(..), HsBind(..), LHsBinds, LHsBind, Sig(..),
LSig, Match(..), IPBind(..), Prag(..), LSig, Match(..), IPBind(..), Prag(..),
HsType(..), LHsType, HsExplicitForAll(..), hsLTyVarNames, HsType(..), LHsType, HsExplicitForAll(..), hsLTyVarNames,
isVanillaLSig, sigName, placeHolderNames, isPragLSig, isVanillaLSig, sigName, placeHolderNames, isPragLSig,
LPat, GRHSs, MatchGroup(..), isEmptyLHsBinds, LPat, GRHSs, MatchGroup(..), isEmptyLHsBinds, pprLHsBinds,
collectHsBindBinders, collectPatBinders, pprPatBind collectHsBindBinders, collectPatBinders, pprPatBind
) )
import TcHsSyn ( zonkId, (<$>) ) import TcHsSyn ( zonkId, (<$>) )
...@@ -59,8 +59,8 @@ import VarSet ...@@ -59,8 +59,8 @@ import VarSet
import SrcLoc ( Located(..), unLoc, getLoc ) import SrcLoc ( Located(..), unLoc, getLoc )
import Bag import Bag
import ErrUtils ( Message ) import ErrUtils ( Message )
import Digraph ( SCC(..), stronglyConnComp, flattenSCC ) import Digraph ( SCC(..), stronglyConnComp )
import Maybes ( fromJust, isJust, orElse, catMaybes ) import Maybes ( fromJust, isJust, isNothing, orElse, catMaybes )
import Util ( singleton ) import Util ( singleton )
import BasicTypes ( TopLevelFlag(..), isTopLevel, isNotTopLevel, import BasicTypes ( TopLevelFlag(..), isTopLevel, isNotTopLevel,
RecFlag(..), isNonRec ) RecFlag(..), isNonRec )
...@@ -105,7 +105,7 @@ tcTopBinds :: HsValBinds Name -> TcM (LHsBinds TcId, TcLclEnv) ...@@ -105,7 +105,7 @@ tcTopBinds :: HsValBinds Name -> TcM (LHsBinds TcId, TcLclEnv)
-- want. The bit we care about is the local bindings -- want. The bit we care about is the local bindings
-- and the free type variables thereof -- and the free type variables thereof
tcTopBinds binds tcTopBinds binds
= do { (ValBindsOut prs, env) <- tcValBinds TopLevel binds getLclEnv = do { (ValBindsOut prs _, env) <- tcValBinds TopLevel binds getLclEnv
; return (foldr (unionBags . snd) emptyBag prs, env) } ; return (foldr (unionBags . snd) emptyBag prs, env) }
-- The top level bindings are flattened into a giant -- The top level bindings are flattened into a giant
-- implicitly-mutually-recursive LHsBinds -- implicitly-mutually-recursive LHsBinds
...@@ -156,40 +156,12 @@ tcLocalBinds (HsIPBinds (IPBinds ip_binds _)) thing_inside ...@@ -156,40 +156,12 @@ tcLocalBinds (HsIPBinds (IPBinds ip_binds _)) thing_inside
tcCheckRho expr ty `thenM` \ expr' -> tcCheckRho expr ty `thenM` \ expr' ->
returnM (ip_inst, (IPBind ip' expr')) returnM (ip_inst, (IPBind ip' expr'))
------------------------
mkEdges :: (Name -> Bool) -> [LHsBind Name]
-> [(LHsBind Name, BKey, [BKey])]
type BKey = Int -- Just number off the bindings
mkEdges exclude_fn binds
= [ (bind, key, [fromJust mb_key | n <- nameSetToList (bind_fvs (unLoc bind)),
let mb_key = lookupNameEnv key_map n,
isJust mb_key,
not (exclude_fn n) ])
| (bind, key) <- keyd_binds
]
where
keyd_binds = binds `zip` [0::BKey ..]
bind_fvs (FunBind _ _ _ fvs) = fvs
bind_fvs (PatBind _ _ _ fvs) = fvs
bind_fvs bind = pprPanic "mkEdges" (ppr bind)
key_map :: NameEnv BKey -- Which binding it comes from
key_map = mkNameEnv [(bndr, key) | (L _ bind, key) <- keyd_binds
, bndr <- bindersOfHsBind bind ]
bindersOfHsBind :: HsBind Name -> [Name]
bindersOfHsBind (PatBind pat _ _ _) = collectPatBinders pat
bindersOfHsBind (FunBind (L _ f) _ _ _) = [f]
------------------------ ------------------------
tcValBinds :: TopLevelFlag tcValBinds :: TopLevelFlag
-> HsValBinds Name -> TcM thing -> HsValBinds Name -> TcM thing
-> TcM (HsValBinds TcId, thing) -> TcM (HsValBinds TcId, thing)
tcValBinds top_lvl (ValBindsIn binds sigs) thing_inside tcValBinds top_lvl (ValBindsOut binds sigs) thing_inside
= tcAddLetBoundTyVars binds $ = tcAddLetBoundTyVars binds $
-- BRING ANY SCOPED TYPE VARIABLES INTO SCOPE -- BRING ANY SCOPED TYPE VARIABLES INTO SCOPE
-- Notice that they scope over -- Notice that they scope over
...@@ -199,11 +171,7 @@ tcValBinds top_lvl (ValBindsIn binds sigs) thing_inside ...@@ -199,11 +171,7 @@ tcValBinds top_lvl (ValBindsIn binds sigs) thing_inside
do { -- Typecheck the signature do { -- Typecheck the signature
tc_ty_sigs <- recoverM (returnM []) (tcTySigs sigs) tc_ty_sigs <- recoverM (returnM []) (tcTySigs sigs)
; let { prag_fn = mkPragFun sigs
-- Do the basic strongly-connected component thing
; let { sccs :: [SCC (LHsBind Name)]
; sccs = stronglyConnComp (mkEdges (\n -> False) (bagToList binds))
; prag_fn = mkPragFun sigs
; sig_fn = lookupSig tc_ty_sigs ; sig_fn = lookupSig tc_ty_sigs
; sig_ids = map sig_id tc_ty_sigs } ; sig_ids = map sig_id tc_ty_sigs }
...@@ -211,13 +179,13 @@ tcValBinds top_lvl (ValBindsIn binds sigs) thing_inside ...@@ -211,13 +179,13 @@ tcValBinds top_lvl (ValBindsIn binds sigs) thing_inside
-- the Ids declared with type signatures -- the Ids declared with type signatures
; (binds', thing) <- tcExtendIdEnv sig_ids $ ; (binds', thing) <- tcExtendIdEnv sig_ids $
tc_val_binds top_lvl sig_fn prag_fn tc_val_binds top_lvl sig_fn prag_fn
sccs thing_inside binds thing_inside
; return (ValBindsOut binds', thing) } ; return (ValBindsOut binds' sigs, thing) }
------------------------ ------------------------
tc_val_binds :: TopLevelFlag -> TcSigFun -> TcPragFun tc_val_binds :: TopLevelFlag -> TcSigFun -> TcPragFun
-> [SCC (LHsBind Name)] -> TcM thing -> [(RecFlag, LHsBinds Name)] -> TcM thing
-> TcM ([(RecFlag, LHsBinds TcId)], thing) -> TcM ([(RecFlag, LHsBinds TcId)], thing)
-- Typecheck a whole lot of value bindings, -- Typecheck a whole lot of value bindings,
-- one strongly-connected component at a time -- one strongly-connected component at a time
...@@ -226,62 +194,94 @@ tc_val_binds top_lvl sig_fn prag_fn [] thing_inside ...@@ -226,62 +194,94 @@ tc_val_binds top_lvl sig_fn prag_fn [] thing_inside
= do { thing <- thing_inside = do { thing <- thing_inside
; return ([], thing) } ; return ([], thing) }
tc_val_binds top_lvl sig_fn prag_fn (scc : sccs) thing_inside tc_val_binds top_lvl sig_fn prag_fn (group : groups) thing_inside
= do { (group', (groups', thing)) = do { (group', (groups', thing))
<- tc_group top_lvl sig_fn prag_fn scc $ <- tc_group top_lvl sig_fn prag_fn group $
tc_val_binds top_lvl sig_fn prag_fn sccs thing_inside tc_val_binds top_lvl sig_fn prag_fn groups thing_inside
; return (group' ++ groups', thing) } ; return (group' ++ groups', thing) }
------------------------ ------------------------
tc_group :: TopLevelFlag -> TcSigFun -> TcPragFun tc_group :: TopLevelFlag -> TcSigFun -> TcPragFun
-> SCC (LHsBind Name) -> TcM thing -> (RecFlag, LHsBinds Name) -> TcM thing
-> TcM ([(RecFlag, LHsBinds TcId)], thing) -> TcM ([(RecFlag, LHsBinds TcId)], thing)
-- Typecheck one strongly-connected component of the original program. -- Typecheck one strongly-connected component of the original program.
-- We get a list of groups back, because there may -- We get a list of groups back, because there may
-- be specialisations etc as well -- be specialisations etc as well
tc_group top_lvl sig_fn prag_fn scc@(AcyclicSCC bind) thing_inside tc_group top_lvl sig_fn prag_fn (NonRecursive, binds) thing_inside
= -- A single non-recursive binding = -- A single non-recursive binding
-- We want to keep non-recursive things non-recursive -- We want to keep non-recursive things non-recursive
-- so that we desugar unlifted bindings correctly -- so that we desugar unlifted bindings correctly
do { (binds, thing) <- tcPolyBinds top_lvl NonRecursive do { (binds, thing) <- tcPolyBinds top_lvl NonRecursive NonRecursive
sig_fn prag_fn scc thing_inside sig_fn prag_fn binds thing_inside
; return ([(NonRecursive, b) | b <- binds], thing) } ; return ([(NonRecursive, b) | b <- binds], thing) }
tc_group top_lvl sig_fn prag_fn scc@(CyclicSCC binds) thing_inside tc_group top_lvl sig_fn prag_fn (Recursive, binds) thing_inside
= -- A recursive strongly-connected component = -- A recursive strongly-connected component
-- To maximise polymorphism (with -fglasgow-exts), we do a new -- To maximise polymorphism (with -fglasgow-exts), we do a new
-- strongly-connected component analysis, this time omitting -- strongly-connected component analysis, this time omitting
-- any references to variables with type signatures. -- any references to variables with type signatures.
-- --
-- Then we bring into scope all the variables with type signatures -- Then we bring into scope all the variables with type signatures
do { traceTc (text "tc_group rec" <+> vcat [ppr b $$ text "--and--" | b <- binds]) do { traceTc (text "tc_group rec" <+> pprLHsBinds binds)
; gla_exts <- doptM Opt_GlasgowExts ; gla_exts <- doptM Opt_GlasgowExts
; (binds,thing) <- if gla_exts ; (binds,thing) <- if gla_exts
then go new_sccs then go new_sccs
else go1 scc thing_inside else tc_binds Recursive binds thing_inside
; return ([(Recursive, unionManyBags binds)], thing) } ; return ([(Recursive, unionManyBags binds)], thing) }
-- Rec them all together -- Rec them all together
where where
new_sccs :: [SCC (LHsBind Name)] new_sccs :: [SCC (LHsBind Name)]
new_sccs = stronglyConnComp (mkEdges has_sig binds) new_sccs = stronglyConnComp (mkEdges sig_fn binds)
-- go :: SCC (LHsBind Name) -> TcM ([LHsBind TcId], thing) -- go :: SCC (LHsBind Name) -> TcM ([LHsBind TcId], thing)
go (scc:sccs) = do { (binds1, (binds2, thing)) <- go1 scc (go sccs) go (scc:sccs) = do { (binds1, (binds2, thing)) <- go1 scc (go sccs)
; return (binds1 ++ binds2, thing) } ; return (binds1 ++ binds2, thing) }
go [] = do { thing <- thing_inside; return ([], thing) } go [] = do { thing <- thing_inside; return ([], thing) }
go1 scc thing_inside = tcPolyBinds top_lvl Recursive go1 (AcyclicSCC bind) = tc_binds NonRecursive (unitBag bind)
sig_fn prag_fn scc thing_inside go1 (CyclicSCC binds) = tc_binds Recursive (listToBag binds)
has_sig :: Name -> Bool tc_binds rec_tc binds = tcPolyBinds top_lvl Recursive rec_tc sig_fn prag_fn binds
has_sig n = isJust (sig_fn n)
------------------------
mkEdges :: TcSigFun -> LHsBinds Name
-> [(LHsBind Name, BKey, [BKey])]
type BKey = Int -- Just number off the bindings
mkEdges sig_fn binds
= [ (bind, key, [fromJust mb_key | n <- nameSetToList (bind_fvs (unLoc bind)),
let mb_key = lookupNameEnv key_map n,
isJust mb_key,
no_sig n ])
| (bind, key) <- keyd_binds
]
where
no_sig :: Name -> Bool
no_sig n = isNothing (sig_fn n)
keyd_binds = bagToList binds `zip` [0::BKey ..]
bind_fvs (FunBind _ _ _ fvs) = fvs
bind_fvs (PatBind _ _ _ fvs) =