Commit 6bab649b authored by Simon Peyton Jones's avatar Simon Peyton Jones
Browse files

Improve checking of joins in Core Lint

This patch addresses the rather expensive treatment of join points,
identified in Trac #13220 comment:17

Before we were tracking the "bad joins".  Now we track the good ones.
That is easier to think about, and much more efficient; see CoreLint
Note [Join points].

On the way I did some other modest refactoring, among other things
removing a duplicated call of lintIdBndr for let-bindings.

On teh
parent fc9d152b
......@@ -151,7 +151,6 @@ find an occurrence of an Id, we fetch it from the in-scope set.
Note [Bad unsafe coercion]
~~~~~~~~~~~~~~~~~~~~~~~~~~
For discussion see https://ghc.haskell.org/trac/ghc/wiki/BadUnsafeCoercions
Linter introduces additional rules that checks improper coercion between
different types, called bad coercions. Following coercions are forbidden:
......@@ -170,12 +169,10 @@ different types, called bad coercions. Following coercions are forbidden:
Note [Join points]
~~~~~~~~~~~~~~~~~~
We check the rules listed in Note [Invariants on join points] in CoreSyn. The
only one that causes any difficulty is the first: All occurrences must be tail
calls. To this end, along with the in-scope set, we remember in le_bad_joins the
subset of join ids that are no longer allowed because they were declared "too
far away." For example:
calls. To this end, along with the in-scope set, we remember in le_joins the
subset of in-scope Ids that are valid join ids. For example:
join j x = ... in
case e of
......@@ -184,11 +181,11 @@ far away." For example:
C -> join h = jump j w in ... -- good
D -> let x = jump j v in ... -- BAD
A join point remains valid in case branches, so when checking the A branch, j
is still valid. When we check the scrutinee of the inner case, however, we add j
to le_bad_joins and catch the error. Similarly, join points can occur free in
RHSes of other join points but not the RHSes of value bindings (thunks and
functions).
A join point remains valid in case branches, so when checking the A
branch, j is still valid. When we check the scrutinee of the inner
case, however, we set le_joins to empty, and catch the
error. Similarly, join points can occur free in RHSes of other join
points but not the RHSes of value bindings (thunks and functions).
************************************************************************
* *
......@@ -387,10 +384,9 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc,
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lintCoreBindings dflags pass local_in_scope binds
= initL dflags flags $
addLoc TopLevelBindings $
addInScopeVars local_in_scope $
addInScopeVars binders $
= initL dflags flags in_scope_set $
addLoc TopLevelBindings $
lintIdBndrs TopLevel binders $
-- Put all the top-level binders in scope at the start
-- This is because transformation rules can bring something
-- into use 'unexpectedly'
......@@ -398,6 +394,8 @@ lintCoreBindings dflags pass local_in_scope binds
; checkL (null ext_dups) (dupExtVars ext_dups)
; mapM lint_bind binds }
where
in_scope_set = mkInScopeSet (mkVarSet local_in_scope)
flags = LF { lf_check_global_ids = check_globals
, lf_check_inline_loop_breakers = check_lbs
, lf_check_static_ptrs = check_static_ptrs }
......@@ -463,9 +461,9 @@ lintUnfolding dflags locn vars expr
| isEmptyBag errs = Nothing
| otherwise = Just (pprMessageBag errs)
where
(_warns, errs) = initL dflags defaultLintFlags linter
in_scope = mkInScopeSet vars
(_warns, errs) = initL dflags defaultLintFlags in_scope linter
linter = addLoc (ImportedUnfolding locn) $
addInScopeVarSet vars $
lintCoreExpr expr
lintExpr :: DynFlags
......@@ -477,9 +475,9 @@ lintExpr dflags vars expr
| isEmptyBag errs = Nothing
| otherwise = Just (pprMessageBag errs)
where
(_warns, errs) = initL dflags defaultLintFlags linter
in_scope = mkInScopeSet (mkVarSet vars)
(_warns, errs) = initL dflags defaultLintFlags in_scope linter
linter = addLoc TopLevelBindings $
addInScopeVars vars $
lintCoreExpr expr
{-
......@@ -499,7 +497,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
= addLoc (RhsOf binder) $
-- Check the rhs
do { ty <- lintRhs binder rhs
; lint_bndr binder -- Check match to RHS type
; binder_ty <- applySubstTy (idType binder)
; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty)
......@@ -571,11 +568,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
-- We should check the unfolding, if any, but this is tricky because
-- the unfolding is a SimplifiableCoreExpr. Give up for now.
where
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lint_bndr var | isId var = lintIdBndr top_lvl_flag var $ \_ -> return ()
| otherwise = return ()
-- | Checks the RHS of bindings. It only differs from 'lintCoreExpr'
-- in that it doesn't reject occurrences of the function 'makeStatic' when they
......@@ -680,7 +672,7 @@ lintCoreExpr :: CoreExpr -> LintM OutType
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lintCoreExpr (Var var)
= lintCoreVar var 0
= lintVarOcc var 0
lintCoreExpr (Lit lit)
= return (literalType lit)
......@@ -726,13 +718,16 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
| isId bndr
= do { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs)
; addLoc (BodyOfLetRec [bndr])
(lintIdBndr NotTopLevel bndr $ \_ -> lintCoreExpr body) }
(lintIdBndr NotTopLevel bndr $ \_ ->
addGoodJoins [bndr] $
lintCoreExpr body) }
| otherwise
= failWithL (mkLetErr bndr rhs) -- Not quite accurate
lintCoreExpr (Let (Rec pairs) body)
= lintIdBndrs bndrs $ \_ ->
= lintIdBndrs NotTopLevel bndrs $
addGoodJoins bndrs $
do { checkL (null dups) (dupVars dups)
; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
mkInconsistentRecMsg bndrs
......@@ -812,51 +807,38 @@ lintCoreExpr (Coercion co)
= do { (k1, k2, ty1, ty2, role) <- lintInCo co
; return (mkHeteroCoercionType role k1 k2 ty1 ty2) }
lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed
----------------------
lintVarOcc :: Var -> Int -- Number of arguments (type or value) being passed
-> LintM Type -- returns type of the *variable*
lintCoreVar var nargs
lintVarOcc var nargs
= do { checkL (isNonCoVarId var)
(text "Non term variable" <+> ppr var)
; lf <- getLintFlags
-- Cneck that the type of the occurrence is the same
-- as the type of the binding site
; ty <- applySubstTy (idType var)
; var' <- lookupIdInScope var
; let ty' = idType var'
; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
-- Check for a nested occurrence of the StaticPtr constructor.
-- See Note [Checking StaticPtrs].
; lf <- getLintFlags
; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $
checkL (idName var /= makeStaticName) $
text "Found makeStatic nested in an expression"
; checkDeadIdOcc var
; ty <- applySubstTy (idType var)
; var' <- lookupIdInScope var
; let ty' = idType var'
; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
; mb_join_arity
<- case isJoinId_maybe var' of
Just join_arity ->
do { checkL (isJoinId_maybe var == Just join_arity) $
mkJoinBndrOccMismatchMsg var' var
; return $ Just join_arity }
Nothing ->
case tailCallInfo (idOccInfo var') of
AlwaysTailCalled join_arity -> return $ Just join_arity
-- This function will be turned into a join point by the
-- simplifier; typecheck it as if it already were one
NoTailCallInfo -> return $ Nothing
; case mb_join_arity of
Just join_arity ->
do { bad <- isBadJoin var'
; checkL (not bad) $ mkJoinOutOfScopeMsg var'
; checkL (nargs == join_arity) $
mkBadJumpMsg var' join_arity nargs }
Nothing ->
do { checkL (not (isJoinId var)) $
mkJoinBndrOccMismatchMsg var' var }
; checkJoinOcc var nargs
; return (idType var') }
lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed
-> LintM Type -- returns type of the *function*
lintCoreFun :: CoreExpr
-> Int -- Number of arguments (type or val) being passed
-> LintM Type -- Returns type of the *function*
lintCoreFun (Var var) nargs
= lintCoreVar var nargs
= lintVarOcc var nargs
lintCoreFun (Lam var body) nargs
-- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see
-- Note [Beta redexes]
......@@ -865,10 +847,47 @@ lintCoreFun (Lam var body) nargs
lintBinder var $ \ var' ->
do { body_ty <- lintCoreFun body (nargs - 1)
; return $ mkLamType var' body_ty }
lintCoreFun expr nargs
= markAllJoinsBadIf (nargs /= 0) $
lintCoreExpr expr
------------------
checkDeadIdOcc :: Id -> LintM ()
-- Occurrences of an Id should never be dead....
-- except when we are checking a case pattern
checkDeadIdOcc id
| isDeadOcc (idOccInfo id)
= do { in_case <- inCasePat
; checkL in_case
(text "Occurrence of a dead Id" <+> ppr id) }
| otherwise
= return ()
------------------
checkJoinOcc :: Id -> JoinArity -> LintM ()
-- Check that if the occurrence is a JoinId, then so is the
-- binding site, and it's a valid join Id
checkJoinOcc var n_args
| Just join_arity_occ <- isJoinId_maybe var
= do { mb_join_arity_bndr <- lookupJoinId var
; case mb_join_arity_bndr of {
Nothing -> -- Binder is not a join point
addErrL (invalidJoinOcc var) ;
Just join_arity_bndr ->
do { checkL (join_arity_bndr == join_arity_occ) $
-- Arity differs at binding site and occurrence
mkJoinBndrOccMismatchMsg var join_arity_bndr join_arity_occ
; checkL (n_args == join_arity_occ) $
-- Arity doesn't match #args
mkBadJumpMsg var join_arity_occ n_args } } }
| otherwise
= return ()
{-
Note [No alternatives lint check]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -1010,17 +1029,6 @@ lintTyKind tyvar arg_ty
where
tyvar_kind = tyVarKind tyvar
checkDeadIdOcc :: Id -> LintM ()
-- Occurrences of an Id should never be dead....
-- except when we are checking a case pattern
checkDeadIdOcc id
| isDeadOcc (idOccInfo id)
= do { in_case <- inCasePat
; checkL in_case
(text "Occurrence of a dead Id" <+> ppr id) }
| otherwise
= return ()
{-
************************************************************************
* *
......@@ -1152,21 +1160,22 @@ lintCoBndr cv thing_inside
(text "CoVar with non-coercion type:" <+> pprTyVar cv)
; updateTCvSubst subst' (thing_inside cv') }
lintIdBndrs :: [Var] -> ([Var] -> LintM a) -> LintM a
lintIdBndrs ids linterF
lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
lintIdBndrs top_lvl ids linterF
= go ids
where
go [] = linterF []
go (id:ids) = lintIdBndr NotTopLevel id $ \id ->
lintIdBndrs ids $ \ids ->
linterF (id:ids)
go [] = linterF
go (id:ids) = lintIdBndr top_lvl id $ \_ ->
lintIdBndrs top_lvl ids $
linterF
lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a
-- Do substitution on the type of a binder and add the var with this
-- new type to the in-scope set of the second argument
-- ToDo: lint its rules
lintIdBndr top_lvl id linterF
= do { flags <- getLintFlags
= ASSERT2( isId id, ppr id )
do { flags <- getLintFlags
; checkL (not (lf_check_global_ids flags) || isLocalId id)
(text "Non-local Id binder" <+> ppr id)
-- See Note [Checking for global Ids]
......@@ -1784,7 +1793,8 @@ data LintEnv
, le_subst :: TCvSubst -- Current type substitution; we also use this
-- to keep track of all the variables in scope,
-- both Ids and TyVars
, le_bad_joins :: IdSet -- Join points that are no longer valid
, le_joins :: IdSet -- Join points in scope that are valid
-- A subset of teh InScopeSet in le_subst
-- See Note [Join points]
, le_dynflags :: DynFlags -- DynamicFlags
}
......@@ -1891,13 +1901,17 @@ data LintLocInfo
| InType Type -- Inside a type
| InCo Coercion -- Inside a coercion
initL :: DynFlags -> LintFlags -> LintM a -> WarnsAndErrs -- Errors and warnings
initL dflags flags m
initL :: DynFlags -> LintFlags -> InScopeSet
-> LintM a -> WarnsAndErrs -- Errors and warnings
initL dflags flags in_scope m
= case unLintM m env (emptyBag, emptyBag) of
(_, errs) -> errs
where
env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = []
, le_dynflags = dflags, le_bad_joins = emptyVarSet }
env = LE { le_flags = flags
, le_subst = mkEmptyTCvSubst in_scope
, le_joins = emptyVarSet
, le_loc = []
, le_dynflags = dflags }
getLintFlags :: LintM LintFlags
getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs)
......@@ -1952,29 +1966,12 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs)
is_case_pat (LE { le_loc = CasePat {} : _ }) = True
is_case_pat _other = False
addInScopeVars :: [Var] -> LintM a -> LintM a
addInScopeVars vars m
= LintM $ \ env errs ->
unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars
, le_bad_joins = bad_joins' env })
errs
where
bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars)
addInScopeVarSet :: VarSet -> LintM a -> LintM a
addInScopeVarSet vars m
= LintM $ \ env errs ->
unLintM m (env { le_subst = extendTCvInScopeSet (le_subst env) vars })
errs
addInScopeVar :: Var -> LintM a -> LintM a
addInScopeVar var m
= LintM $ \ env errs ->
unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
, le_bad_joins = bad_joins' env }) errs
where
bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var
| otherwise = le_bad_joins env
unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
, le_joins = delVarSet (le_joins env) var
}) errs
extendSubstL :: TyVar -> Type -> LintM a -> LintM a
extendSubstL tv ty m
......@@ -1987,16 +1984,25 @@ updateTCvSubst subst' m
markAllJoinsBad :: LintM a -> LintM a
markAllJoinsBad m
= LintM $ \ env errs -> unLintM m (marked env) errs
where
marked env = env { le_bad_joins = filterVarSet isJoinId in_set }
where
in_set = getInScopeVars (getTCvInScope (le_subst env))
= LintM $ \ env errs -> unLintM m (env { le_joins = emptyVarSet }) errs
markAllJoinsBadIf :: Bool -> LintM a -> LintM a
markAllJoinsBadIf True m = markAllJoinsBad m
markAllJoinsBadIf False m = m
addGoodJoins :: [Var] -> LintM a -> LintM a
addGoodJoins vars thing_inside
| null join_ids
= thing_inside
| otherwise
= LintM $ \ env errs -> unLintM thing_inside (add_joins env) errs
where
add_joins env = env { le_joins = le_joins env `extendVarSetList` join_ids }
join_ids = filter isJoinId vars
getValidJoins :: LintM IdSet
getValidJoins = LintM (\ env errs -> (Just (le_joins env), errs))
getTCvSubst :: LintM TCvSubst
getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs))
......@@ -2022,9 +2028,14 @@ lookupIdInScope id
where
out_of_scope = pprBndr LetBind id <+> text "is out of scope"
isBadJoin :: Id -> LintM Bool
isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env),
errs)
lookupJoinId :: Id -> LintM (Maybe JoinArity)
-- Look up an Id which should be a join point, valid here
-- If so, return its arity, if not return Nothing
lookupJoinId id
= do { join_set <- getValidJoins
; case lookupVarSet join_set id of
Just id' -> return (isJoinId_maybe id')
Nothing -> return Nothing }
lintTyCoVarInScope :: Var -> LintM ()
lintTyCoVarInScope v = lintInScope (text "is out of scope") v
......@@ -2294,9 +2305,10 @@ mkBadJoinArityMsg var ar nlams
text "Join arity:" <+> ppr ar,
text "Number of lambdas:" <+> ppr nlams ]
mkJoinOutOfScopeMsg :: Var -> SDoc
mkJoinOutOfScopeMsg var
= text "Join variable no longer in scope:" <+> ppr var
invalidJoinOcc :: Var -> SDoc
invalidJoinOcc var
= vcat [ text "Invalid occurrence of a join variable:" <+> ppr var
, text "The binder is either not a join point, or not valid here" ]
mkBadJumpMsg :: Var -> Int -> Int -> SDoc
mkBadJumpMsg var ar nargs
......@@ -2312,17 +2324,12 @@ mkInconsistentRecMsg bndrs
where
ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr)
mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc
mkJoinBndrOccMismatchMsg bndr var
= vcat [ text "Mismatch in join point status between binder and occurrence",
text "Var:" <+> ppr bndr,
text "Binder:" <+> ppr_join_status bndr,
text "Occ:" <+> ppr_join_status var ]
where
ppr_join_status v = case details of JoinId _ -> ppr details
_ -> text "not a join id"
where
details = idDetails v
mkJoinBndrOccMismatchMsg :: Var -> JoinArity -> JoinArity -> SDoc
mkJoinBndrOccMismatchMsg bndr join_arity_bndr join_arity_occ
= vcat [ text "Mismatch in join point arity between binder and occurrence"
, text "Var:" <+> ppr bndr
, text "Arity at binding site:" <+> ppr join_arity_bndr
, text "Arity at occurrence: " <+> ppr join_arity_occ ]
mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc
mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment