diff --git a/Control/Concurrent/STM/TArray.hs b/Control/Concurrent/STM/TArray.hs index 0755dd1ad0092917557fa1171d83276a2393173d..4ac2db4f9e81326d09b7dbbef2b6cf274e719d3b 100644 --- a/Control/Concurrent/STM/TArray.hs +++ b/Control/Concurrent/STM/TArray.hs @@ -4,6 +4,12 @@ {-# LANGUAGE Trustworthy #-} #endif +#define HAS_UNLIFTED_ARRAY defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 904 + +#if HAS_UNLIFTED_ARRAY +{-# LANGUAGE MagicHash, UnboxedTuples #-} +#endif + ----------------------------------------------------------------------------- -- | -- Module : Control.Concurrent.STM.TArray @@ -14,7 +20,7 @@ -- Stability : experimental -- Portability : non-portable (requires STM) -- --- TArrays: transactional arrays, for use in the STM monad +-- TArrays: transactional arrays, for use in the STM monad. -- ----------------------------------------------------------------------------- @@ -22,32 +28,79 @@ module Control.Concurrent.STM.TArray ( TArray ) where -import Data.Array (Array, bounds) -import Data.Array.Base (listArray, unsafeAt, MArray(..), - IArray(numElements)) -import Data.Ix (rangeSize) +import Control.Monad.STM (STM, atomically) import Data.Typeable (Typeable) -import Control.Concurrent.STM.TVar (TVar, newTVar, newTVarIO, readTVar, readTVarIO, writeTVar) -#ifdef __GLASGOW_HASKELL__ -import GHC.Conc (STM, atomically) +#if HAS_UNLIFTED_ARRAY +import Control.Concurrent.STM.TVar (readTVar, readTVarIO, writeTVar) +import Data.Array.Base (safeRangeSize, MArray(..)) +import Data.Ix (Ix) +import GHC.Conc (STM(..), TVar(..)) +import GHC.Exts +import GHC.IO (IO(..)) #else -import Control.Sequential.STM (STM, atomically) +import Control.Concurrent.STM.TVar (TVar, newTVar, newTVarIO, readTVar, readTVarIO, writeTVar) +import Data.Array (Array, bounds, listArray) +import Data.Array.Base (safeRangeSize, unsafeAt, MArray(..), IArray(numElements)) #endif --- |TArray is a transactional array, supporting the usual 'MArray' +-- | 'TArray' is a transactional array, supporting the usual 'MArray' -- interface for mutable arrays. -- --- It is currently implemented as @Array ix (TVar e)@, --- but it may be replaced by a more efficient implementation in the future --- (the interface will remain the same, however). --- +-- It is conceptually implemented as @Array i (TVar e)@. +#if HAS_UNLIFTED_ARRAY +data TArray i e = TArray + !i -- lower bound + !i -- upper bound + !Int -- size + (Array# (TVar# RealWorld e)) + deriving (Typeable) + +instance (Eq i, Eq e) => Eq (TArray i e) where + (TArray l1 u1 n1 arr1#) == (TArray l2 u2 n2 arr2#) = + -- each `TArray` has its own `TVar`s, so it's sufficient to compare the first one + if n1 == 0 then n2 == 0 else l1 == l2 && u1 == u2 && isTrue# (sameTVar# (unsafeFirstT arr1#) (unsafeFirstT arr2#)) + where + unsafeFirstT :: Array# (TVar# RealWorld e) -> TVar# RealWorld e + unsafeFirstT arr# = case indexArray# arr# 0# of (# e #) -> e + +newTArray# :: Ix i => (i, i) -> e -> State# RealWorld -> (# State# RealWorld, TArray i e #) +newTArray# b@(l, u) e = \s1# -> + case safeRangeSize b of + n@(I# n#) -> case newTVar# e s1# of + (# s2#, initial_tvar# #) -> case newArray# n# initial_tvar# s2# of + (# s3#, marr# #) -> + let go i# = \s4# -> case newTVar# e s4# of + (# s5#, tvar# #) -> case writeArray# marr# i# tvar# s5# of + s6# -> if isTrue# (i# ==# n# -# 1#) then s6# else go (i# +# 1#) s6# + in case unsafeFreezeArray# marr# (if n <= 1 then s3# else go 1# s3#) of + (# s7#, arr# #) -> (# s7#, TArray l u n arr# #) + +instance MArray TArray e STM where + getBounds (TArray l u _ _) = return (l, u) + getNumElements (TArray _ _ n _) = return n + newArray b e = STM $ newTArray# b e + unsafeRead (TArray _ _ _ arr#) (I# i#) = case indexArray# arr# i# of + (# tvar# #) -> readTVar (TVar tvar#) + unsafeWrite (TArray _ _ _ arr#) (I# i#) e = case indexArray# arr# i# of + (# tvar# #) -> writeTVar (TVar tvar#) e + +-- | Writes are slow in `IO`. +instance MArray TArray e IO where + getBounds (TArray l u _ _) = return (l, u) + getNumElements (TArray _ _ n _) = return n + newArray b e = IO $ newTArray# b e + unsafeRead (TArray _ _ _ arr#) (I# i#) = case indexArray# arr# i# of + (# tvar# #) -> readTVarIO (TVar tvar#) + unsafeWrite (TArray _ _ _ arr#) (I# i#) e = case indexArray# arr# i# of + (# tvar# #) -> atomically $ writeTVar (TVar tvar#) e +#else newtype TArray i e = TArray (Array i (TVar e)) deriving (Eq, Typeable) instance MArray TArray e STM where getBounds (TArray a) = return (bounds a) getNumElements (TArray a) = return (numElements a) newArray b e = do - a <- rep (rangeSize b) (newTVar e) + a <- rep (safeRangeSize b) (newTVar e) return $ TArray (listArray b a) unsafeRead (TArray a) i = readTVar $ unsafeAt a i unsafeWrite (TArray a) i e = writeTVar (unsafeAt a i) e @@ -59,15 +112,15 @@ instance MArray TArray e IO where getBounds (TArray a) = return (bounds a) getNumElements (TArray a) = return (numElements a) newArray b e = do - a <- rep (rangeSize b) (newTVarIO e) + a <- rep (safeRangeSize b) (newTVarIO e) return $ TArray (listArray b a) unsafeRead (TArray a) i = readTVarIO $ unsafeAt a i unsafeWrite (TArray a) i e = atomically $ writeTVar (unsafeAt a i) e {-# INLINE newArray #-} --- | Like 'replicateM' but uses an accumulator to prevent stack overflows. --- Unlike 'replicateM' the returned list is in reversed order. +-- | Like 'replicateM', but uses an accumulator to prevent stack overflows. +-- Unlike 'replicateM', the returned list is in reversed order. -- This doesn't matter though since this function is only used to create -- arrays with identical elements. rep :: Monad m => Int -> m a -> m [a] @@ -76,4 +129,5 @@ rep n m = go n [] go 0 xs = return xs go i xs = do x <- m - go (i-1) (x:xs) + go (i - 1) (x : xs) +#endif diff --git a/changelog.md b/changelog.md index 4ae61747e8e91b20ea0988093c8dfdb29038fd14..c16002bf90b41508e88f7b00b81f863ecb1377e0 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,8 @@ ## Upcoming * Fix strictness of `stateTVar` ([#69](https://github.com/haskell/stm/pull/69)) + * Rewrite `TBQueue` to use arrays ([#70](https://github.com/haskell/stm/pull/70)) + * Use unlifted `Array#` for `TArray` ([#66](https://github.com/haskell/stm/pull/66)) ## 2.5.1.0 *Aug 2022*