Commit fccd6043 authored by tharris's avatar tharris
Browse files

[project @ 2005-05-27 14:55:32 by tharris]

Add STM array-shuffle test
parent ba9dee2e
...@@ -108,3 +108,5 @@ test('conc049', only_compiler_types(['ghc']), compile_and_run, ['-package stm']) ...@@ -108,3 +108,5 @@ test('conc049', only_compiler_types(['ghc']), compile_and_run, ['-package stm'])
test('conc050', skip, compile_and_run, ['-package stm']) test('conc050', skip, compile_and_run, ['-package stm'])
test('conc051', normal, compile_and_run, ['']) test('conc051', normal, compile_and_run, [''])
test('conc052', normal, compile_and_run, ['-package stm'])
-- STM stress test
{-# OPTIONS -fffi #-}
module Main (main) where
import Foreign
import Control.Concurrent
import Control.Exception
import GHC.Conc -- Control.Concurrent.STM
import System.Random
import Data.Array
import Data.List
import GHC.Conc ( unsafeIOToSTM )
import Control.Monad ( when )
import System.IO
import System.IO.Unsafe
import System.Environment
import Foreign.C
-- | The number of array elements
n_elems :: Int
n_elems = 20
-- | The number of threads swapping elements
n_threads :: Int
n_threads = 2
-- | The number of swaps for each thread to perform
iterations :: Int
iterations = 20000
type Elements = Array Int (TVar Int)
thread :: TVar Int -> Elements -> IO ()
thread done elements = loop iterations
where loop 0 = atomically $ do x <- readTVar done; writeTVar done (x+1)
loop n = do
i1 <- randomRIO (1,n_elems)
i2 <- randomRIO (1,n_elems)
let e1 = elements ! i1
let e2 = elements ! i2
atomically $ do
e1_v <- readTVar e1
e2_v <- readTVar e2
writeTVar e1 e2_v
writeTVar e2 e1_v
loop (n-1)
await_end :: TVar Int -> IO ()
await_end done = atomically $ do x <- readTVar done
if (x == n_threads) then return () else retry
main = do
Foreign.newStablePtr stdout
setStdGen (read "526454551 6356")
let init_vals = [1..n_elems] -- take n_elems
tvars <- atomically $ mapM newTVar init_vals
let elements = listArray (1,n_elems) tvars
done <- atomically (newTVar 0)
sequence [ forkIO (thread done elements) | id <- [1..n_threads] ]
await_end done
fin_vals <- mapM (\t -> atomically $ readTVar t) (elems elements)
putStr("Before: ")
mapM (\v -> putStr ((show v) ++ " " )) init_vals
putStr("\nAfter: ")
mapM (\v -> putStr ((show v) ++ " " )) (sort fin_vals)
putStr("\n")
if ((sort fin_vals) == init_vals) then return () else throwDyn "Mismatch"
Before: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
After: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment