Skip to content

Unexpected code duplication of join point continuations...

In the bytestring package read-only IO loops over the content often masquerade as pure. The upper layers of a given function may be pure, but they call code that ultimately ends up calling accursedUnutterablePerformIO.

It is sometimes useful to inline the I/O helper into the upper layer of the code, deferring heap allocation of intermediate results. What actually happens is the the calling code gets inlined into the recursive I/O loop as continuation join points.

Surprisingly to a naïve user, it is much too easy to end with significant code duplication via innocuous looking conditional expressions, that lead to multiple code paths in or out of the loop.

The "paths in" code duplication:

The obvious case one should expect is:

if | cond1 -> ... do the io ... -- INLINEd as requested
   | cond2 -> ... do the io ... -- INLINEd again

Even if in this case "INLINEd" means that the calling code is copied as a continuation into two separate copies of the loop with various join points, this is in retrospect I think to be expected.

A less obvious "paths in" outcome arose when trying to eliminate code duplication from a loop performing overflow-detected parsing of a decimal Int (as Word to possibly negate) from a Lazy.ByteString. The initial state was:

readWord !q !r =
    \ !s -> consume s False 0
  where
    -- All done
    consume s@Empty !valid !acc
        = if valid then convert acc s else Nothing
    -- skip empty chunk
    consume (Chunk (BI.BS _ 0) cs) !valid !acc
        = consume cs valid acc
    -- process non-empty chunk
    consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
        = case _digits q r c acc of
            Result used acc'
                | used <= 0 -- No more digits present
                  -> if valid then convert acc' s else Nothing
                | used < len -- valid input not entirely digits
                  -> let !c' = BU.unsafeDrop used c
                      in convert acc' $ Chunk c' cs
                | otherwise -- try to read more digits
                  -> consume cs True acc'
            Overflow -> Nothing

    convert !acc s =
        let !i = case r of
                -- See [Note @maxBound `quotRem` 10 == (q, 7)@]
                7 -> fromIntegral @Accum @a acc
                _ -> negate @a $ fromIntegral @Accum @a acc
         in Just (i, s)

Which produced two copies of the loop because there are two ways to end up recursively calling consume, and empty chunk case ended up in a separate copy of the loop. Therefore, it seemed logical to consolidate the code as follows:

readWord !q !r =
    \ !s -> consume s False 0
  where
    consume :: ByteString -> Int -> Accum -> Maybe (a, ByteString)
    -- All done
    consume s@Empty !valid !acc
        = if valid then convert acc s else Nothing
    -- Process chunk
    consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
        = case _digits q r c acc of
            Result used acc'
                | used == len -- try to read more digits
                  -> consume cs (valid || used > 0) acc'
                | used > 0 -- valid input not entirely digits
                  -> let !c' = BU.unsafeDrop used c
                      in convert acc' $ Chunk c' cs
                | otherwise -- No more digits present
                  -> if valid then convert acc' s else Nothing
            Overflow -> Nothing

    convert !acc s =
        let !i = case r of
                -- See [Note @maxBound `quotRem` 10 == (q, 7)@]
                7 -> fromIntegral @Accum @a acc
                _ -> negate @a $ fromIntegral @Accum @a acc
         in Just (i, s)

It now looks like there's only path into calling _digits, but instead of reducing the number of definitions from 2 to 1, we get an increase from 2 to 3! This is because in:

consume cs (valid || used > 0) acc'

The argument valid || used > 0 floats out, and we end up with something along the lines of:

    case valid of
        True -> consume cs True acc'
        _    -> case used > 0 of
                True -> consume cs True acc'
                _    -> consume cs False acc'

each of which then gets inlined with the loop. Oops!

One solution is to switch to bitwise logical operations, rather than logical ||:

readWord !q !r =
    \ !s -> consume s 0 0
  where
    consume :: ByteString -> Int -> Accum -> Maybe (a, ByteString)
    -- All done
    consume s@Empty !valid !acc
        = if valid /= 0 then convert acc s else Nothing
    -- Process chunk
    consume s@(Chunk c@(BI.BS _ !len) cs) !valid !acc
        = case _digits q r c acc of
            Result used acc'
                | used == len -- try to read more digits
                  -> consume cs (valid .|. used) acc'
                | used > 0 -- valid input not entirely digits
                  -> let !c' = BU.unsafeDrop used c
                      in convert acc' $ Chunk c' cs
                | otherwise -- No more digits present
                  -> if valid /= 0 then convert acc' s else Nothing
            Overflow -> Nothing

    convert !acc s =
        let !i = case r of
                -- See [Note @maxBound `quotRem` 10 == (q, 7)@]
                7 -> fromIntegral @Accum @a acc
                _ -> negate @a $ fromIntegral @Accum @a acc
         in Just (i, s)

A more surprising outcome was duplication of the continuations of each

exit "path out" of the loop: Where the caller's continuations get forked into essentially identical join points. The original loop code was:
_digits :: Accum -> Accum -> BI.ByteString -> Accum -> Result
{-# INLINE _digits #-}
_digits !q !r !(BI.BS !fp !len) = \ !acc ->
    BI.accursedUnutterablePerformIO $
        BI.unsafeWithForeignPtr fp $ \ptr -> do
            let end = ptr `plusPtr` len
            go ptr end ptr acc
  where
    go start end = loop
      where
        loop !ptr !acc | ptr == end
            = return $ Result (ptr `minusPtr` start) acc
        loop !ptr !acc = do
            !d <- fromDigit <$> peek ptr
            if | d > 9                             -- No more digits
                 -> return $ Result (ptr `minusPtr` start) acc
               | acc < q
                 -> loop (ptr `plusPtr` 1) (acc * 10 + d)
               | acc > q
                 -> return Overflow
               | d <= r
                 -> loop (ptr `plusPtr` 1) (acc * 10 + d)
               | otherwise
                 -> return Overflow
        --
        fromDigit = \w -> fromIntegral w - 0x30 -- i.e. w - '0'

With two ways to return a Result, and two ways to return Overflow, each of these created separate join points for the respective continuations on the return value.

In the generated core, each copy of the loop then internally duplications the exit2/exit3 continuations:
-- RHS size: {terms: 533, types: 314, coercions: 0, joins: 12/20}
$w$j :: Word# -> ByteString -> Maybe (Int64, ByteString)
$w$j
  = \ (ww :: Word#) (ww1 :: ByteString) ->
      joinrec {
        $s$wconsume :: Word# -> ByteString -> Maybe (Int64, ByteString)
        $s$wconsume (sc :: Word#) (sc1 :: ByteString)
          = case sc1 of wild {
              Empty ->
                case ww of {
                  __DEFAULT -> Just (I64# (negateInt# (word2Int# sc)), Empty);
                  7## -> Just (I64# (word2Int# sc), Empty)
                };
              Chunk dt dt1 dt2 cs ->
                case dt2 of ds {
                  __DEFAULT ->
                    let {
                      end :: Addr#
                      end = plusAddr# dt ds } in
                    join {
                      exit :: State# RealWorld -> Maybe (Int64, ByteString)
                      exit (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                    join {
                      exit1 :: State# RealWorld -> Maybe (Int64, ByteString)
                      exit1 (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                    join {
                      exit2
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      exit2 (ww2 :: Addr#) (ww3 :: Word#) (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT ->
                          let {
                            dt3 :: Int#
                            dt3 = minusAddr# ww2 dt } in
                          case <=# dt3 0# of {
                            __DEFAULT ->
                              case <# dt3 ds of {
                                __DEFAULT -> jump $s$wconsume ww3 cs;
                                1# ->
                                  case ww of {
                                    __DEFAULT ->
                                      Just
                                        (I64# (negateInt# (word2Int# ww3)),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
                                    7## ->
                                      Just
                                        (I64# (word2Int# ww3),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
                                  }
                              };
                            1# ->
                              case ww of {
                                __DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), wild);
                                7## -> Just (I64# (word2Int# ww3), wild)
                              }
                          }
                          } } in
                    join {
                      exit3
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      exit3 (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
                        = case touch# dt1 w of { __DEFAULT ->
                          let {
                            dt3 :: Int#
                            dt3 = minusAddr# ww2 dt } in
                          case <=# dt3 0# of {
                            __DEFAULT ->
                              case <# dt3 ds of {
                                __DEFAULT -> jump $s$wconsume ww3 cs;
                                1# ->
                                  case ww of {
                                    __DEFAULT ->
                                      Just
                                        (I64# (negateInt# (word2Int# ww3)),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
                                    7## ->
                                      Just
                                        (I64# (word2Int# ww3),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
                                  }
                              };
                            1# ->
                              case ww of {
                                __DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), wild);
                                7## -> Just (I64# (word2Int# ww3), wild)
                              }
                          }
                          } } in
                    joinrec {
                      $wloop
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      $wloop (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
                        = case eqAddr# ww2 end of {
                            __DEFAULT ->
                              case readWord8OffAddr# ww2 0# w of { (# ipv, ipv1 #) ->
                              let {
                                ipv2 :: Word#
                                ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
                              case gtWord# ipv2 9## of {
                                __DEFAULT ->
                                  case ltWord# ww3 922337203685477580## of {
                                    __DEFAULT ->
                                      case gtWord# ww3 922337203685477580## of {
                                        __DEFAULT ->
                                          case leWord# ipv2 ww of {
                                            __DEFAULT -> jump exit ipv;
                                            1# ->
                                              jump $wloop
                                                (plusAddr# ww2 1#)
                                                (plusWord# (timesWord# ww3 10##) ipv2)
                                                ipv
                                          };
                                        1# -> jump exit1 ipv
                                      };
                                    1# ->
                                      jump $wloop
                                        (plusAddr# ww2 1#)
                                        (plusWord# (timesWord# ww3 10##) ipv2)
                                        ipv
                                  };
                                1# -> jump exit2 ww2 ww3 ipv
                              }
                              };
                            1# -> jump exit3 ww2 ww3 w
                          }; } in
                    jump $wloop dt sc realWorld#;
                  0# -> jump $s$wconsume sc cs
                }
            }; } in
      joinrec {
        $s$wconsume1 :: Word# -> ByteString -> Maybe (Int64, ByteString)
        $s$wconsume1 (sc :: Word#) (sc1 :: ByteString)
          = case sc1 of {
              Empty -> Nothing;
              Chunk dt dt1 dt2 cs ->
                case dt2 of ds {
                  __DEFAULT ->
                    let {
                      end :: Addr#
                      end = plusAddr# dt ds } in
                    join {
                      exit :: State# RealWorld -> Maybe (Int64, ByteString)
                      exit (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                    join {
                      exit1 :: State# RealWorld -> Maybe (Int64, ByteString)
                      exit1 (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                    join {
                      exit2
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      exit2 (ww2 :: Addr#) (ww3 :: Word#) (ipv :: State# RealWorld)
                        = case touch# dt1 ipv of { __DEFAULT ->
                          let {
                            dt3 :: Int#
                            dt3 = minusAddr# ww2 dt } in
                          case <=# dt3 0# of {
                            __DEFAULT ->
                              case <# dt3 ds of {
                                __DEFAULT -> jump $s$wconsume ww3 cs;
                                1# ->
                                  case ww of {
                                    __DEFAULT ->
                                      Just
                                        (I64# (negateInt# (word2Int# ww3)),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
                                    7## ->
                                      Just
                                        (I64# (word2Int# ww3),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
                                  }
                              };
                            1# -> Nothing
                          }
                          } } in
                    join {
                      exit3
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      exit3 (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
                        = case touch# dt1 w of { __DEFAULT ->
                          let {
                            dt3 :: Int#
                            dt3 = minusAddr# ww2 dt } in
                          case <=# dt3 0# of {
                            __DEFAULT ->
                              case <# dt3 ds of {
                                __DEFAULT -> jump $s$wconsume ww3 cs;
                                1# ->
                                  case ww of {
                                    __DEFAULT ->
                                      Just
                                        (I64# (negateInt# (word2Int# ww3)),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs);
                                    7## ->
                                      Just
                                        (I64# (word2Int# ww3),
                                         Chunk (plusAddr# dt dt3) dt1 (-# ds dt3) cs)
                                  }
                              };
                            1# -> Nothing
                          }
                          } } in
                    joinrec {
                      $wloop
                        :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                      $wloop (ww2 :: Addr#) (ww3 :: Word#) (w :: State# RealWorld)
                        = case eqAddr# ww2 end of {
                            __DEFAULT ->
                              case readWord8OffAddr# ww2 0# w of { (# ipv, ipv1 #) ->
                              let {
                                ipv2 :: Word#
                                ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
                              case gtWord# ipv2 9## of {
                                __DEFAULT ->
                                  case ltWord# ww3 922337203685477580## of {
                                    __DEFAULT ->
                                      case gtWord# ww3 922337203685477580## of {
                                        __DEFAULT ->
                                          case leWord# ipv2 ww of {
                                            __DEFAULT -> jump exit ipv;
                                            1# ->
                                              jump $wloop
                                                (plusAddr# ww2 1#)
                                                (plusWord# (timesWord# ww3 10##) ipv2)
                                                ipv
                                          };
                                        1# -> jump exit1 ipv
                                      };
                                    1# ->
                                      jump $wloop
                                        (plusAddr# ww2 1#)
                                        (plusWord# (timesWord# ww3 10##) ipv2)
                                        ipv
                                  };
                                1# -> jump exit2 ww2 ww3 ipv
                              }
                              };
                            1# -> jump exit3 ww2 ww3 w
                          }; } in
                    jump $wloop dt sc realWorld#;
                  0# -> jump $s$wconsume1 sc cs
                }
            }; } in
      jump $s$wconsume1 0## ww1

So to eliminate duplicate exit paths, the above had to be refactored as follows:

_digits :: Accum -> Accum -> BI.ByteString -> Accum -> Result
{-# INLINE _digits #-}
_digits !q !r !(BI.BS !fp !len) = \ !acc ->
    BI.accursedUnutterablePerformIO $
        BI.unsafeWithForeignPtr fp $ \ptr -> do
            let end = ptr `plusPtr` len
            go ptr end ptr acc
  where
    go start end = loop
      where
        loop !ptr !acc = getDigit >>= \ !d ->
            if | d > 9
                 -> return $ Result (ptr `minusPtr` start) acc
               | acc < q || acc == q && d <= r
                 -> loop (ptr `plusPtr` 1) (acc * 10 + d)
               | otherwise
                 -> return Overflow
          where
            fromDigit = \w -> fromIntegral w - 0x30 -- i.e. w - '0'
            --
            getDigit | ptr /= end = fromDigit <$> peek ptr
                     | otherwise  = pure 10  -- End of input
            {-# NOINLINE getDigit #-}

This required the somewhat artificial getDigit to be have a NOINLINE annotation, to avoid the code paths being split anyway.

The new Core output is much better (exit2/exit3 are the same, but too small to really matter, and may be optimised out in the assembly):
-- RHS size: {terms: 247, types: 174, coercions: 0, joins: 8/11}
$w$j :: Word# -> ByteString -> Maybe (Int64, ByteString)
$w$j
  = \ (ww :: Word#) (ww1 :: ByteString) ->
      join {
        exit :: Int# -> Word# -> Maybe (Int64, ByteString)
        exit (ww2 :: Int#) (ww3 :: Word#)
          = case ww2 of {
              __DEFAULT ->
                case ww of {
                  __DEFAULT -> Just (I64# (negateInt# (word2Int# ww3)), Empty);
                  7## -> Just (I64# (word2Int# ww3), Empty)
                };
              0# -> Nothing
            } } in
      join {
        exit1
          :: Int#
             -> ByteString
             -> Addr#
             -> ForeignPtrContents
             -> Int#
             -> ByteString
             -> Int#
             -> Word#
             -> Maybe (Int64, ByteString)
        exit1 (ww2 :: Int#)
              (wild :: ByteString)
              (dt :: Addr#)
              (dt1 :: ForeignPtrContents)
              (dt2 :: Int#)
              (cs :: ByteString)
              (dt3 :: Int#)
              (dt4 :: Word#)
          = case ># dt3 0# of {
              __DEFAULT ->
                case ww2 of {
                  __DEFAULT ->
                    case ww of {
                      __DEFAULT -> Just (I64# (negateInt# (word2Int# dt4)), wild);
                      7## -> Just (I64# (word2Int# dt4), wild)
                    };
                  0# -> Nothing
                };
              1# ->
                case ww of {
                  __DEFAULT ->
                    Just
                      (I64# (negateInt# (word2Int# dt4)),
                       Chunk (plusAddr# dt dt3) dt1 (-# dt2 dt3) cs);
                  7## ->
                    Just
                      (I64# (word2Int# dt4),
                       Chunk (plusAddr# dt dt3) dt1 (-# dt2 dt3) cs)
                }
            } } in
      joinrec {
        $wconsume
          :: ByteString -> Int# -> Word# -> Maybe (Int64, ByteString)
        $wconsume (w :: ByteString) (ww2 :: Int#) (ww3 :: Word#)
          = case w of wild {
              Empty -> jump exit ww2 ww3;
              Chunk dt dt1 dt2 cs ->
                let {
                  end :: Addr#
                  end = plusAddr# dt dt2 } in
                join {
                  exit2 :: State# RealWorld -> Maybe (Int64, ByteString)
                  exit2 (ipv :: State# RealWorld)
                    = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                join {
                  exit3 :: State# RealWorld -> Maybe (Int64, ByteString)
                  exit3 (ipv :: State# RealWorld)
                    = case touch# dt1 ipv of { __DEFAULT -> Nothing } } in
                join {
                  exit4
                    :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                  exit4 (ww4 :: Addr#) (ww5 :: Word#) (ipv :: State# RealWorld)
                    = case touch# dt1 ipv of { __DEFAULT ->
                      let {
                        dt3 :: Int#
                        dt3 = minusAddr# ww4 dt } in
                      case ==# dt3 dt2 of {
                        __DEFAULT -> jump exit1 ww2 wild dt dt1 dt2 cs dt3 ww5;
                        1# -> jump $wconsume cs (orI# ww2 dt3) ww5
                      }
                      } } in
                joinrec {
                  $wloop
                    :: Addr# -> Word# -> State# RealWorld -> Maybe (Int64, ByteString)
                  $wloop (ww4 :: Addr#) (ww5 :: Word#) (w1 :: State# RealWorld)
                    = join {
                        getDigit :: State# RealWorld -> Maybe (Int64, ByteString)
                        getDigit (eta :: State# RealWorld)
                          = case eqAddr# ww4 end of {
                              __DEFAULT ->
                                case readWord8OffAddr# ww4 0# eta of { (# ipv, ipv1 #) ->
                                let {
                                  ipv2 :: Word#
                                  ipv2 = minusWord# (word8ToWord# ipv1) 48## } in
                                case gtWord# ipv2 9## of {
                                  __DEFAULT ->
                                    case ltWord# ww5 922337203685477580## of {
                                      __DEFAULT ->
                                        case ww5 of {
                                          __DEFAULT -> jump exit2 ipv;
                                          922337203685477580## ->
                                            case leWord# ipv2 ww of {
                                              __DEFAULT -> jump exit3 ipv;
                                              1# ->
                                                jump $wloop
                                                  (plusAddr# ww4 1#)
                                                  (plusWord# 9223372036854775800## ipv2)
                                                  ipv
                                            }
                                        };
                                      1# ->
                                        jump $wloop
                                          (plusAddr# ww4 1#)
                                          (plusWord# (timesWord# ww5 10##) ipv2)
                                          ipv
                                    };
                                  1# -> jump exit4 ww4 ww5 ipv
                                }
                                };
                              1# -> jump exit4 ww4 ww5 eta
                            } } in
                      jump getDigit w1; } in
                jump $wloop dt ww3 realWorld#
            }; } in
      jump $wconsume ww1 0# 0##

but I'm somewhat reluctant to saddle the code with subtle optimisations, they are not easily maintained...

Edited by vdukhovni
To upload designs, you'll need to enable LFS and have an admin enable hashed storage. More information