Cast worker wrapper doesn't work with INLINABLE
Consider this code, distilled from this comment in #19874 (closed).
newtype Wombat a = Wombat (a->a)
f :: Num a => Bool -> Wombat a
{-# INLINEABLE f #-}
f True = f False
f False = Wombat (\x -> x+1)
With HEAD we get this not very nice code.
Rec {
-- RHS size: {terms: 18, types: 9, coercions: 14, joins: 0/0}
f [InlPrag=INLINABLE, Occ=LoopBreaker]
:: forall a. Num a => Bool -> Wombat a
[GblId,
Arity=3,
Str=<SP(SCS(C1(L)),A,A,A,A,A,L)><1L><L>,
Unf=Unf{Src=InlineStable, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True, Guidance=IF_ARGS [60 70] 240 60
Tmpl= \ (@a_aAg)
($dNum_aAh :: Num a_aAg)
(ds_dAv [Occ=OnceL1!] :: Bool) ->
(\ (eta_B0 [Occ=Once2] :: a_aAg) ->
case ds_dAv of {
False ->
+ @a_aAg $dNum_aAh eta_B0 (fromInteger @a_aAg $dNum_aAh 1);
True ->
((f @a_aAg $dNum_aAh GHC.Types.False)
`cast` (Foo.N:Wombat[0] <a_aAg>_R
:: Wombat a_aAg ~R# (a_aAg -> a_aAg)))
eta_B0
})
`cast` (Sym (Foo.N:Wombat[0] <a_aAg>_R)
:: (a_aAg -> a_aAg) ~R# Wombat a_aAg)}]
f = (\ (@a_aAg)
($dNum_aAh :: Num a_aAg)
(ds_dAv :: Bool)
(eta_B0 :: a_aAg) ->
case ds_dAv of {
False ->
+ @a_aAg $dNum_aAh eta_B0 (fromInteger @a_aAg $dNum_aAh lvl_rAF);
True ->
((f @a_aAg $dNum_aAh GHC.Types.False)
`cast` (Foo.N:Wombat[0] <a_aAg>_R
:: Wombat a_aAg ~R# (a_aAg -> a_aAg)))
eta_B0
})
`cast` (forall (a :: <*>_N).
<Num a>_R
%<'Many>_N ->_R <Bool>_R
%<'Many>_N ->_R Sym (Foo.N:Wombat[0] <a>_R)
:: (forall {a}. Num a => Bool -> a -> a)
~R# (forall {a}. Num a => Bool -> Wombat a))
end Rec }
Notice that our cast worker/wrapper stuff isn't happening. (See Note [Cast worker/wrappers]
in GHC.Core.Opt.Simplify for what cast
worker/wrapper is).
The problem is the stable unfolding. What we'd like to get is this:
Rec {
-- RHS size: {terms: 18, types: 9, coercions: 0, joins: 0/0}
Foo.f1 [InlPrag=INLINABLE, Occ=LoopBreaker]
:: forall {a}. Num a => Bool -> a -> a
[GblId,
Arity=3,
Str=<SP(SCS(C1(L)),A,A,A,A,A,L)><1L><L>,
Unf=Unf{Src=InlineStable, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True, Guidance=IF_ARGS [60 70 0] 230 0
Tmpl= \ (@a_aAn)
($dNum_aAo :: Num a_aAn)
(ds_dAC [Occ=Once1!] :: Bool)
(eta_B0 [Occ=Once2] :: a_aAn) ->
case ds_dAC of {
False ->
+ @a_aAn $dNum_aAo eta_B0 (fromInteger @a_aAn $dNum_aAo 1);
True -> Foo.f1 @a_aAn $dNum_aAo GHC.Types.False eta_B0
}}]
Foo.f1
= \ (@a_aAn)
($dNum_aAo :: Num a_aAn)
(ds_dAC :: Bool)
(eta_B0 :: a_aAn) ->
case ds_dAC of {
False ->
+ @a_aAn $dNum_aAo eta_B0 (fromInteger @a_aAn $dNum_aAo lvl_rAO);
True -> Foo.f1 @a_aAn $dNum_aAo GHC.Types.False eta_B0
}
end Rec }
-- RHS size: {terms: 1, types: 0, coercions: 12, joins: 0/0}
f :: forall a. Num a => Bool -> Wombat a
[GblId,
Arity=3,
Str=<SP(SCS(C1(L)),A,A,A,A,A,L)><1L><L>,
f = Foo.f1
`cast` (forall (a :: <*>_N).
<Num a>_R
%<'Many>_N ->_R <Bool>_R
%<'Many>_N ->_R Sym (Foo.N:Wombat[0] <a>_R)
:: (forall {a}. Num a => Bool -> a -> a)
~R# (forall {a}. Num a => Bool -> Wombat a))
That is, do cast worker/wrapper, and move the stable unfolding to the worker.