Missed loop fusion optimisation
Summary
I've written a boxed and unboxed version of the same program, but for some reason GHC is only able to properly optimize the boxed version. In particular the program has two loops which fuse in the boxed version but stay separate in the unboxed version.
For more background info, see this GitHub issue.
Steps to reproduce
Boxed:
module Boxed (test) where
import Data.Primitive.Array
import System.IO.Unsafe (unsafePerformIO)
data Step s a = Yield a s | Done
uninitialised = undefined
test :: Int -> Int -> Array Double -> (Int, Int, Array Double)
test off n oldArr = unsafePerformIO $ do
newArr <- newArray n uninitialised
let
step' i
| i >= n = Done
| otherwise =
let x = indexArray oldArr (off + i) in
if x > 10
then Yield x (i + 1)
else step' (i + 1)
loop i j = do
case step' i of
Yield x s' -> do
writeArray newArr j (x + 1)
loop s' (j + 1)
Done -> do
out <- unsafeFreezeArray newArr
return (0, j, out)
loop 0 0
Unboxed:
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module Unboxed (test) where
import GHC.Exts
import GHC.IO
data Step s a = Yield a s | Done
uninitialised = undefined
test :: Int# -> Int# -> Array# Double -> (# Int#, Int#, Array# Double #)
test off n oldArr = runRW# $ \s0 ->
case newArray# n uninitialised s0
of { (# s1, newArr #) ->
let
step' i
| isTrue# (i >=# n) = Done
| otherwise =
let (# D# x #) = indexArray# oldArr (off +# i) in
if isTrue# (x >## 10.0##)
then Yield (D# x) (I# (i +# 1#))
else step' (i +# 1#)
loop i j s2 =
case step' i of
Yield x (I# s') ->
case writeArray# newArr j (x + 1) s2
of { s3 ->
loop s' (j +# 1#) s3
}
Done ->
case unsafeFreezeArray# newArr s2
of { (# s3, out #) ->
(# 0#, j, out #)
}
in
loop 0# 0# s1
}
Expected behavior
I expect both programs to produce very similar core where the two loops are fused and the intermediate Yield
and Done
constructors are eliminated.
That does happen for the boxed version:
$wtest :: Int -> Int# -> Array Double -> (Int, Int, Array Double)
$wtest
= \ (w :: Int) (ww :: Int#) (w1 :: Array Double) ->
runRW#
(\ (s :: State# RealWorld) ->
case noDuplicate# s of s' { __DEFAULT ->
case newArray# ww uninitialised (s' `cast` <Co:3>) of
{ (# ipv, ipv1 #) ->
joinrec {
$s$wloop
:: State# RealWorld -> Int# -> Int# -> (Int, Int, Array Double)
$s$wloop (sc :: State# RealWorld) (sc1 :: Int#) (sc2 :: Int#)
= join {
$j :: (Int, Int, Array Double)
$j
= case unsafeFreezeArray# ipv1 (sc `cast` <Co:3>) of
{ (# ipv2, ipv3 #) ->
lazy (test1, I# sc1, Array ipv3)
} } in
case >=# sc2 ww of {
__DEFAULT ->
case w of { I# x ->
case w1 of { Array ds2 ->
case indexArray# ds2 (+# x sc2) of { (# ipv2 #) ->
case ipv2 of { D# x1 ->
case >## x1 10.0## of {
__DEFAULT ->
joinrec {
$wstep' :: Int# -> (Int, Int, Array Double)
$wstep' (ww1 :: Int#)
= case >=# ww1 ww of {
__DEFAULT ->
case indexArray# ds2 (+# x ww1) of { (# ipv3 #) ->
case ipv3 of { D# x2 ->
case >## x2 10.0## of {
__DEFAULT -> jump $wstep' (+# ww1 1#);
1# ->
case writeArray#
ipv1 sc1 (D# (+## x2 1.0##)) (sc `cast` <Co:3>)
of s'#
{ __DEFAULT ->
jump $s$wloop (s'# `cast` <Co:2>) (+# sc1 1#) (+# ww1 1#)
}
} } };
1# -> jump $j
}; } in
jump $wstep' (+# sc2 1#);
1# ->
case writeArray# ipv1 sc1 (D# (+## x1 1.0##)) (sc `cast` <Co:3>)
of s'#
{ __DEFAULT ->
jump $s$wloop (s'# `cast` <Co:2>) (+# sc1 1#) (+# sc2 1#)
}
} } } } };
1# -> jump $j
}; } in
jump $s$wloop (ipv `cast` <Co:2>) 0# 0#
} })
But the unboxed version keeps the loops separate and doesn't eliminate Yield
and Done
:
test
= \ off n oldArr ->
runRW#
(\ s0 ->
case newArray# n uninitialised s0 of { (# ipv, ipv1 #) ->
letrec {
step'
= \ i ->
case >=# i n of {
__DEFAULT ->
case indexArray# oldArr (+# off i) of { (# ipv2 #) ->
case ipv2 of wild { D# x ->
case >## x 10.0## of {
__DEFAULT -> step' (+# i 1#);
1# -> Yield wild (I# (+# i 1#))
}
}
};
1# -> Done
}; } in
join {
exit j s2
= case unsafeFreezeArray# ipv1 s2 of { (# ipv2, ipv3 #) ->
(# 0#, j, ipv3 #)
} } in
joinrec {
loop i j s2
= case step' i of {
Yield x ds1 ->
case ds1 of { I# s' ->
case writeArray#
ipv1 j (case x of { D# x1 -> D# (+## x1 1.0##) }) s2
of s3
{ __DEFAULT ->
jump loop s' (+# j 1#) s3
}
};
Done -> jump exit j s2
}; } in
jump loop 0# 0# ipv
})
Environment
- GHC version used: 9.2.3 and 9.4.2