diff --git a/cbits/cbits.c b/cbits/cbits.c index 6fa8bc123fc75675b75b02a4924fb731461a067b..3166a050f721bb1b551269812ce16a005e8a0a80 100644 --- a/cbits/cbits.c +++ b/cbits/cbits.c @@ -127,11 +127,18 @@ _hs_text_decode_latin1(uint16_t *dest, const uint8_t const *src, * state0 != UTF8_ACCEPT, UTF8_REJECT * */ -const uint8_t * -_hs_text_decode_utf8_state(uint16_t *const dest, size_t *destoff, - const uint8_t **const src, - const uint8_t *const srcend, - uint32_t *codepoint0, uint32_t *state0) +#if defined(__GNUC__) || defined(__clang__) +static inline uint8_t const * +_hs_text_decode_utf8_int(uint16_t *const dest, size_t *destoff, + const uint8_t const **src, const uint8_t const *srcend, + uint32_t *codepoint0, uint32_t *state0) + __attribute((always_inline)); +#endif + +static inline uint8_t const * +_hs_text_decode_utf8_int(uint16_t *const dest, size_t *destoff, + const uint8_t const **src, const uint8_t const *srcend, + uint32_t *codepoint0, uint32_t *state0) { uint16_t *d = dest + *destoff; const uint8_t *s = *src, *last = *src; @@ -185,10 +192,6 @@ _hs_text_decode_utf8_state(uint16_t *const dest, size_t *destoff, last = s; } - /* Invalid encoding, back up to the errant character */ - if (state == UTF8_REJECT) - s -= 1; - *destoff = d - dest; *codepoint0 = codepoint; *state0 = state; @@ -197,6 +200,19 @@ _hs_text_decode_utf8_state(uint16_t *const dest, size_t *destoff, return s; } +uint8_t const * +_hs_text_decode_utf8_state(uint16_t *const dest, size_t *destoff, + const uint8_t const **src, + const uint8_t const *srcend, + uint32_t *codepoint0, uint32_t *state0) +{ + uint8_t const *ret = _hs_text_decode_utf8_int(dest, destoff, src, srcend, + codepoint0, state0); + if (*state0 == UTF8_REJECT) + ret -=1; + return ret; +} + /* * Helper to decode buffer and discard final decoder state */ @@ -206,5 +222,10 @@ _hs_text_decode_utf8(uint16_t *const dest, size_t *destoff, { uint32_t codepoint; uint32_t state = UTF8_ACCEPT; - return _hs_text_decode_utf8_state(dest, destoff, &src, srcend, &codepoint, &state); + uint8_t const *ret = _hs_text_decode_utf8_int(dest, destoff, &src, srcend, + &codepoint, &state); + /* Back up if we have an incomplete or invalid encoding */ + if (state != UTF8_ACCEPT) + ret -= 1; + return ret; } diff --git a/changelog b/changelog index e17e6abd2386335b0f38b72aa816180b7f74df89..f442690963a83b281e93a022ea11382d7e928c8a 100644 --- a/changelog +++ b/changelog @@ -1,3 +1,8 @@ +1.0.0.1 + + * decodeUtf8: Fixed a regression that caused us to incorrectly + identify truncated UTF-8 as valid (gh-61) + 1.0.0.0 * Added support for Unicode 6.3.0 to case conversion functions diff --git a/tests/Tests/Properties.hs b/tests/Tests/Properties.hs index 69d321893851a3f40b703c000556002dd1281a93..e34078bd6f3f0efc549daadd445bbf524af20dec 100644 --- a/tests/Tests/Properties.hs +++ b/tests/Tests/Properties.hs @@ -13,6 +13,7 @@ import Test.QuickCheck import Test.QuickCheck.Monadic import Text.Show.Functions () +import Control.Applicative ((<$>), (<*>)) import Control.Arrow ((***), second) import Control.Exception (catch) import Data.Char (chr, isDigit, isHexDigit, isLower, isSpace, isUpper, ord) @@ -110,23 +111,55 @@ t_utf8_incr = do E.Some t _ f' = f a in t : feedChunksOf n f' b --- This is a poor attempt to ensure that the error handling paths on --- decode are exercised in some way. Proper testing would be rather --- more involved. -t_utf8_err :: DecodeErr -> B.ByteString -> Property -t_utf8_err (DE _ de) bs = monadicIO $ do - l <- run $ let len = T.length (E.decodeUtf8With de bs) - in (len `seq` return (Right len)) `catch` - (\(e::UnicodeException) -> return (Left e)) - case l of - Left err -> assert $ length (show err) >= 0 - Right n -> assert $ n >= 0 +data Badness = Solo | Leading | Trailing + deriving (Eq, Show) + +instance Arbitrary Badness where + arbitrary = elements [Solo, Leading, Trailing] + +t_utf8_err :: Badness -> DecodeErr -> Property +t_utf8_err bad de = do + let gen = case bad of + Solo -> genInvalidUTF8 + Leading -> B.append <$> genInvalidUTF8 <*> genUTF8 + Trailing -> B.append <$> genUTF8 <*> genInvalidUTF8 + genUTF8 = E.encodeUtf8 <$> genUnicode + forAll gen $ \bs -> do + onErr <- genDecodeErr de + monadicIO $ do + l <- run $ let len = T.length (E.decodeUtf8With onErr bs) + in (len `seq` return (Right len)) `catch` + (\(e::UnicodeException) -> return (Left e)) + assert $ case l of + Left err -> length (show err) >= 0 + Right _ -> de /= Strict t_utf8_err' :: B.ByteString -> Property t_utf8_err' bs = monadicIO . assert $ case E.decodeUtf8' bs of Left err -> length (show err) >= 0 Right t -> T.length t >= 0 +genInvalidUTF8 :: Gen B.ByteString +genInvalidUTF8 = B.pack <$> oneof [ + -- invalid leading byte of a 2-byte sequence + (:) <$> choose (0xC0, 0xC1) <*> upTo 1 contByte + -- invalid leading byte of a 4-byte sequence + , (:) <$> choose (0xF5, 0xFF) <*> upTo 3 contByte + -- continuation bytes without a start byte + , listOf1 contByte + -- short 2-byte sequence + , (:[]) <$> choose (0xC2, 0xDF) + -- short 3-byte sequence + , (:) <$> choose (0xE0, 0xEF) <*> upTo 1 contByte + -- short 4-byte sequence + , (:) <$> choose (0xF0, 0xF4) <*> upTo 2 contByte + ] + where + contByte = (0x80 +) <$> choose (0, 0x3f) + upTo n gen = do + k <- choose (0,n) + vectorOf k gen + s_Eq s = (s==) `eq` ((S.streamList s==) . S.streamList) where _types = s :: String sf_Eq p s = diff --git a/tests/Tests/QuickCheckUtils.hs b/tests/Tests/QuickCheckUtils.hs index 1920142fa844351cd6f38c3b5496c98034119d86..b433a51228164a5c13c9c49bad6c88b45f8196b1 100644 --- a/tests/Tests/QuickCheckUtils.hs +++ b/tests/Tests/QuickCheckUtils.hs @@ -18,6 +18,7 @@ module Tests.QuickCheckUtils , integralRandomR , DecodeErr (..) + , genDecodeErr , Stringy (..) , eq @@ -194,16 +195,17 @@ integralRandomR (a,b) g = case randomR (fromIntegral a :: Integer, fromIntegral b :: Integer) g of (x,h) -> (fromIntegral x, h) -data DecodeErr = DE String T.OnDecodeError +data DecodeErr = Lenient | Ignore | Strict | Replace + deriving (Show, Eq) -instance Show DecodeErr where - show (DE d _) = "DE " ++ d +genDecodeErr :: DecodeErr -> Gen T.OnDecodeError +genDecodeErr Lenient = return T.lenientDecode +genDecodeErr Ignore = return T.ignore +genDecodeErr Strict = return T.strictDecode +genDecodeErr Replace = arbitrary instance Arbitrary DecodeErr where - arbitrary = oneof [ return $ DE "lenient" T.lenientDecode - , return $ DE "ignore" T.ignore - , return $ DE "strict" T.strictDecode - , DE "replace" `fmap` arbitrary ] + arbitrary = elements [Lenient, Ignore, Strict, Replace] class Stringy s where packS :: String -> s