Commit ffaf10f9 authored by Bodigrim's avatar Bodigrim
Browse files

Use AVX/SSE instructions for length/take/drop

parent f2b1fc7e
/*
* Copyright (c) 2021 Andrew Lelechenko <andrew.lelechenko@gmail.com>
*/
#include <string.h>
#include <stdint.h>
#include <sys/types.h>
#ifdef __x86_64__
#include <emmintrin.h>
#include <xmmintrin.h>
#include <immintrin.h>
#include <cpuid.h>
#endif
#include <stdbool.h>
#ifndef __STDC_NO_ATOMICS__
#include <stdatomic.h>
#endif
bool has_avx512_vl_bw() {
#ifdef __x86_64__
uint32_t eax = 0, ebx = 0, ecx = 0, edx = 0;
__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
// https://en.wikipedia.org/wiki/CPUID#EAX=7,_ECX=0:_Extended_Features
const bool has_avx512_bw = ebx & (1 << 30);
const bool has_avx512_vl = ebx & (1 << 31);
// printf("cpuid=%d=cpuid\n", has_avx512_bw && has_avx512_vl);
return has_avx512_bw && has_avx512_vl;
#else
return false;
#endif
}
/*
measure_off_naive / measure_off_avx / measure_off_sse
take a UTF-8 sequence between src and srcend, and a number of characters cnt.
If the sequence is long enough to contain cnt characters, then return how many bytes
remained unconsumed. Otherwise, if the sequence is shorter, return
negated count of lacking characters. Cf. _hs_text_measure_off below.
*/
inline const ssize_t measure_off_naive(const uint8_t *src, const uint8_t *srcend, size_t cnt)
{
// Count leading bytes in 8 byte sequence
while (src < srcend - 7){
uint64_t w64;
memcpy(&w64, src, sizeof(uint64_t));
size_t leads = __builtin_popcountll(((w64 << 1) | ~w64) & 0x8080808080808080ULL);
if (cnt < leads) break;
cnt-= leads;
src+= 8;
}
// Skip until next leading byte
while (src < srcend){
uint8_t w8 = *src;
if ((int8_t)w8 >= -0x40) break;
src++;
}
// Finish up with tail
while (src < srcend && cnt > 0){
uint8_t leadByte = *src++;
cnt--;
src+= (leadByte >= 0xc0) + (leadByte >= 0xe0) + (leadByte >= 0xf0);
}
return cnt == 0 ? (ssize_t)(srcend - src) : (ssize_t)(- cnt);
}
#ifdef __x86_64__
__attribute__((target("avx512vl,avx512bw")))
const ssize_t measure_off_avx(const uint8_t *src, const uint8_t *srcend, size_t cnt)
{
while (src < srcend - 63){
__m512i w512 = _mm512_loadu_si512((__m512i *)src);
// Which bytes are either < 128 or >= 192?
uint64_t mask = _mm512_cmpgt_epi8_mask(w512, _mm512_set1_epi8(0xBF));
size_t leads = __builtin_popcountll(mask);
if (cnt < leads) break;
cnt-= leads;
src+= 64;
}
// Cannot proceed to measure_off_sse, because of AVX-SSE transition penalties
// https://software.intel.com/content/www/us/en/develop/articles/avoiding-avx-sse-transition-penalties.html
if (src < srcend - 31){
__m256i w256 = _mm256_loadu_si256((__m256i *)src);
uint32_t mask = _mm256_cmpgt_epi8_mask(w256, _mm256_set1_epi8(0xBF));
size_t leads = __builtin_popcountl(mask);
if (cnt >= leads){
cnt-= leads;
src+= 32;
}
}
if (src < srcend - 15){
__m128i w128 = _mm_maskz_loadu_epi16(0xFF, (__m128i *)src); // not _mm_loadu_si128; and GCC does not have _mm_loadu_epi16
uint16_t mask = _mm_cmpgt_epi8_mask(w128, _mm_set1_epi8(0xBF)); // not _mm_movemask_epi8
size_t leads = __builtin_popcountl(mask);
if (cnt >= leads){
cnt-= leads;
src+= 16;
}
}
return measure_off_naive(src, srcend, cnt);
}
#endif
const ssize_t measure_off_sse(const uint8_t *src, const uint8_t *srcend, size_t cnt)
{
#ifdef __x86_64__
while (src < srcend - 15){
__m128i w128 = _mm_loadu_si128((__m128i *)src);
// Which bytes are either < 128 or >= 192?
uint16_t mask = _mm_movemask_epi8(_mm_cmpgt_epi8(w128, _mm_set1_epi8(0xBF)));
size_t leads = __builtin_popcount(mask);
if (cnt < leads) break;
cnt-= leads;
src+= 16;
}
#endif
return measure_off_naive(src, srcend, cnt);
}
typedef const ssize_t (*measure_off_t) (const uint8_t*, const uint8_t*, size_t);
/*
_hs_text_measure_off takes a UTF-8 encoded buffer, specified by (src, off, len),
and a number of code points (aka characters) cnt. If the buffer is long enough
to contain cnt characters, then _hs_text_measure_off returns a non-negative number,
measuring their size in code units (aka bytes). If the buffer is shorter,
_hs_text_measure_off returns a non-positive number, which is a negated total count
of characters available in the buffer. If len = 0 or cnt = 0, this function returns 0
as well.
This scheme allows us to implement both take/drop and length with the same C function.
The input buffer (src, off, len) must be a valid UTF-8 sequence,
this condition is not checked.
*/
ssize_t _hs_text_measure_off(const uint8_t *src, size_t off, size_t len, size_t cnt) {
static _Atomic measure_off_t s_impl = (measure_off_t)NULL;
measure_off_t impl = atomic_load_explicit(&s_impl, memory_order_relaxed);
if (!impl) {
#ifdef __x86_64__
impl = has_avx512_vl_bw() ? measure_off_avx : measure_off_sse;
#else
impl = measure_off_sse;
#endif
atomic_store_explicit(&s_impl, impl, memory_order_relaxed);
}
ssize_t ret = (*impl)(src + off, src + off + len, cnt);
return ret >= 0 ? ((ssize_t)len - ret) : (- (cnt + ret));
}
{-# LANGUAGE BangPatterns, CPP, MagicHash, Rank2Types, UnboxedTuples, TypeFamilies #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE UnliftedFFITypes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
......@@ -196,13 +197,15 @@ module Data.Text
-- * Low level operations
, copy
, unpackCString#
, measureOff
) where
import Prelude (Char, Bool(..), Int, Maybe(..), String,
Eq(..), Ord(..), Ordering(..), (++),
Read(..),
(&&), (||), (+), (-), (.), ($), ($!), (>>),
not, return, otherwise, quot)
not, return, otherwise, quot, IO)
import Control.DeepSeq (NFData(rnf))
#if defined(ASSERTS)
import Control.Exception (assert)
......@@ -230,7 +233,7 @@ import Data.Text.Internal (Text(..), empty, firstf, mul, safe, text)
import Data.Text.Show (singleton, unpack, unpackCString#)
import qualified Prelude as P
import Data.Text.Unsafe (Iter(..), iter, iter_, lengthWord8, reverseIter,
reverseIter_, unsafeHead, unsafeTail)
reverseIter_, unsafeHead, unsafeTail, unsafeDupablePerformIO)
import Data.Text.Internal.Search (indices)
#if defined(__HADDOCK__)
import Data.ByteString (ByteString)
......@@ -238,11 +241,13 @@ import qualified Data.Text.Lazy as L
import Data.Int (Int64)
#endif
import Data.Word (Word8)
import GHC.Base (eqInt, neInt, gtInt, geInt, ltInt, leInt)
import Foreign.C.Types
import GHC.Base (eqInt, neInt, gtInt, geInt, ltInt, leInt, ByteArray#)
import qualified GHC.Exts as Exts
import qualified Language.Haskell.TH.Lib as TH
import qualified Language.Haskell.TH.Syntax as TH
import Text.Printf (PrintfArg, formatArg, formatString)
import System.Posix.Types (CSsize(..))
-- $setup
-- >>> import Data.Text
......@@ -538,7 +543,7 @@ length ::
HasCallStack =>
#endif
Text -> Int
length t = S.length (stream t)
length = P.negate . measureOff P.maxBound
{-# INLINE [1] length #-}
-- length needs to be phased after the compareN/length rules otherwise
-- it may inline before the rules have an opportunity to fire.
......@@ -1069,15 +1074,25 @@ take :: Int -> Text -> Text
take n t@(Text arr off len)
| n <= 0 = empty
| n >= len = t
| otherwise = text arr off (iterN n t)
| otherwise = let m = measureOff n t in if m >= 0 then text arr off m else t
{-# INLINE [1] take #-}
iterN :: Int -> Text -> Int
iterN n t@(Text _arr _off len) = loop 0 0
where loop !i !cnt
| i >= len || cnt >= n = i
| otherwise = loop (i+d) (cnt+1)
where d = iter_ t i
-- | /O(n)/ If @t@ is long enough to contain @n@ characters, 'measureOff' @n@ @t@
-- returns a non-negative number, measuring their size in 'Word8'. Otherwise,
-- if @t@ is shorter, return a non-positive number, which is a negated total count
-- of 'Char' available in @t@. If @t@ is empty or @n = 0@, return 0.
--
-- This function is used to implement 'take', 'drop', 'splitAt' and 'length'
-- and is useful on its own in streaming and parsing libraries.
measureOff :: Int -> Text -> Int
measureOff !n (Text (A.ByteArray arr) off len) = if len == 0 then 0 else
cSsizeToInt $ unsafeDupablePerformIO $
c_measure_off arr (intToCSize off) (intToCSize len) (intToCSize n)
-- | The input buffer (arr :: ByteArray#, off :: CSize, len :: CSize)
-- must specify a valid UTF-8 sequence, this condition is not checked.
foreign import ccall unsafe "_hs_text_measure_off" c_measure_off
:: ByteArray# -> CSize -> CSize -> CSize -> IO CSsize
-- | /O(n)/ 'takeEnd' @n@ @t@ returns the suffix remaining after
-- taking @n@ characters from the end of @t@.
......@@ -1110,8 +1125,8 @@ drop :: Int -> Text -> Text
drop n t@(Text arr off len)
| n <= 0 = t
| n >= len = empty
| otherwise = text arr (off+i) (len-i)
where i = iterN n t
| otherwise = if m >= 0 then text arr (off+m) (len-m) else mempty
where m = measureOff n t
{-# INLINE [1] drop #-}
-- | /O(n)/ 'dropEnd' @n@ @t@ returns the prefix remaining after
......@@ -1219,8 +1234,8 @@ splitAt :: Int -> Text -> (Text, Text)
splitAt n t@(Text arr off len)
| n <= 0 = (empty, t)
| n >= len = (t, empty)
| otherwise = let k = iterN n t
in (text arr off k, text arr (off+k) (len-k))
| otherwise = let m = measureOff n t in
if m >= 0 then (text arr off m, text arr (off+m) (len-m)) else (t, mempty)
-- | /O(n)/ 'span', applied to a predicate @p@ and text @t@, returns
-- a pair whose first element is the longest prefix (possibly empty)
......@@ -1786,6 +1801,11 @@ copy (Text arr off len) = Text (A.run go) 0 len
A.copyI len marr 0 arr off
return marr
intToCSize :: Int -> CSize
intToCSize = P.fromIntegral
cSsizeToInt :: CSsize -> Int
cSsizeToInt = P.fromIntegral
-------------------------------------------------
-- NOTE: the named chunk below used by doctest;
......
......@@ -203,6 +203,7 @@ import Prelude (Char, Bool(..), Maybe(..), String,
error, flip, fmap, fromIntegral, not, otherwise, quot)
import qualified Prelude as P
import Control.DeepSeq (NFData(..))
import Data.Bits (finiteBitSize)
import Data.Int (Int64)
import qualified Data.List as L
import Data.Char (isSpace)
......@@ -972,10 +973,15 @@ take i t0 = take' i t0
take' :: Int64 -> Text -> Text
take' 0 _ = Empty
take' _ Empty = Empty
take' n (Chunk t ts)
| n < len = Chunk (T.take (int64ToInt n) t) Empty
| otherwise = Chunk t (take' (n - len) ts)
where len = intToInt64 (T.length t)
take' n (Chunk t@(T.Text arr off _) ts)
| finiteBitSize (0 :: P.Int) == 64, m <- T.measureOff (int64ToInt n) t =
if m >= 0
then fromStrict (T.Text arr off m)
else Chunk t (take' (n + intToInt64 m) ts)
| n < l = Chunk (T.take (int64ToInt n) t) Empty
| otherwise = Chunk t (take' (n - l) ts)
where l = intToInt64 (T.length t)
{-# INLINE [1] take #-}
-- | /O(n)/ 'takeEnd' @n@ @t@ returns the suffix remaining after
......@@ -1009,10 +1015,15 @@ drop i t0
drop' :: Int64 -> Text -> Text
drop' 0 ts = ts
drop' _ Empty = Empty
drop' n (Chunk t ts)
| n < len = Chunk (T.drop (int64ToInt n) t) ts
| otherwise = drop' (n - len) ts
where len = intToInt64 (T.length t)
drop' n (Chunk t@(T.Text arr off len) ts)
| finiteBitSize (0 :: P.Int) == 64, m <- T.measureOff (int64ToInt n) t =
if m >= 0
then chunk (T.Text arr (off + m) (len - m)) ts
else drop' (n + intToInt64 m) ts
| n < l = Chunk (T.drop (int64ToInt n) t) ts
| otherwise = drop' (n - l) ts
where l = intToInt64 (T.length t)
{-# INLINE [1] drop #-}
-- | /O(n)/ 'dropEnd' @n@ @t@ returns the prefix remaining after
......
......@@ -65,6 +65,7 @@ flag developer
library
c-sources: cbits/cbits.c
cbits/measure_off.c
cbits/utils.c
include-dirs: include
hs-source-dirs: src
......@@ -131,6 +132,10 @@ library
ghc-options: -fno-ignore-asserts
cpp-options: -DASSERTS
-- https://gitlab.haskell.org/ghc/ghc/-/issues/19900
if os(windows)
extra-libraries: gcc_s
default-language: Haskell2010
default-extensions:
NondecreasingIndentation
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment