Commit 1a4c04b1 authored by Simon Peyton Jones's avatar Simon Peyton Jones
Browse files

Fix 'SPECIALISE instance'

Trac #12944 showed that the DsBinds code that implemented a
SPECIALISE pragma was inadequate if the constraints solving
added let-bindings for dictionaries.  The result was that
we ended up with an unbound dictionary in a DFunUnfolding -- and
Lint didn't even check for that!

Fixing this was not entirely straightforward

* In DsBinds.dsSpec we use a new function
     TcEvidence.collectHsWrapBinders
  to pick off the lambda binders from the HsWapper

* dsWrapper now returns a (CoreExpr -> CoreExpr) function

* CoreUnfold.specUnfolding now takes a (CoreExpr -> CoreExpr)
  function it can use to specialise the unfolding.

On the whole the code is simpler than before.
parent c48595ee
......@@ -147,22 +147,28 @@ mkInlinableUnfolding dflags expr
expr' = simpleOptExpr expr
is_bot = isJust (exprBotStrictness_maybe expr')
specUnfolding :: DynFlags -> Subst -> [Var] -> [CoreExpr] -> Unfolding -> Unfolding
specUnfolding :: [Var] -> (CoreExpr -> CoreExpr) -> Arity -> Unfolding -> Unfolding
-- See Note [Specialising unfoldings]
-- specUnfolding subst new_bndrs spec_args unf
-- = \new_bndrs. (subst( unf ) spec_args)
-- specUnfolding spec_bndrs spec_app arity_decrease unf
-- = \spec_bndrs. spec_app( unf )
--
-- Precondition: in-scope(subst) `superset` fvs( spec_args )
specUnfolding _ subst new_bndrs spec_args
df@(DFunUnfolding { df_bndrs = bndrs, df_con = con , df_args = args })
= ASSERT2( length bndrs >= length spec_args, ppr df $$ ppr spec_args $$ ppr new_bndrs )
mkDFunUnfolding (new_bndrs ++ extra_bndrs) con
(map (substExpr spec_doc subst2) args)
specUnfolding spec_bndrs spec_app arity_decrease
df@(DFunUnfolding { df_bndrs = old_bndrs, df_con = con, df_args = args })
= ASSERT2( arity_decrease == count isId old_bndrs - count isId spec_bndrs, ppr df )
mkDFunUnfolding spec_bndrs con (map spec_arg args)
-- There is a hard-to-check assumption here that the spec_app has
-- enough applications to exactly saturate the old_bndrs
-- For DFunUnfoldings we transform
-- \old_bndrs. MkD <op1> ... <opn>
-- to
-- \new_bndrs. MkD (spec_app(\old_bndrs. <op1>)) ... ditto <opn>
-- The ASSERT checks the value part of that
where
subst1 = extendSubstList subst (bndrs `zip` spec_args)
(subst2, extra_bndrs) = substBndrs subst1 (dropList spec_args bndrs)
spec_arg arg = simpleOptExpr (spec_app (mkLams old_bndrs arg))
-- The beta-redexes created by spec_app will be
-- simplified away by simplOptExpr
specUnfolding _dflags subst new_bndrs spec_args
specUnfolding spec_bndrs spec_app arity_decrease
(CoreUnfolding { uf_src = src, uf_tmpl = tmpl
, uf_is_top = top_lvl
, uf_guidance = old_guidance })
......@@ -170,25 +176,19 @@ specUnfolding _dflags subst new_bndrs spec_args
, UnfWhen { ug_arity = old_arity
, ug_unsat_ok = unsat_ok
, ug_boring_ok = boring_ok } <- old_guidance
= let guidance = UnfWhen { ug_arity = old_arity - count isValArg spec_args
+ count isId new_bndrs
= let guidance = UnfWhen { ug_arity = old_arity - arity_decrease
, ug_unsat_ok = unsat_ok
, ug_boring_ok = boring_ok }
new_tmpl = simpleOptExpr $ mkLams new_bndrs $
mkApps (substExpr spec_doc subst tmpl) spec_args
-- The beta-redexes created here will be simplified
-- away by simplOptExpr in mkUnfolding
new_tmpl = simpleOptExpr (mkLams spec_bndrs (spec_app tmpl))
-- The beta-redexes created by spec_app will be
-- simplified away by simplOptExpr
in mkCoreUnfolding src top_lvl new_tmpl guidance
specUnfolding _ _ _ _ _ = noUnfolding
specUnfolding _ _ _ _ = noUnfolding
spec_doc :: SDoc
spec_doc = text "specUnfolding"
{-
Note [Specialising unfoldings]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
{- Note [Specialising unfoldings]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When we specialise a function for some given type-class arguments, we use
specUnfolding to specialise its unfolding. Some important points:
......@@ -997,6 +997,13 @@ found that the WorkWrap phase thought that
y = case x of F# v -> F# (v +# v)
was certainlyWillInline, so the addition got duplicated.
Note [certainlyWillInline: INLINABLE]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
certainlyWillInline /must/ return Nothing for a large INLINABLE thing,
even though we have a stable inlining, so that strictness w/w takes
place. It makes a big difference to efficiency, and the w/w pass knows
how to transfer the INLINABLE info to the worker; see WorkWrap
Note [Worker-wrapper for INLINABLE functions]
************************************************************************
* *
......
......@@ -616,8 +616,8 @@ dsCmd _ids local_vars _stack_ty _res_ty (HsCmdArrForm op _ _ args) env_ids = do
dsCmd ids local_vars stack_ty res_ty (HsCmdWrap wrap cmd) env_ids = do
(core_cmd, env_ids') <- dsCmd ids local_vars stack_ty res_ty cmd env_ids
wrapped_cmd <- dsHsWrapper wrap core_cmd
return (wrapped_cmd, env_ids')
core_wrap <- dsHsWrapper wrap
return (core_wrap core_cmd, env_ids')
dsCmd _ _ _ _ _ c = pprPanic "dsCmd" (ppr c)
......
......@@ -127,9 +127,10 @@ dsHsBind dflags
= do { (args, body) <- matchWrapper
(FunRhs (noLoc $ idName fun) Prefix)
Nothing matches
; core_wrap <- dsHsWrapper co_fn
; let body' = mkOptTickBox tick body
; rhs <- dsHsWrapper co_fn (mkLams args body')
; let core_binds@(id,_) = makeCorePair dflags fun False 0 rhs
rhs = core_wrap (mkLams args body')
core_binds@(id,_) = makeCorePair dflags fun False 0 rhs
force_var =
if xopt LangExt.Strict dflags
&& matchGroupArity matches == 0 -- no need to force lambdas
......@@ -170,12 +171,13 @@ dsHsBind dflags
do { (_, bind_prs) <- ds_lhs_binds binds
; let core_bind = Rec bind_prs
; ds_binds <- dsTcEvBinds_s ev_binds
; rhs <- dsHsWrapper wrap $ -- Usually the identity
; core_wrap <- dsHsWrapper wrap -- Usually the identity
; let rhs = core_wrap $
mkLams tyvars $ mkLams dicts $
mkCoreLets ds_binds $
Let core_bind $
Var local
; (spec_binds, rules) <- dsSpecs rhs prags
; let global' = addIdSpecialisations global rules
......@@ -195,10 +197,10 @@ dsHsBind dflags
, abe_poly = global
, abe_mono = local
, abe_prags = prags })
= do { rhs <- dsHsWrapper wrap (Var local)
= do { core_wrap <- dsHsWrapper wrap
; return (makeCorePair dflags global
(isDefaultMethod prags)
0 rhs) }
0 (core_wrap (Var local))) }
; main_binds <- mapM mk_bind exports
; ds_binds <- dsTcEvBinds_s ev_binds
......@@ -238,11 +240,11 @@ dsHsBind dflags
, abe_mono = local, abe_prags = spec_prags })
-- See Note [AbsBinds wrappers] in HsBinds
= do { tup_id <- newSysLocalDs tup_ty
; rhs <- dsHsWrapper wrap $
mkLams tyvars $ mkLams dicts $
; core_wrap <- dsHsWrapper wrap
; let rhs = core_wrap $ mkLams tyvars $ mkLams dicts $
mkTupleSelector all_locals local tup_id $
mkVarApps (Var poly_tup_id) (tyvars ++ dicts)
; let rhs_for_spec = Let (NonRec poly_tup_id poly_tup_rhs) rhs
rhs_for_spec = Let (NonRec poly_tup_id poly_tup_rhs) rhs
; (spec_binds, rules) <- dsSpecs rhs_for_spec spec_prags
; let global' = (global `setInlinePragma` defaultInlinePragma)
`addIdSpecialisations` rules
......@@ -317,10 +319,10 @@ dsHsBind dflags (AbsBindsSig { abs_tvs = tyvars, abs_ev_vars = dicts
do { (args, body) <- matchWrapper
(FunRhs (noLoc $ idName global) Prefix)
Nothing matches
; core_wrap <- dsHsWrapper co_fn
; let body' = mkOptTickBox tick body
; fun_rhs <- dsHsWrapper co_fn $
mkLams args body'
; let force_vars
fun_rhs = core_wrap (mkLams args body')
force_vars
| xopt LangExt.Strict dflags
, matchGroupArity matches == 0 -- no need to force lambdas
= [global]
......@@ -629,32 +631,39 @@ dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl))
; let poly_name = idName poly_id
spec_occ = mkSpecOcc (getOccName poly_name)
spec_name = mkInternalName uniq spec_occ (getSrcSpan poly_name)
; (bndrs, ds_lhs) <- liftM collectBinders
(dsHsWrapper spec_co (Var poly_id))
; let spec_ty = mkLamTypes bndrs (exprType ds_lhs)
(spec_bndrs, spec_app) = collectHsWrapBinders spec_co
-- spec_co looks like
-- \spec_bndrs. [] spec_args
-- perhaps with the body of the lambda wrapped in some WpLets
-- E.g. /\a \(d:Eq a). let d2 = $df d in [] (Maybe a) d2
; core_app <- dsHsWrapper spec_app
; let ds_lhs = core_app (Var poly_id)
spec_ty = mkLamTypes spec_bndrs (exprType ds_lhs)
; -- pprTrace "dsRule" (vcat [ text "Id:" <+> ppr poly_id
-- , text "spec_co:" <+> ppr spec_co
-- , text "ds_rhs:" <+> ppr ds_lhs ]) $
case decomposeRuleLhs bndrs ds_lhs of {
case decomposeRuleLhs spec_bndrs ds_lhs of {
Left msg -> do { warnDs NoReason msg; return Nothing } ;
Right (rule_bndrs, _fn, args) -> do
{ dflags <- getDynFlags
; this_mod <- getModule
; let fn_unf = realIdUnfolding poly_id
unf_fvs = stableUnfoldingVars fn_unf `orElse` emptyVarSet
in_scope = mkInScopeSet (unf_fvs `unionVarSet` exprsFreeVars args)
spec_unf = specUnfolding dflags (mkEmptySubst in_scope) bndrs args fn_unf
spec_unf = specUnfolding spec_bndrs core_app arity_decrease fn_unf
spec_id = mkLocalId spec_name spec_ty
`setInlinePragma` inl_prag
`setIdUnfolding` spec_unf
arity_decrease = count isValArg args - count isId spec_bndrs
; rule <- dsMkUserRule this_mod is_local_id
(mkFastString ("SPEC " ++ showPpr dflags poly_name))
rule_act poly_name
rule_bndrs args
(mkVarApps (Var spec_id) bndrs)
(mkVarApps (Var spec_id) spec_bndrs)
; spec_rhs <- dsHsWrapper spec_co poly_rhs
; let spec_rhs = mkLams spec_bndrs (core_app poly_rhs)
-- Commented out: see Note [SPECIALISE on INLINE functions]
-- ; when (isInlinePragma id_inl)
......@@ -1037,22 +1046,25 @@ a mistake. That's what the isDeadBinder call detects.
-}
dsHsWrapper :: HsWrapper -> CoreExpr -> DsM CoreExpr
dsHsWrapper WpHole e = return e
dsHsWrapper (WpTyApp ty) e = return $ App e (Type ty)
dsHsWrapper (WpLet ev_binds) e = do bs <- dsTcEvBinds ev_binds
return (mkCoreLets bs e)
dsHsWrapper (WpCompose c1 c2) e = do { e1 <- dsHsWrapper c2 e
; dsHsWrapper c1 e1 }
dsHsWrapper (WpFun c1 c2 t1) e = do { x <- newSysLocalDs t1
; e1 <- dsHsWrapper c1 (Var x)
; e2 <- dsHsWrapper c2 (mkCoreAppDs (text "dsHsWrapper") e e1)
; return (Lam x e2) }
dsHsWrapper (WpCast co) e = ASSERT(coercionRole co == Representational)
return $ mkCastDs e co
dsHsWrapper (WpEvLam ev) e = return $ Lam ev e
dsHsWrapper (WpTyLam tv) e = return $ Lam tv e
dsHsWrapper (WpEvApp tm) e = liftM (App e) (dsEvTerm tm)
dsHsWrapper :: HsWrapper -> DsM (CoreExpr -> CoreExpr)
dsHsWrapper WpHole = return $ \e -> e
dsHsWrapper (WpTyApp ty) = return $ \e -> App e (Type ty)
dsHsWrapper (WpEvLam ev) = return $ Lam ev
dsHsWrapper (WpTyLam tv) = return $ Lam tv
dsHsWrapper (WpLet ev_binds) = do { bs <- dsTcEvBinds ev_binds
; return (mkCoreLets bs) }
dsHsWrapper (WpCompose c1 c2) = do { w1 <- dsHsWrapper c1
; w2 <- dsHsWrapper c2
; return (w1 . w2) }
dsHsWrapper (WpFun c1 c2 t1) = do { x <- newSysLocalDs t1
; w1 <- dsHsWrapper c1
; w2 <- dsHsWrapper c2
; let app f a = mkCoreAppDs (text "dsHsWrapper") f a
; return (\e -> Lam x (w2 (app e (w1 (Var x))))) }
dsHsWrapper (WpCast co) = ASSERT(coercionRole co == Representational)
return $ \e -> mkCastDs e co
dsHsWrapper (WpEvApp tm) = do { core_tm <- dsEvTerm tm
; return (\e -> App e core_tm) }
--------------------------------------
dsTcEvBinds_s :: [TcEvBinds] -> DsM [CoreBind]
......
......@@ -214,8 +214,9 @@ dsExpr (HsOverLit lit) = dsOverLit lit
dsExpr (HsWrap co_fn e)
= do { e' <- dsExpr e
; wrapped_e <- dsHsWrapper co_fn e'
; wrap' <- dsHsWrapper co_fn
; dflags <- getDynFlags
; let wrapped_e = wrap' e'
; warnAboutIdentities dflags e' (exprType wrapped_e)
; return wrapped_e }
......@@ -748,9 +749,11 @@ dsSyntaxExpr (SyntaxExpr { syn_expr = expr
, syn_arg_wraps = arg_wraps
, syn_res_wrap = res_wrap })
arg_exprs
= do { args <- zipWithM dsHsWrapper arg_wraps arg_exprs
; fun <- dsExpr expr
; dsHsWrapper res_wrap $ mkApps fun args }
= do { fun <- dsExpr expr
; core_arg_wraps <- mapM dsHsWrapper arg_wraps
; core_res_wrap <- dsHsWrapper res_wrap
; let wrapped_args = zipWith ($) core_arg_wraps arg_exprs
; return (core_res_wrap (mkApps fun wrapped_args)) }
findField :: [LHsRecField Id arg] -> Name -> [arg]
findField rbinds sel
......
......@@ -253,8 +253,9 @@ matchCoercion (var:vars) ty (eqns@(eqn1:_))
; var' <- newUniqueId var pat_ty'
; match_result <- match (var':vars) ty $
map (decomposeFirstPat getCoPat) eqns
; rhs' <- dsHsWrapper co (Var var)
; return (mkCoLetMatchResult (NonRec var' rhs') match_result) }
; core_wrap <- dsHsWrapper co
; let bind = NonRec var' (core_wrap (Var var))
; return (mkCoLetMatchResult bind match_result) }
matchCoercion _ _ _ = panic "matchCoercion"
matchView :: [Id] -> Type -> [EquationInfo] -> DsM MatchResult
......
......@@ -1024,7 +1024,8 @@ to substitute sc -> sc_flt in the RHS
-}
specBind :: SpecEnv -- Use this for RHSs
-> CoreBind
-> CoreBind -- Binders are already cloned by cloneBindSM,
-- but RHSs are un-processed
-> UsageDetails -- Info on how the scope of the binding
-> SpecM ([CoreBind], -- New bindings
UsageDetails) -- And info to pass upstream
......@@ -1093,9 +1094,9 @@ specBind rhs_env (Rec pairs) body_uds
---------------------------
specDefns :: SpecEnv
-> UsageDetails -- Info on how it is used in its scope
-> [(Id,CoreExpr)] -- The things being bound and their un-processed RHS
-> SpecM ([Id], -- Original Ids with RULES added
[(Id,CoreExpr)], -- Extra, specialised bindings
-> [(OutId,InExpr)] -- The things being bound and their un-processed RHS
-> SpecM ([OutId], -- Original Ids with RULES added
[(OutId,OutExpr)], -- Extra, specialised bindings
UsageDetails) -- Stuff to fling upwards from the specialised versions
-- Specialise a list of bindings (the contents of a Rec), but flowing usages
......@@ -1114,7 +1115,7 @@ specDefns env uds ((bndr,rhs):pairs)
---------------------------
specDefn :: SpecEnv
-> UsageDetails -- Info on how it is used in its scope
-> Id -> CoreExpr -- The thing being bound and its un-processed RHS
-> OutId -> InExpr -- The thing being bound and its un-processed RHS
-> SpecM (Id, -- Original Id with added RULES
[(Id,CoreExpr)], -- Extra, specialised bindings
UsageDetails) -- Stuff to fling upwards from the specialised versions
......@@ -1140,7 +1141,7 @@ specCalls :: Maybe Module -- Just this_mod => specialising imported fn
-> SpecEnv
-> [CoreRule] -- Existing RULES for the fn
-> [CallInfo]
-> Id -> CoreExpr
-> OutId -> InExpr
-> SpecM ([CoreRule], -- New RULES for the fn
[(Id,CoreExpr)], -- Extra, specialised bindings
UsageDetails) -- New usage details from the specialised RHSs
......@@ -1317,17 +1318,11 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
= (inl_prag { inl_inline = EmptyInlineSpec }, noUnfolding)
| otherwise
= (inl_prag, specUnfolding dflags spec_unf_subst poly_tyvars
spec_unf_args fn_unf)
spec_unf_args = ty_args ++ spec_dict_args
spec_unf_subst = CoreSubst.setInScope (se_subst env)
(CoreSubst.substInScope (se_subst rhs_env2))
-- Extend the in-scope set to satisfy the precondition of
-- specUnfolding, namely that in-scope(unf_subst) includes
-- the free vars of spec_unf_args. The in-scope set of rhs_env2
-- is just the ticket; but the actual substitution we want is
-- the same old one from 'env'
= (inl_prag, specUnfolding poly_tyvars spec_app
arity_decrease fn_unf)
arity_decrease = length spec_dict_args
spec_app e = (e `mkApps` ty_args) `mkApps` spec_dict_args
--------------------------------------
-- Adding arity information just propagates it a bit faster
......
......@@ -7,7 +7,7 @@ module TcEvidence (
-- HsWrapper
HsWrapper(..),
(<.>), mkWpTyApps, mkWpEvApps, mkWpEvVarApps, mkWpTyLams,
mkWpLams, mkWpLet, mkWpCastN, mkWpCastR,
mkWpLams, mkWpLet, mkWpCastN, mkWpCastR, collectHsWrapBinders,
mkWpFun, mkWpFuns, idHsWrapper, isIdHsWrapper, pprHsWrapper,
-- Evidence bindings
......@@ -267,6 +267,23 @@ isIdHsWrapper :: HsWrapper -> Bool
isIdHsWrapper WpHole = True
isIdHsWrapper _ = False
collectHsWrapBinders :: HsWrapper -> ([Var], HsWrapper)
-- Collect the outer lambda binders of a HsWrapper,
-- stopping as soon as you get to a non-lambda binder
collectHsWrapBinders wrap = go wrap []
where
-- go w ws = collectHsWrapBinders (w <.> w1 <.> ... <.> wn)
go :: HsWrapper -> [HsWrapper] -> ([Var], HsWrapper)
go (WpEvLam v) wraps = add_lam v (gos wraps)
go (WpTyLam v) wraps = add_lam v (gos wraps)
go (WpCompose w1 w2) wraps = go w1 (w2:wraps)
go wrap wraps = ([], foldl (<.>) wrap wraps)
gos [] = ([], WpHole)
gos (w:ws) = go w ws
add_lam v (vs,w) = (v:vs, w)
{-
************************************************************************
* *
......
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -O #-}
module T12944 () where
class AdditiveGroup v where
(^+^) :: v -> v -> v
negateV :: v -> v
(^-^) :: v -> v -> v
v ^-^ v' = v ^+^ negateV v'
class AdditiveGroup v => VectorSpace v where
type Scalar v :: *
(*^) :: Scalar v -> v -> v
data Poly1 a = Poly1 a a
data IntOfLog poly a = IntOfLog !a !(poly a)
instance Num a => AdditiveGroup (Poly1 a) where
{-# INLINE (^+^) #-}
{-# INLINE negateV #-}
Poly1 a b ^+^ Poly1 a' b' = Poly1 (a + a') (b + b')
negateV (Poly1 a b) = Poly1 (negate a) (negate b)
instance (AdditiveGroup (poly a), Num a) => AdditiveGroup (IntOfLog poly a) where
{-# INLINE (^+^) #-}
{-# INLINE negateV #-}
IntOfLog k p ^+^ IntOfLog k' p' = IntOfLog (k + k') (p ^+^ p')
negateV (IntOfLog k p) = IntOfLog (negate k) (negateV p)
{-# SPECIALISE instance Num a => AdditiveGroup (IntOfLog Poly1 a) #-}
-- This pragmas casued the crash
instance (VectorSpace (poly a), Scalar (poly a) ~ a, Num a) => VectorSpace (IntOfLog poly a) where
type Scalar (IntOfLog poly a) = a
s *^ IntOfLog k p = IntOfLog (s * k) (s *^ p)
......@@ -105,3 +105,4 @@ test('T10767', normal, compile, [''])
test('DsStrictWarn', normal, compile, [''])
test('T10662', normal, compile, ['-Wall'])
test('T11414', normal, compile, [''])
test('T12944', normal, compile, [''])
{-# LANGUAGE KindSignatures, TypeFamilies, GADTs, DataKinds #-}
module T12444a where
type family F a :: *
type instance F (Maybe x) = Maybe (F x)
foo :: a -> Maybe (F a)
foo = undefined
-- bad :: (F (Maybe t) ~ t) => Maybe t -> [Maybe t]
bad x = [x, foo x]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment