Skip to content

Missed optimization

Please feel free to change the title to something that better describes the issue.

Summary

Here is some code that calculates the sum of vertices over a tree.

sum1 :: (Int -> [Int]) -> Int -> Int
sum1 neighbors root = go root 0 where
    f :: Int -> [Int -> Int] -> Int -> Int
    f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
    go :: Int -> Int -> Int
    go x = f x (map go (neighbors x))
{-# NOINLINE sum1 #-}

This might seem a bit unnatural, but it's a simplified example so please excuse that.
I expect the [Int -> Int] to get optimized away, but unfortunately that doesn't happen according to the core.

However, if I eta-expand go, the problem disappears. The same happens if I merge f into go.

Here are some benchmarks:

import Criterion.Main
import Data.List

main :: IO ()
main = defaultMain
    [ bench "sum1" $ whnf (sum1 binTree) 1
    , bench "sum2" $ whnf (sum2 binTree) 1
    , bench "sum3" $ whnf (sum3 binTree) 1
    ]

binTree :: Int -> [Int]
binTree x
    | x > 1000000 = []
    | otherwise   = [2*x + 1, 2*x + 2]

sum1 :: (Int -> [Int]) -> Int -> Int
sum1 neighbors root = go root 0 where
    f :: Int -> [Int -> Int] -> Int -> Int
    f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
    go :: Int -> Int -> Int
    go x = f x (map go (neighbors x))
{-# NOINLINE sum1 #-}

sum2 :: (Int -> [Int]) -> Int -> Int
sum2 neighbors root = go root 0 where
    f :: Int -> [Int -> Int] -> Int -> Int
    f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
    go :: Int -> Int -> Int
    go x eta1 = f x (map go (neighbors x)) eta1
{-# NOINLINE sum2 #-}

sum3 :: (Int -> [Int]) -> Int -> Int
sum3 neighbors root = go root 0 where
    go :: Int -> Int -> Int
    go x = \acc -> foldl' (\acc' k -> k acc') (acc + x) (map go (neighbors x))
{-# NOINLINE sum3 #-}

With GHC 9.4.4 and -O2:

benchmarking sum1
time                 49.07 ms   (48.88 ms .. 49.31 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 48.52 ms   (48.25 ms .. 48.75 ms)
std dev              473.0 μs   (371.5 μs .. 619.1 μs)

benchmarking sum2
time                 6.521 ms   (6.493 ms .. 6.539 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 6.542 ms   (6.538 ms .. 6.548 ms)
std dev              14.12 μs   (10.22 μs .. 19.91 μs)

benchmarking sum3
time                 6.565 ms   (6.505 ms .. 6.647 ms)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 6.544 ms   (6.531 ms .. 6.570 ms)
std dev              50.54 μs   (26.96 μs .. 91.55 μs)

Expected behavior

sum1 to be as efficient as sum2 and sum3.

Environment

  • GHC version used: 9.4.4

Optional:

  • Operating System: Ubuntu
  • System Architecture: x86_64

Edit: I have simplified the example a bit, the original example was

depthSum1 :: forall a. (a -> [a]) -> a -> Int
depthSum1 neighbors root = go root 0 0 where
    f :: a -> [Int -> Int -> Int] -> Int -> Int -> Int
    f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks
    go :: a -> Int -> Int -> Int
    go x = f x (map go (neighbors x))
Edited by meooow
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information