Skip to content

SpecConstr default settings are not sufficient for stream fusion

Summary

The three_partitions example fuses completely, but the resulting Core contains some glaring call pattern specialization opportunities. Unfortunately, I need to use both -fspec-constr-keen and -fspec-constr-count=4 to optimize it properly. Doing that speeds it up by a factor of 2.75x!

Steps to reproduce

module T (three_partitions) where

import Prelude hiding (concatMap, Either (..), Maybe (..)) --, foldl', unlines, (++), map)

data Stream a = forall s. Stream (s -> Step s a) !s
data Step s a = Yield a !s | Done | Skip !s

data Tuple a b = !a :!: !b
data Either a b = Left !a | Right !b
data Maybe a = Nothing | Just !a

stream :: [a] -> Stream a
stream = Stream next where
  next [] = Done
  next (x:xs) = Yield x xs
{-# INLINE [1] stream #-}

unstream :: Stream a -> [a]
unstream (Stream next s0) = go s0 where
  go !s = case next s of
    Yield x s' -> x : go s'
    Done -> []
    Skip s' -> go s'
{-# INLINE [1] unstream #-}

{-# RULES "stream/unstream" forall s. stream (unstream s) = s #-}

concatMap :: (a -> Stream b) -> Stream a -> Stream b
concatMap f (Stream nexto s0) = Stream next (Left s0) where
  next (Left !so) = case nexto so of
    Yield x so' -> Skip (Right (so' :!: f x))
    Done -> Done
    Skip so' -> Skip (Left so')
  next (Right (!so :!: Stream nexti si)) = case nexti si of
    Yield y si' -> Yield y (Right (so :!: Stream nexti si'))
    Done -> Skip (Left so)
    Skip si' -> Skip (Right (so :!: Stream nexti si')) 
{-# INLINE [0] concatMap #-}

data Lazy a = L a

concatMap' :: (a -> s -> Step s b) -> (a -> s) -> Stream a -> Stream b
concatMap' nexti f (Stream nexto s0) = Stream next (Left s0) where
  next (Left !so) = case nexto so of
    Yield x so' -> Skip (Right (so' :!: L x :!: f x))
    Done -> Done
    Skip so' -> Skip (Left so')
  next (Right (so :!: L x :!: si)) = case nexti x si of
    Yield y si' -> Yield y (Right (so :!: L x :!: si'))
    Done -> Skip (Left so)
    Skip si' -> Skip (Right (so :!: L x :!: si'))
{-# INLINE concatMap' #-}

{-# RULES "concatMap" forall next f. concatMap (\x -> Stream (next x) (f x)) = concatMap' next f #-}

eftInt :: Int -> Int -> Stream Int
eftInt = \x y -> 
  let
  next i | i <= y = Yield i (i + 1)
         | otherwise = Done
  in Stream next x
{-# INLINE eftInt #-}

three_partitions :: Int -> Stream (Int,Int,Int)
three_partitions m =
  concatMap
    (\i ->
      concatMap
        (\j -> let !k = m - (i+j) in stream [(i,j,k)])
        (eftInt i ((m-1) `div` 2)))
    (eftInt 0 (m `div` 3))
Compiling using !12723 yields:
three_partitions :: Int -> [(Int, Int, Int)]
three_partitions
  = \ (m_aU9 :: Int) ->
      let {
        y_s1p5 :: Int
        y_s1p5
          = case m_aU9 of { I# x_a1nS ->
            I# (uncheckedIShiftRA# (-# x_a1nS 1#) 1#)
            } } in
      case m_aU9 of { I# x_a1oi ->
      let {
        c0#_a1pY :: Int#
        c0#_a1pY = <# x_a1oi 0# } in
      let {
        wild1_X1 :: Int#
        wild1_X1 = -# (quotInt# (+# x_a1oi c0#_a1pY) 3#) c0#_a1pY } in
      case <=# 0# wild1_X1 of {
        __DEFAULT -> [];
        1# ->
          letrec {
            $sgo_s1qO
              :: Int
                 -> Int#
                 -> Int
                 -> Int
                 -> (Int, Int, Int)
                 -> [(Int, Int, Int)]
                 -> [(Int, Int, Int)]
            $sgo_s1qO
              = \ (sc_s1qH :: Int)
                  (sc1_s1qI :: Int#)
                  (sc2_s1qJ :: Int)
                  (sc3_s1qK :: Int)
                  (sc4_s1qL :: (Int, Int, Int))
                  (sc5_s1qM :: [(Int, Int, Int)]) ->
                  case sc_s1qH of sc6_X2 { I# ipv_s1rf ->
                  case sc2_s1qJ of sc7_X3 { I# ipv1_s1rh ->
                  case sc3_s1qK of wild2_a1nH { I# y1_a1nI ->
                  : sc4_s1qL
                    (go_a1gN
                       (Right
                          (:!:
                             (:!: sc6_X2 (L (I# sc1_s1qI)))
                             (Right (:!: (:!: sc7_X3 (L wild2_a1nH)) (L sc5_s1qM))))))
                  }
                  }
                  };
            $sgo1_s1qF :: Int -> Int# -> Int# -> [(Int, Int, Int)]
            $sgo1_s1qF
              = \ (sc_s1qC :: Int) (sc1_s1qD :: Int#) (sc2_s1qE :: Int#) ->
                  case sc_s1qC of sc3_X2 { I# ipv_s1rl ->
                  case y_s1p5 of { I# y1_X9 ->
                  case <=# sc2_s1qE y1_X9 of {
                    __DEFAULT -> $sgo2_s1qN sc3_X2;
                    1# ->
                      $sgo_s1qO
                        sc3_X2
                        sc1_s1qD
                        (I# (+# sc2_s1qE 1#))
                        (I# sc2_s1qE)
                        (I# sc1_s1qD, I# sc2_s1qE, I# (-# x_a1oi (+# sc1_s1qD sc2_s1qE)))
                        []
                  }
                  }
                  };
            $sgo2_s1qN :: Int -> [(Int, Int, Int)]
            $sgo2_s1qN
              = \ (sc_s1qG :: Int) ->
                  case sc_s1qG of { I# ipv_s1rj ->
                  case <=# ipv_s1rj wild1_X1 of {
                    __DEFAULT -> [];
                    1# -> $sgo1_s1qF (I# (+# ipv_s1rj 1#)) ipv_s1rj ipv_s1rj
                  }
                  };
            go_a1gN
              :: Either
                   Int
                   (Tuple
                      (Tuple Int (Lazy Int))
                      (Either
                         Int (Tuple (Tuple Int (Lazy Int)) (Lazy [(Int, Int, Int)]))))
                 -> [(Int, Int, Int)]
            go_a1gN
              = \ (s_aE6
                     :: Either
                          Int
                          (Tuple
                             (Tuple Int (Lazy Int))
                             (Either
                                Int (Tuple (Tuple Int (Lazy Int)) (Lazy [(Int, Int, Int)]))))) ->
                  case s_aE6 of {
                    Left so_aEt ->
                      case so_aEt of { I# x1_s1q8 ->
                      case <=# x1_s1q8 wild1_X1 of {
                        __DEFAULT -> [];
                        1# -> $sgo1_s1qF (I# (+# x1_s1q8 1#)) x1_s1q8 x1_s1q8
                      }
                      };
                    Right ds_d1lw ->
                      case ds_d1lw of { :!: ds1_s1qb si_s1qc ->
                      case ds1_s1qb of wild3_s1qe { :!: so_s1qf ds2_s1qg ->
                      case ds2_s1qg of { L x1_s1qj ->
                      case x1_s1qj of conrep_a10E { I# ipv_s1pI ->
                      case si_s1qc of {
                        Left so1_aEt ->
                          case so1_aEt of wild6_s1ql { I# x2_s1qm ->
                          case y_s1p5 of { I# y1_X9 ->
                          case <=# x2_s1qm y1_X9 of {
                            __DEFAULT -> $sgo2_s1qN so_s1qf;
                            1# ->
                              $sgo_s1qO
                                so_s1qf
                                ipv_s1pI
                                (I# (+# x2_s1qm 1#))
                                wild6_s1ql
                                (conrep_a10E, wild6_s1ql, I# (-# x_a1oi (+# ipv_s1pI x2_s1qm)))
                                []
                          }
                          }
                          };
                        Right ds3_X8 ->
                          case ds3_X8 of { :!: ds4_s1qp si1_s1qq ->
                          case ds4_s1qp of wild7_s1qs { :!: so1_s1qt ds5_s1qu ->
                          case ds5_s1qu of { L x2_s1qx ->
                          case si1_s1qq of { L ds6_s1qA ->
                          case x2_s1qx of { I# y1_a1nI ->
                          case ds6_s1qA of {
                            [] -> go_a1gN (Right (:!: wild3_s1qe (Left so1_s1qt)));
                            : x3_aDY xs_aDZ ->
                              : x3_aDY
                                (go_a1gN
                                   (Right (:!: wild3_s1qe (Right (:!: wild7_s1qs (L xs_aDZ))))))
                          }
                          }
                          }
                          }
                          }
                          }
                      }
                      }
                      }
                      }
                      }
                  }; } in
          $sgo1_s1qF (I# 1#) 0# 0#
      }
      }

Expected behavior

Compiling with -fspec-constr-keen -fspec-constr-count=4 yields:

three_partitions :: Int -> [(Int, Int, Int)]
three_partitions
  = \ (m_aTJ :: Int) ->
      let {
        y_s1oF :: Int
        y_s1oF
          = case m_aTJ of { I# x_a1ns ->
            I# (uncheckedIShiftRA# (-# x_a1ns 1#) 1#)
            } } in
      case m_aTJ of { I# x_a1nS ->
      let {
        c0#_a1py :: Int#
        c0#_a1py = <# x_a1nS 0# } in
      let {
        wild1_X1 :: Int#
        wild1_X1 = -# (quotInt# (+# x_a1nS c0#_a1py) 3#) c0#_a1py } in
      case <=# 0# wild1_X1 of {
        __DEFAULT -> [];
        1# ->
          letrec {
            $sgo_s1qf :: Int# -> Int# -> Int# -> [(Int, Int, Int)]
            $sgo_s1qf
              = \ (sc_s1qc :: Int#) (sc1_s1qd :: Int#) (sc2_s1qe :: Int#) ->
                  case y_s1oF of { I# y1_X9 ->
                  case <=# sc2_s1qe y1_X9 of {
                    __DEFAULT ->
                      case <=# sc_s1qc wild1_X1 of {
                        __DEFAULT -> [];
                        1# -> $sgo_s1qf (+# sc_s1qc 1#) sc_s1qc sc_s1qc
                      };
                    1# ->
                      : (I# sc1_s1qd, I# sc2_s1qe, I# (-# x_a1nS (+# sc1_s1qd sc2_s1qe)))
                        ($sgo_s1qf sc_s1qc sc1_s1qd (+# sc2_s1qe 1#))
                  }
                  }; } in
          $sgo_s1qf 1# 0# 0#
      }
      }

Which is 2.75x times faster on the input 10000 on my machine

Environment

  • GHC version used: 9.11.20240604 (based on bb40244e)
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information