Commit 88e3c815 authored by Simon Peyton Jones's avatar Simon Peyton Jones Committed by Marge Bot

Fix specialisation for DFuns

When specialising a DFun we must take care to saturate the
unfolding.  See Note [Specialising DFuns] in Specialise.

Fixes #18120
parent 7a763cff
......@@ -1362,6 +1362,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
inl_prag = idInlinePragma fn
inl_act = inlinePragmaActivation inl_prag
is_local = isLocalId fn
is_dfun = isDFunId fn
-- Figure out whether the function has an INLINE pragma
-- See Note [Inline specialisations]
......@@ -1384,22 +1385,34 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
spec_call :: SpecInfo -- Accumulating parameter
-> CallInfo -- Call instance
-> SpecM SpecInfo
spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) (CI { ci_key = call_args })
spec_call spec_acc@(rules_acc, pairs_acc, uds_acc) _ci@(CI { ci_key = call_args })
= -- See Note [Specialising Calls]
do { ( useful, rhs_env2, leftover_bndrs
do { let all_call_args | is_dfun = call_args ++ repeat UnspecArg
| otherwise = call_args
-- See Note [Specialising DFuns]
; ( useful, rhs_env2, leftover_bndrs
, rule_bndrs, rule_lhs_args
, spec_bndrs, dx_binds, spec_args) <- specHeader env rhs_bndrs call_args
, spec_bndrs1, dx_binds, spec_args) <- specHeader env rhs_bndrs all_call_args
-- ; pprTrace "spec_call" (vcat [ text "call info: " <+> ppr _ci
-- , text "useful: " <+> ppr useful
-- , text "rule_bndrs:" <+> ppr rule_bndrs
-- , text "lhs_args: " <+> ppr rule_lhs_args
-- , text "spec_bndrs:" <+> ppr spec_bndrs1
-- , text "spec_args: " <+> ppr spec_args
-- , text "dx_binds: " <+> ppr dx_binds
-- , text "rhs_env2: " <+> ppr (se_subst rhs_env2)
-- , ppr dx_binds ]) $
-- return ()
; dflags <- getDynFlags
; if not useful -- No useful specialisation
|| already_covered dflags rules_acc rule_lhs_args
then return spec_acc
else -- pprTrace "spec_call" (vcat [ ppr _call_info, ppr fn, ppr rhs_dict_ids
-- , text "rhs_env2" <+> ppr (se_subst rhs_env2)
-- , ppr dx_binds ]) $
else
do { -- Run the specialiser on the specialised RHS
-- The "1" suffix is before we maybe add the void arg
; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs ++ leftover_bndrs) rhs_body
; (spec_rhs1, rhs_uds) <- specLam rhs_env2 (spec_bndrs1 ++ leftover_bndrs) rhs_body
; let spec_fn_ty1 = exprType spec_rhs1
-- Maybe add a void arg to the specialised function,
......@@ -1407,14 +1420,13 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
-- See Note [Specialisations Must Be Lifted]
-- C.f. GHC.Core.Opt.WorkWrap.Utils.mkWorkerArgs
add_void_arg = isUnliftedType spec_fn_ty1 && not (isJoinId fn)
(spec_rhs, spec_fn_ty, rule_rhs_args)
| add_void_arg = ( Lam voidArgId spec_rhs1
, mkVisFunTy voidPrimTy spec_fn_ty1
, voidPrimId : spec_bndrs)
| otherwise = (spec_rhs1, spec_fn_ty1, spec_bndrs)
arity_decr = count isValArg rule_lhs_args - count isId rule_rhs_args
join_arity_decr = length rule_lhs_args - length rule_rhs_args
(spec_bndrs, spec_rhs, spec_fn_ty)
| add_void_arg = ( voidPrimId : spec_bndrs1
, Lam voidArgId spec_rhs1
, mkVisFunTy voidPrimTy spec_fn_ty1)
| otherwise = (spec_bndrs1, spec_rhs1, spec_fn_ty1)
join_arity_decr = length rule_lhs_args - length spec_bndrs
spec_join_arity | Just orig_join_arity <- isJoinId_maybe fn
= Just (orig_join_arity - join_arity_decr)
| otherwise
......@@ -1449,7 +1461,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
(idName fn)
rule_bndrs
rule_lhs_args
(mkVarApps (Var spec_fn) rule_rhs_args)
(mkVarApps (Var spec_fn) spec_bndrs)
spec_rule
= case isJoinId_maybe fn of
......@@ -1472,15 +1484,15 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
= (inl_prag { inl_inline = NoUserInline }, noUnfolding)
| otherwise
= (inl_prag, specUnfolding dflags fn spec_bndrs spec_app arity_decr fn_unf)
spec_app e = e `mkApps` spec_args
= (inl_prag, specUnfolding dflags spec_bndrs (`mkApps` spec_args)
rule_lhs_args fn_unf)
--------------------------------------
-- Adding arity information just propagates it a bit faster
-- See Note [Arity decrease] in GHC.Core.Opt.Simplify
-- Copy InlinePragma information from the parent Id.
-- So if f has INLINE[1] so does spec_fn
arity_decr = count isValArg rule_lhs_args - count isId spec_bndrs
spec_f_w_arity = spec_fn `setIdArity` max 0 (fn_arity - arity_decr)
`setInlinePragma` spec_inl_prag
`setIdUnfolding` spec_unf
......@@ -1498,8 +1510,19 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
, spec_uds `plusUDs` uds_acc
) } }
{- Note [Specialisation Must Preserve Sharing]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
{- Note [Specialising DFuns]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DFuns have a special sort of unfolding (DFunUnfolding), and these are
hard to specialise a DFunUnfolding to give another DFunUnfolding
unless the DFun is fully applied (#18120). So, in the case of DFunIds
we simply extend the CallKey with trailing UnspecArgs, so we'll
generate a rule that completely saturates the DFun.
There is an ASSERT that checks this, in the DFunUnfolding case of
GHC.Core.Unfold.specUnfolding.
Note [Specialisation Must Preserve Sharing]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider a function:
f :: forall a. Eq a => a -> blah
......@@ -2089,7 +2112,7 @@ isSpecDict _ = False
-- -- Specialised function helpers
-- , [c, i, x]
-- , [dShow1 = $dfShow dShowT2]
-- , [T1, T2, dEqT1, dShow1]
-- , [T1, T2, c, i, dEqT1, dShow1]
-- )
specHeader
:: SpecEnv
......@@ -2106,12 +2129,13 @@ specHeader
-- RULE helpers
, [OutBndr] -- Binders for the RULE
, [CoreArg] -- Args for the LHS of the rule
, [OutExpr] -- Args for the LHS of the rule
-- Specialised function helpers
, [OutBndr] -- Binders for $sf
, [DictBind] -- Auxiliary dictionary bindings
, [OutExpr] -- Specialised arguments for unfolding
-- Same length as "args for LHS of rule"
)
-- We want to specialise on type 'T1', and so we must construct a substitution
......
......@@ -173,47 +173,47 @@ mkInlinableUnfolding dflags expr
where
expr' = simpleOptExpr dflags expr
specUnfolding :: DynFlags -> Id -> [Var] -> (CoreExpr -> CoreExpr) -> Arity
specUnfolding :: DynFlags
-> [Var] -> (CoreExpr -> CoreExpr)
-> [CoreArg] -- LHS arguments in the RULE
-> Unfolding -> Unfolding
-- See Note [Specialising unfoldings]
-- specUnfolding spec_bndrs spec_app arity_decrease unf
-- = \spec_bndrs. spec_app( unf )
-- specUnfolding spec_bndrs spec_args unf
-- = \spec_bndrs. unf spec_args
--
specUnfolding dflags fn spec_bndrs spec_app arity_decrease
specUnfolding dflags spec_bndrs spec_app rule_lhs_args
df@(DFunUnfolding { df_bndrs = old_bndrs, df_con = con, df_args = args })
= ASSERT2( arity_decrease == count isId old_bndrs - count isId spec_bndrs
, ppr df $$ ppr spec_bndrs $$ ppr (spec_app (Var fn)) $$ ppr arity_decrease )
= ASSERT2( rule_lhs_args `equalLength` old_bndrs
, ppr df $$ ppr rule_lhs_args )
-- For this ASSERT see Note [DFunUnfoldings] in GHC.Core.Opt.Specialise
mkDFunUnfolding spec_bndrs con (map spec_arg args)
-- There is a hard-to-check assumption here that the spec_app has
-- enough applications to exactly saturate the old_bndrs
-- For DFunUnfoldings we transform
-- \old_bndrs. MkD <op1> ... <opn>
-- \obs. MkD <op1> ... <opn>
-- to
-- \new_bndrs. MkD (spec_app(\old_bndrs. <op1>)) ... ditto <opn>
-- The ASSERT checks the value part of that
-- \sbs. MkD ((\obs. <op1>) spec_args) ... ditto <opn>
where
spec_arg arg = simpleOptExpr dflags (spec_app (mkLams old_bndrs arg))
spec_arg arg = simpleOptExpr dflags $
spec_app (mkLams old_bndrs arg)
-- The beta-redexes created by spec_app will be
-- simplified away by simplOptExpr
specUnfolding dflags _ spec_bndrs spec_app arity_decrease
specUnfolding dflags spec_bndrs spec_app rule_lhs_args
(CoreUnfolding { uf_src = src, uf_tmpl = tmpl
, uf_is_top = top_lvl
, uf_guidance = old_guidance })
| isStableSource src -- See Note [Specialising unfoldings]
, UnfWhen { ug_arity = old_arity
, ug_unsat_ok = unsat_ok
, ug_boring_ok = boring_ok } <- old_guidance
= let guidance = UnfWhen { ug_arity = old_arity - arity_decrease
, ug_unsat_ok = unsat_ok
, ug_boring_ok = boring_ok }
new_tmpl = simpleOptExpr dflags (mkLams spec_bndrs (spec_app tmpl))
-- The beta-redexes created by spec_app will be
-- simplified away by simplOptExpr
, UnfWhen { ug_arity = old_arity } <- old_guidance
= mkCoreUnfolding src top_lvl new_tmpl
(old_guidance { ug_arity = old_arity - arity_decrease })
where
new_tmpl = simpleOptExpr dflags $
mkLams spec_bndrs $
spec_app tmpl -- The beta-redexes created by spec_app
-- will besimplified away by simplOptExpr
arity_decrease = count isValArg rule_lhs_args - count isId spec_bndrs
in mkCoreUnfolding src top_lvl new_tmpl guidance
specUnfolding _ _ _ _ _ _ = noUnfolding
specUnfolding _ _ _ _ _ = noUnfolding
{- Note [Specialising unfoldings]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -693,20 +693,19 @@ dsSpec mb_poly_rhs (L loc (SpecPrag poly_id spec_co spec_inl))
dflags <- getDynFlags
; case decomposeRuleLhs dflags spec_bndrs ds_lhs of {
Left msg -> do { warnDs NoReason msg; return Nothing } ;
Right (rule_bndrs, _fn, args) -> do
Right (rule_bndrs, _fn, rule_lhs_args) -> do
{ this_mod <- getModule
; let fn_unf = realIdUnfolding poly_id
spec_unf = specUnfolding dflags poly_id spec_bndrs core_app arity_decrease fn_unf
spec_unf = specUnfolding dflags spec_bndrs core_app rule_lhs_args fn_unf
spec_id = mkLocalId spec_name spec_ty
`setInlinePragma` inl_prag
`setIdUnfolding` spec_unf
arity_decrease = count isValArg args - count isId spec_bndrs
; rule <- dsMkUserRule this_mod is_local_id
(mkFastString ("SPEC " ++ showPpr dflags poly_name))
rule_act poly_name
rule_bndrs args
rule_bndrs rule_lhs_args
(mkVarApps (Var spec_id) spec_bndrs)
; let spec_rhs = mkLams spec_bndrs (core_app poly_rhs)
......
......@@ -634,7 +634,6 @@ to connect the two, something like
This wrapper is put in the TcSpecPrag, in the ABExport record of
the AbsBinds.
f :: (Eq a, Ix b) => a -> b -> Bool
{-# SPECIALISE f :: (Ix p, Ix q) => Int -> (p,q) -> Bool #-}
f = <poly_rhs>
......@@ -662,8 +661,6 @@ Note that
* The RHS of f_spec, <poly_rhs> has a *copy* of 'binds', so that it
can fully specialise it.
From the TcSpecPrag, in GHC.HsToCore.Binds we generate a binding for f_spec and a RULE:
f_spec :: Int -> b -> Int
......@@ -702,14 +699,14 @@ Some wrinkles
So we simply do this:
- Generate a constraint to check that the specialised type (after
skolemiseation) is equal to the instantiated function type.
skolemisation) is equal to the instantiated function type.
- But *discard* the evidence (coercion) for that constraint,
so that we ultimately generate the simpler code
f_spec :: Int -> F Int
f_spec = <f rhs> Int dNumInt
RULE: forall d. f Int d = f_spec
You can see this discarding happening in
You can see this discarding happening in tcSpecPrag
3. Note that the HsWrapper can transform *any* function with the right
type prefix
......
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module Bug where
import Data.Kind
type family
AllF (c :: k -> Constraint) (xs :: [k]) :: Constraint where
AllF _c '[] = ()
AllF c (x ': xs) = (c x, All c xs)
class (AllF c xs, SListI xs) => All (c :: k -> Constraint) (xs :: [k]) where
instance All c '[] where
instance (c x, All c xs) => All c (x ': xs) where
class Top x
instance Top x
type SListI = All Top
class All SListI (Code a) => Generic (a :: Type) where
type Code a :: [[Type]]
data T = MkT Int
instance Generic T where
type Code T = '[ '[Int] ]
......@@ -318,3 +318,4 @@ test('T17966',
test('T17810', normal, multimod_compile, ['T17810', '-fspecialise-aggressively -dcore-lint -O -v0'])
test('T18013', normal, multimod_compile, ['T18013', '-v0 -O'])
test('T18098', normal, compile, ['-dcore-lint -O2'])
test('T18120', normal, compile, ['-dcore-lint -O'])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment