Unexpected code duplication of join point continuations...
In the bytestring
package read-only IO loops over the content often masquerade as pure. The upper layers of a given function may be pure, but they call code that ultimately ends up calling accursedUnutterablePerformIO
.
It is sometimes useful to inline the I/O helper into the upper layer of the code, deferring heap allocation of intermediate results. What actually happens is the the calling code gets inlined into the recursive I/O loop as continuation join points.
Surprisingly to a naïve user, it is much too easy to end with significant code duplication via innocuous looking conditional expressions, that lead to multiple code paths in or out of the loop.
The "paths in" code duplication:
The obvious case one should expect is:
if | cond1 -> ... do the io ... -- INLINEd as requested
| cond2 -> ... do the io ... -- INLINEd again
Even if in this case "INLINEd" means that the calling code is copied as a continuation into two separate copies of the loop with various join points, this is in retrospect I think to be expected.
A less obvious "paths in" outcome arose when trying to eliminate code duplication from a loop performing overflow-detected parsing of a decimal Int
(as Word
to possibly negate) from a Lazy.ByteString
. The initial state was:
readWord !q !r =
\ !s -> consume s False 0
where
-- All done
consume s@Empty !valid !acc
= if valid then convert acc s else Nothing
-- skip empty chunk
consume (Chunk (BI.BS _ 0) cs) !valid !acc
= consume cs valid acc
-- process non-empty chunk
consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
= case _digits q r c acc of
Result used acc'
| used <= 0 -- No more digits present
-> if valid then convert acc' s else Nothing
| used < len -- valid input not entirely digits
-> let !c' = BU.unsafeDrop used c
in convert acc' $ Chunk c' cs
| otherwise -- try to read more digits
-> consume cs True acc'
Overflow -> Nothing
convert !acc s =
let !i = case r of
-- See [Note @maxBound `quotRem` 10 == (q, 7)@]
7 -> fromIntegral @Accum @a acc
_ -> negate @a $ fromIntegral @Accum @a acc
in Just (i, s)
Which produced two copies of the loop because there are two ways to end up recursively calling consume
, and empty chunk case ended up in a separate copy of the loop. Therefore, it seemed logical to consolidate the code as follows:
readWord !q !r =
\ !s -> consume s False 0
where
consume :: ByteString -> Int -> Accum -> Maybe (a, ByteString)
-- All done
consume s@Empty !valid !acc
= if valid then convert acc s else Nothing
-- Process chunk
consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
= case _digits q r c acc of
Result used acc'
| used == len -- try to read more digits
-> consume cs (valid || used > 0) acc'
| used > 0 -- valid input not entirely digits
-> let !c' = BU.unsafeDrop used c
in convert acc' $ Chunk c' cs
| otherwise -- No more digits present
-> if valid then convert acc' s else Nothing
Overflow -> Nothing
convert !acc s =
let !i = case r of
-- See [Note @maxBound `quotRem` 10 == (q, 7)@]
7 -> fromIntegral @Accum @a acc
_ -> negate @a $ fromIntegral @Accum @a acc
in Just (i, s)
It now looks like there's only path into calling _digits
, but instead of reducing the number of definitions from 2 to 1, we
get an increase from 2 to 3! This is because in:
consume cs (valid || used > 0) acc'
The argument valid || used > 0
floats out, and we end up with something along the lines of:
case valid of
True -> consume cs True acc'
_ -> case used > 0 of
True -> consume cs True acc'
_ -> consume cs False acc'
each of which then gets inlined with the loop. Oops!
One solution is to switch to bitwise logical operations, rather than
logical ||
:
readWord !q !r =
\ !s -> consume s 0 0
where
consume :: ByteString -> Int -> Accum -> Maybe (a, ByteString)
-- All done
consume s@Empty !valid !acc
= if valid /= 0 then convert acc s else Nothing
-- Process chunk
consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
= case _digits q r c acc of
Result used acc'
| used == len -- try to read more digits
-> consume cs (valid .|. used) acc'
| used > 0 -- valid input not entirely digits
-> let !c' = BU.unsafeDrop used c
in convert acc' $ Chunk c' cs
| otherwise -- No more digits present
-> if valid /= 0 then convert acc' s else Nothing
Overflow -> Nothing
convert !acc s =
let !i = case r of
-- See [Note @maxBound `quotRem` 10 == (q, 7)@]
7 -> fromIntegral @Accum @a acc
_ -> negate @a $ fromIntegral @Accum @a acc
in Just (i, s)
A more surprising outcome was duplication of the continuations of each
exit "path out" of the loop:
Where the caller's continuations get forked into essentially identical join points. The original loop code was:_digits :: Accum -> Accum -> BI.ByteString -> Accum -> Result
{-# INLINE _digits #-}
_digits !q !r !(BI.BS !fp !len) = \ !acc ->
BI.accursedUnutterablePerformIO $
BI.unsafeWithForeignPtr fp $ \ptr -> do
let end = ptr `plusPtr` len
go ptr end ptr acc
where
go start end = loop
where
loop !ptr !acc | ptr == end
= return $ Result (ptr `minusPtr` start) acc
loop !ptr !acc = do
!d <- fromDigit <$> peek ptr
if | d > 9 -- No more digits
-> return $ Result (ptr `minusPtr` start) acc
| acc < q
-> loop (ptr `plusPtr` 1) (acc * 10 + d)
| acc > q
-> return Overflow
| d <= r
-> loop (ptr `plusPtr` 1) (acc * 10 + d)
| otherwise
-> return Overflow
--
fromDigit = \w -> fromIntegral w - 0x30 -- i.e. w - '0'
With two ways to return a Result
, and two ways to return Overflow
, each of these created separate join points for the respective continuations on the return value.
In the generated core, each copy of the loop then internally duplications the exit2/exit3 continuations:
-- RHS size: {terms: 533, types: 314, coercions: 0, joins: 12/20}
$w$j :: Word# -> ByteString -> Maybe (Int64, ByteString)
$w$j
= \ (ww :: Word#) (ww1 :: ByteString) ->
joinrec {
$s$wconsume :: Word# -> ByteString -> Maybe (Int64, ByteString)
$s$wconsume (sc :: Word#) (sc1 :: ByteString)
= case sc1 of wild {
Empty ->
case ww of {
__DEFAULT -> Just (I64# (negateInt# (word2Int# sc)), Empty);
7## -> Just (I64# (word2Int# sc), Empty)
};
Chunk dt dt1 dt2 cs ->
case dt2 of ds {
__DEFAULT ->
let {
end :: Addr#
end = plusAddr# dt ds } in
join {
exit :: State# RealWorld -> Maybe (Int64, ByteString)
exit (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit1 :: State# RealWorld -> Maybe (Int64, ByteString)
exit1 (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit2
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
exit2 (ww2 :: Addr#) (ww3 :: Word#) (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT ->
let {
dt3 :: Int#
dt3 = minusAddr# ww2 dt } in
case <=# dt3 0# of {
__DEFAULT ->
case <# dt3 ds of {
__DEFAULT -> jump $s$wconsume ww3 cs;
1# ->
case ww of {
__DEFAULT ->
Just
(I64# (negateInt# (word2Int# ww3)),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
7## ->
Just
(I64# (word2Int# ww3),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
}
};
1# ->
case ww of {
__DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), wild);
7## -> Just (I64# (word2Int# ww3), wild)
}
}
} } in
join {
exit3
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
exit3 (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
= case touch# dt1 w of { __DEFAULT ->
let {
dt3 :: Int#
dt3 = minusAddr# ww2 dt } in
case <=# dt3 0# of {
__DEFAULT ->
case <# dt3 ds of {
__DEFAULT -> jump $s$wconsume ww3 cs;
1# ->
case ww of {
__DEFAULT ->
Just
(I64# (negateInt# (word2Int# ww3)),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
7## ->
Just
(I64# (word2Int# ww3),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
}
};
1# ->
case ww of {
__DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), wild);
7## -> Just (I64# (word2Int# ww3), wild)
}
}
} } in
joinrec {
$wloop
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
$wloop (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
= case eqAddr# ww2 end of {
__DEFAULT ->
case readWord8OffAddr# ww2 0# w of { (# ipv, ipv1 #) ->
let {
ipv2 :: Word#
ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
case gtWord# ipv2 9## of {
__DEFAULT ->
case ltWord# ww3 922337203685477580## of {
__DEFAULT ->
case gtWord# ww3 922337203685477580## of {
__DEFAULT ->
case leWord# ipv2 ww of {
__DEFAULT -> jump exit ipv;
1# ->
jump $wloop
(plusAddr# ww2 1#)
(plusWord# (timesWord# ww3 10##) ipv2)
ipv
};
1# -> jump exit1 ipv
};
1# ->
jump $wloop
(plusAddr# ww2 1#)
(plusWord# (timesWord# ww3 10##) ipv2)
ipv
};
1# -> jump exit2 ww2 ww3 ipv
}
};
1# -> jump exit3 ww2 ww3 w
}; } in
jump $wloop dt sc realWorld#;
0# -> jump $s$wconsume sc cs
}
}; } in
joinrec {
$s$wconsume1 :: Word# -> ByteString -> Maybe (Int64, ByteString)
$s$wconsume1 (sc :: Word#) (sc1 :: ByteString)
= case sc1 of {
Empty -> Nothing;
Chunk dt dt1 dt2 cs ->
case dt2 of ds {
__DEFAULT ->
let {
end :: Addr#
end = plusAddr# dt ds } in
join {
exit :: State# RealWorld -> Maybe (Int64, ByteString)
exit (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit1 :: State# RealWorld -> Maybe (Int64, ByteString)
exit1 (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit2
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
exit2 (ww2 :: Addr#) (ww3 :: Word#) (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT ->
let {
dt3 :: Int#
dt3 = minusAddr# ww2 dt } in
case <=# dt3 0# of {
__DEFAULT ->
case <# dt3 ds of {
__DEFAULT -> jump $s$wconsume ww3 cs;
1# ->
case ww of {
__DEFAULT ->
Just
(I64# (negateInt# (word2Int# ww3)),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
7## ->
Just
(I64# (word2Int# ww3),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
}
};
1# -> Nothing
}
} } in
join {
exit3
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
exit3 (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
= case touch# dt1 w of { __DEFAULT ->
let {
dt3 :: Int#
dt3 = minusAddr# ww2 dt } in
case <=# dt3 0# of {
__DEFAULT ->
case <# dt3 ds of {
__DEFAULT -> jump $s$wconsume ww3 cs;
1# ->
case ww of {
__DEFAULT ->
Just
(I64# (negateInt# (word2Int# ww3)),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
7## ->
Just
(I64# (word2Int# ww3),
Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
}
};
1# -> Nothing
}
} } in
joinrec {
$wloop
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
$wloop (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
= case eqAddr# ww2 end of {
__DEFAULT ->
case readWord8OffAddr# ww2 0# w of { (# ipv, ipv1 #) ->
let {
ipv2 :: Word#
ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
case gtWord# ipv2 9## of {
__DEFAULT ->
case ltWord# ww3 922337203685477580## of {
__DEFAULT ->
case gtWord# ww3 922337203685477580## of {
__DEFAULT ->
case leWord# ipv2 ww of {
__DEFAULT -> jump exit ipv;
1# ->
jump $wloop
(plusAddr# ww2 1#)
(plusWord# (timesWord# ww3 10##) ipv2)
ipv
};
1# -> jump exit1 ipv
};
1# ->
jump $wloop
(plusAddr# ww2 1#)
(plusWord# (timesWord# ww3 10##) ipv2)
ipv
};
1# -> jump exit2 ww2 ww3 ipv
}
};
1# -> jump exit3 ww2 ww3 w
}; } in
jump $wloop dt sc realWorld#;
0# -> jump $s$wconsume1 sc cs
}
}; } in
jump $s$wconsume1 0## ww1
So to eliminate duplicate exit paths, the above had to be refactored as follows:
_digits :: Accum -> Accum -> BI.ByteString -> Accum -> Result
{-# INLINE _digits #-}
_digits !q !r !(BI.BS !fp !len) = \ !acc ->
BI.accursedUnutterablePerformIO $
BI.unsafeWithForeignPtr fp $ \ptr -> do
let end = ptr `plusPtr` len
go ptr end ptr acc
where
go start end = loop
where
loop !ptr !acc = getDigit >>= \ !d ->
if | d > 9
-> return $ Result (ptr `minusPtr` start) acc
| acc < q || acc == q && d <= r
-> loop (ptr `plusPtr` 1) (acc * 10 + d)
| otherwise
-> return Overflow
where
fromDigit = \w -> fromIntegral w - 0x30 -- i.e. w - '0'
--
getDigit | ptr /= end = fromDigit <$> peek ptr
| otherwise = pure 10 -- End of input
{-# NOINLINE getDigit #-}
This required the somewhat artificial getDigit
to be have a NOINLINE
annotation, to avoid the code paths being split anyway.
The new Core output is much better (exit2/exit3 are the same, but too small to really matter, and may be optimised out in the assembly):
-- RHS size: {terms: 247, types: 174, coercions: 0, joins: 8/11}
$w$j :: Word# -> ByteString -> Maybe (Int64, ByteString)
$w$j
= \ (ww :: Word#) (ww1 :: ByteString) ->
join {
exit :: Int# -> Word# -> Maybe (Int64, ByteString)
exit (ww2 :: Int#) (ww3 :: Word#)
= case ww2 of {
__DEFAULT ->
case ww of {
__DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), Empty);
7## -> Just (I64# (word2Int# ww3), Empty)
};
0# -> Nothing
} } in
join {
exit1
:: Int#
-> ByteString
-> Addr#
-> ForeignPtrContents
-> Int#
-> ByteString
-> Int#
-> Word#
-> Maybe (Int64, ByteString)
exit1 (ww2 :: Int#)
(wild :: ByteString)
(dt :: Addr#)
(dt1 :: ForeignPtrContents)
(dt2 :: Int#)
(cs :: ByteString)
(dt3 :: Int#)
(dt4 :: Word#)
= case ># dt3 0# of {
__DEFAULT ->
case ww2 of {
__DEFAULT ->
case ww of {
__DEFAULT -> Just (I64# (negateInt# (word2Int# dt4)), wild);
7## -> Just (I64# (word2Int# dt4), wild)
};
0# -> Nothing
};
1# ->
case ww of {
__DEFAULT ->
Just
(I64# (negateInt# (word2Int# dt4)),
Chunk (plusAddr# dt dt3) dt1 (-# dt2 dt3) cs);
7## ->
Just
(I64# (word2Int# dt4),
Chunk (plusAddr# dt dt3) dt1 (-# dt2 dt3) cs)
}
} } in
joinrec {
$wconsume
:: ByteString -> Int# -> Word# -> Maybe (Int64, ByteString)
$wconsume (w :: ByteString) (ww2 :: Int#) (ww3 :: Word#)
= case w of wild {
Empty -> jump exit ww2 ww3;
Chunk dt dt1 dt2 cs ->
let {
end :: Addr#
end = plusAddr# dt dt2 } in
join {
exit2 :: State# RealWorld -> Maybe (Int64, ByteString)
exit2 (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit3 :: State# RealWorld -> Maybe (Int64, ByteString)
exit3 (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
join {
exit4
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
exit4 (ww4 :: Addr#) (ww5 :: Word#) (ipv :: State# RealWorld)
= case touch# dt1 ipv of { __DEFAULT ->
let {
dt3 :: Int#
dt3 = minusAddr# ww4 dt } in
case ==# dt3 dt2 of {
__DEFAULT -> jump exit1 ww2 wild dt dt1 dt2 cs dt3 ww5;
1# -> jump $wconsume cs (orI# ww2 dt3) ww5
}
} } in
joinrec {
$wloop
:: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
$wloop (ww4 :: Addr#) (ww5 :: Word#) (w1 :: State# RealWorld)
= join {
getDigit :: State# RealWorld -> Maybe (Int64, ByteString)
getDigit (eta :: State# RealWorld)
= case eqAddr# ww4 end of {
__DEFAULT ->
case readWord8OffAddr# ww4 0# eta of { (# ipv, ipv1 #) ->
let {
ipv2 :: Word#
ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
case gtWord# ipv2 9## of {
__DEFAULT ->
case ltWord# ww5 922337203685477580## of {
__DEFAULT ->
case ww5 of {
__DEFAULT -> jump exit2 ipv;
922337203685477580## ->
case leWord# ipv2 ww of {
__DEFAULT -> jump exit3 ipv;
1# ->
jump $wloop
(plusAddr# ww4 1#)
(plusWord# 9223372036854775800## ipv2)
ipv
}
};
1# ->
jump $wloop
(plusAddr# ww4 1#)
(plusWord# (timesWord# ww5 10##) ipv2)
ipv
};
1# -> jump exit4 ww4 ww5 ipv
}
};
1# -> jump exit4 ww4 ww5 eta
} } in
jump getDigit w1; } in
jump $wloop dt ww3 realWorld#
}; } in
jump $wconsume ww1 0# 0##
but I'm somewhat reluctant to saddle the code with subtle optimisations, they are not easily maintained...