Missed optimization
Please feel free to change the title to something that better describes the issue.
Summary
Here is some code that calculates the sum of vertices over a tree.
sum1 :: (Int -> [Int]) -> Int -> Int
sum1 neighbors root = go root 0 where
f :: Int -> [Int -> Int] -> Int -> Int
f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
go :: Int -> Int -> Int
go x = f x (map go (neighbors x))
{-# NOINLINE sum1 #-}
This might seem a bit unnatural, but it's a simplified example so please excuse that.
I expect the [Int -> Int]
to get optimized away, but unfortunately that doesn't happen according to the core.
However, if I eta-expand go
, the problem disappears. The same happens if I merge f
into go
.
Here are some benchmarks:
import Criterion.Main
import Data.List
main :: IO ()
main = defaultMain
[ bench "sum1" $ whnf (sum1 binTree) 1
, bench "sum2" $ whnf (sum2 binTree) 1
, bench "sum3" $ whnf (sum3 binTree) 1
]
binTree :: Int -> [Int]
binTree x
| x > 1000000 = []
| otherwise = [2*x + 1, 2*x + 2]
sum1 :: (Int -> [Int]) -> Int -> Int
sum1 neighbors root = go root 0 where
f :: Int -> [Int -> Int] -> Int -> Int
f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
go :: Int -> Int -> Int
go x = f x (map go (neighbors x))
{-# NOINLINE sum1 #-}
sum2 :: (Int -> [Int]) -> Int -> Int
sum2 neighbors root = go root 0 where
f :: Int -> [Int -> Int] -> Int -> Int
f x ks acc = foldl' (\acc' k -> k acc') (acc + x) ks
go :: Int -> Int -> Int
go x eta1 = f x (map go (neighbors x)) eta1
{-# NOINLINE sum2 #-}
sum3 :: (Int -> [Int]) -> Int -> Int
sum3 neighbors root = go root 0 where
go :: Int -> Int -> Int
go x = \acc -> foldl' (\acc' k -> k acc') (acc + x) (map go (neighbors x))
{-# NOINLINE sum3 #-}
With GHC 9.4.4 and -O2:
benchmarking sum1
time 49.07 ms (48.88 ms .. 49.31 ms)
1.000 R² (1.000 R² .. 1.000 R²)
mean 48.52 ms (48.25 ms .. 48.75 ms)
std dev 473.0 μs (371.5 μs .. 619.1 μs)
benchmarking sum2
time 6.521 ms (6.493 ms .. 6.539 ms)
1.000 R² (1.000 R² .. 1.000 R²)
mean 6.542 ms (6.538 ms .. 6.548 ms)
std dev 14.12 μs (10.22 μs .. 19.91 μs)
benchmarking sum3
time 6.565 ms (6.505 ms .. 6.647 ms)
1.000 R² (0.999 R² .. 1.000 R²)
mean 6.544 ms (6.531 ms .. 6.570 ms)
std dev 50.54 μs (26.96 μs .. 91.55 μs)
Expected behavior
sum1
to be as efficient as sum2
and sum3
.
Environment
- GHC version used: 9.4.4
Optional:
- Operating System: Ubuntu
- System Architecture: x86_64
Edit: I have simplified the example a bit, the original example was
depthSum1 :: forall a. (a -> [a]) -> a -> Int
depthSum1 neighbors root = go root 0 0 where
f :: a -> [Int -> Int -> Int] -> Int -> Int -> Int
f _ ks depth acc = foldl' (\acc' k -> k (depth+1) acc') (acc + depth) ks
go :: a -> Int -> Int -> Int
go x = f x (map go (neighbors x))