SpecConstr default settings are not sufficient for stream fusion
Summary
The three_partitions
example fuses completely, but the resulting Core contains some glaring call pattern specialization opportunities. Unfortunately, I need to use both -fspec-constr-keen
and -fspec-constr-count=4
to optimize it properly. Doing that speeds it up by a factor of 2.75x!
Steps to reproduce
module T (three_partitions) where
import Prelude hiding (concatMap, Either (..), Maybe (..)) --, foldl', unlines, (++), map)
data Stream a = forall s. Stream (s -> Step s a) !s
data Step s a = Yield a !s | Done | Skip !s
data Tuple a b = !a :!: !b
data Either a b = Left !a | Right !b
data Maybe a = Nothing | Just !a
stream :: [a] -> Stream a
stream = Stream next where
next [] = Done
next (x:xs) = Yield x xs
{-# INLINE [1] stream #-}
unstream :: Stream a -> [a]
unstream (Stream next s0) = go s0 where
go !s = case next s of
Yield x s' -> x : go s'
Done -> []
Skip s' -> go s'
{-# INLINE [1] unstream #-}
{-# RULES "stream/unstream" forall s. stream (unstream s) = s #-}
concatMap :: (a -> Stream b) -> Stream a -> Stream b
concatMap f (Stream nexto s0) = Stream next (Left s0) where
next (Left !so) = case nexto so of
Yield x so' -> Skip (Right (so' :!: f x))
Done -> Done
Skip so' -> Skip (Left so')
next (Right (!so :!: Stream nexti si)) = case nexti si of
Yield y si' -> Yield y (Right (so :!: Stream nexti si'))
Done -> Skip (Left so)
Skip si' -> Skip (Right (so :!: Stream nexti si'))
{-# INLINE [0] concatMap #-}
data Lazy a = L a
concatMap' :: (a -> s -> Step s b) -> (a -> s) -> Stream a -> Stream b
concatMap' nexti f (Stream nexto s0) = Stream next (Left s0) where
next (Left !so) = case nexto so of
Yield x so' -> Skip (Right (so' :!: L x :!: f x))
Done -> Done
Skip so' -> Skip (Left so')
next (Right (so :!: L x :!: si)) = case nexti x si of
Yield y si' -> Yield y (Right (so :!: L x :!: si'))
Done -> Skip (Left so)
Skip si' -> Skip (Right (so :!: L x :!: si'))
{-# INLINE concatMap' #-}
{-# RULES "concatMap" forall next f. concatMap (\x -> Stream (next x) (f x)) = concatMap' next f #-}
eftInt :: Int -> Int -> Stream Int
eftInt = \x y ->
let
next i | i <= y = Yield i (i + 1)
| otherwise = Done
in Stream next x
{-# INLINE eftInt #-}
three_partitions :: Int -> Stream (Int,Int,Int)
three_partitions m =
concatMap
(\i ->
concatMap
(\j -> let !k = m - (i+j) in stream [(i,j,k)])
(eftInt i ((m-1) `div` 2)))
(eftInt 0 (m `div` 3))
Compiling using !12723 yields:
three_partitions :: Int -> [(Int, Int, Int)]
three_partitions
= \ (m_aU9 :: Int) ->
let {
y_s1p5 :: Int
y_s1p5
= case m_aU9 of { I# x_a1nS ->
I# (uncheckedIShiftRA# (-# x_a1nS 1#) 1#)
} } in
case m_aU9 of { I# x_a1oi ->
let {
c0#_a1pY :: Int#
c0#_a1pY = <# x_a1oi 0# } in
let {
wild1_X1 :: Int#
wild1_X1 = -# (quotInt# (+# x_a1oi c0#_a1pY) 3#) c0#_a1pY } in
case <=# 0# wild1_X1 of {
__DEFAULT -> [];
1# ->
letrec {
$sgo_s1qO
:: Int
-> Int#
-> Int
-> Int
-> (Int, Int, Int)
-> [(Int, Int, Int)]
-> [(Int, Int, Int)]
$sgo_s1qO
= \ (sc_s1qH :: Int)
(sc1_s1qI :: Int#)
(sc2_s1qJ :: Int)
(sc3_s1qK :: Int)
(sc4_s1qL :: (Int, Int, Int))
(sc5_s1qM :: [(Int, Int, Int)]) ->
case sc_s1qH of sc6_X2 { I# ipv_s1rf ->
case sc2_s1qJ of sc7_X3 { I# ipv1_s1rh ->
case sc3_s1qK of wild2_a1nH { I# y1_a1nI ->
: sc4_s1qL
(go_a1gN
(Right
(:!:
(:!: sc6_X2 (L (I# sc1_s1qI)))
(Right (:!: (:!: sc7_X3 (L wild2_a1nH)) (L sc5_s1qM))))))
}
}
};
$sgo1_s1qF :: Int -> Int# -> Int# -> [(Int, Int, Int)]
$sgo1_s1qF
= \ (sc_s1qC :: Int) (sc1_s1qD :: Int#) (sc2_s1qE :: Int#) ->
case sc_s1qC of sc3_X2 { I# ipv_s1rl ->
case y_s1p5 of { I# y1_X9 ->
case <=# sc2_s1qE y1_X9 of {
__DEFAULT -> $sgo2_s1qN sc3_X2;
1# ->
$sgo_s1qO
sc3_X2
sc1_s1qD
(I# (+# sc2_s1qE 1#))
(I# sc2_s1qE)
(I# sc1_s1qD, I# sc2_s1qE, I# (-# x_a1oi (+# sc1_s1qD sc2_s1qE)))
[]
}
}
};
$sgo2_s1qN :: Int -> [(Int, Int, Int)]
$sgo2_s1qN
= \ (sc_s1qG :: Int) ->
case sc_s1qG of { I# ipv_s1rj ->
case <=# ipv_s1rj wild1_X1 of {
__DEFAULT -> [];
1# -> $sgo1_s1qF (I# (+# ipv_s1rj 1#)) ipv_s1rj ipv_s1rj
}
};
go_a1gN
:: Either
Int
(Tuple
(Tuple Int (Lazy Int))
(Either
Int (Tuple (Tuple Int (Lazy Int)) (Lazy [(Int, Int, Int)]))))
-> [(Int, Int, Int)]
go_a1gN
= \ (s_aE6
:: Either
Int
(Tuple
(Tuple Int (Lazy Int))
(Either
Int (Tuple (Tuple Int (Lazy Int)) (Lazy [(Int, Int, Int)]))))) ->
case s_aE6 of {
Left so_aEt ->
case so_aEt of { I# x1_s1q8 ->
case <=# x1_s1q8 wild1_X1 of {
__DEFAULT -> [];
1# -> $sgo1_s1qF (I# (+# x1_s1q8 1#)) x1_s1q8 x1_s1q8
}
};
Right ds_d1lw ->
case ds_d1lw of { :!: ds1_s1qb si_s1qc ->
case ds1_s1qb of wild3_s1qe { :!: so_s1qf ds2_s1qg ->
case ds2_s1qg of { L x1_s1qj ->
case x1_s1qj of conrep_a10E { I# ipv_s1pI ->
case si_s1qc of {
Left so1_aEt ->
case so1_aEt of wild6_s1ql { I# x2_s1qm ->
case y_s1p5 of { I# y1_X9 ->
case <=# x2_s1qm y1_X9 of {
__DEFAULT -> $sgo2_s1qN so_s1qf;
1# ->
$sgo_s1qO
so_s1qf
ipv_s1pI
(I# (+# x2_s1qm 1#))
wild6_s1ql
(conrep_a10E, wild6_s1ql, I# (-# x_a1oi (+# ipv_s1pI x2_s1qm)))
[]
}
}
};
Right ds3_X8 ->
case ds3_X8 of { :!: ds4_s1qp si1_s1qq ->
case ds4_s1qp of wild7_s1qs { :!: so1_s1qt ds5_s1qu ->
case ds5_s1qu of { L x2_s1qx ->
case si1_s1qq of { L ds6_s1qA ->
case x2_s1qx of { I# y1_a1nI ->
case ds6_s1qA of {
[] -> go_a1gN (Right (:!: wild3_s1qe (Left so1_s1qt)));
: x3_aDY xs_aDZ ->
: x3_aDY
(go_a1gN
(Right (:!: wild3_s1qe (Right (:!: wild7_s1qs (L xs_aDZ))))))
}
}
}
}
}
}
}
}
}
}
}
}; } in
$sgo1_s1qF (I# 1#) 0# 0#
}
}
Expected behavior
Compiling with -fspec-constr-keen -fspec-constr-count=4
yields:
three_partitions :: Int -> [(Int, Int, Int)]
three_partitions
= \ (m_aTJ :: Int) ->
let {
y_s1oF :: Int
y_s1oF
= case m_aTJ of { I# x_a1ns ->
I# (uncheckedIShiftRA# (-# x_a1ns 1#) 1#)
} } in
case m_aTJ of { I# x_a1nS ->
let {
c0#_a1py :: Int#
c0#_a1py = <# x_a1nS 0# } in
let {
wild1_X1 :: Int#
wild1_X1 = -# (quotInt# (+# x_a1nS c0#_a1py) 3#) c0#_a1py } in
case <=# 0# wild1_X1 of {
__DEFAULT -> [];
1# ->
letrec {
$sgo_s1qf :: Int# -> Int# -> Int# -> [(Int, Int, Int)]
$sgo_s1qf
= \ (sc_s1qc :: Int#) (sc1_s1qd :: Int#) (sc2_s1qe :: Int#) ->
case y_s1oF of { I# y1_X9 ->
case <=# sc2_s1qe y1_X9 of {
__DEFAULT ->
case <=# sc_s1qc wild1_X1 of {
__DEFAULT -> [];
1# -> $sgo_s1qf (+# sc_s1qc 1#) sc_s1qc sc_s1qc
};
1# ->
: (I# sc1_s1qd, I# sc2_s1qe, I# (-# x_a1nS (+# sc1_s1qd sc2_s1qe)))
($sgo_s1qf sc_s1qc sc1_s1qd (+# sc2_s1qe 1#))
}
}; } in
$sgo_s1qf 1# 0# 0#
}
}
Which is 2.75x times faster on the input 10000
on my machine
Environment
- GHC version used: 9.11.20240604 (based on bb40244e)