From e10e97526bcf04acb726f851bd904c410c4acbf7 Mon Sep 17 00:00:00 2001
From: David Feuer <David.Feuer@gmail.com>
Date: Wed, 16 Sep 2020 12:11:43 -0400
Subject: [PATCH] Add MArray TArray e IO instance

Closes #35
---
 Control/Concurrent/STM/TArray.hs | 22 +++++++++++++++-------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/Control/Concurrent/STM/TArray.hs b/Control/Concurrent/STM/TArray.hs
index 1d26c21..73e9d54 100644
--- a/Control/Concurrent/STM/TArray.hs
+++ b/Control/Concurrent/STM/TArray.hs
@@ -23,15 +23,16 @@ module Control.Concurrent.STM.TArray (
 ) where
 
 import Data.Array (Array, bounds)
-import Data.Array.Base (listArray, arrEleBottom, unsafeAt, MArray(..),
+import Data.Array.Base (listArray, unsafeAt, MArray(..),
                         IArray(numElements))
 import Data.Ix (rangeSize)
 import Data.Typeable (Typeable)
-import Control.Concurrent.STM.TVar (TVar, newTVar, readTVar, writeTVar)
+import Control.Concurrent.STM.TVar (TVar, newTVar, readTVar, writeTVar
+                                   , newTVarIO, readTVarIO)
 #ifdef __GLASGOW_HASKELL__
-import GHC.Conc (STM)
+import GHC.Conc (STM, atomically)
 #else
-import Control.Sequential.STM (STM)
+import Control.Sequential.STM (STM, atomically)
 #endif
 
 -- |TArray is a transactional array, supporting the usual 'MArray'
@@ -48,13 +49,20 @@ instance MArray TArray e STM where
     newArray b e = do
         a <- rep (rangeSize b) (newTVar e)
         return $ TArray (listArray b a)
-    newArray_ b = do
-        a <- rep (rangeSize b) (newTVar arrEleBottom)
-        return $ TArray (listArray b a)
     unsafeRead (TArray a) i = readTVar $ unsafeAt a i
     unsafeWrite (TArray a) i e = writeTVar (unsafeAt a i) e
     getNumElements (TArray a) = return (numElements a)
 
+-- | Writes are slow in `IO`.
+instance MArray TArray e IO where
+    getBounds (TArray a) = return (bounds a)
+    newArray b e = do
+        a <- rep (rangeSize b) (newTVarIO e)
+        return $ TArray (listArray b a)
+    unsafeRead (TArray a) i = readTVarIO $ unsafeAt a i
+    unsafeWrite (TArray a) i e = atomically $ writeTVar (unsafeAt a i) e
+    getNumElements (TArray a) = return (numElements a)
+
 -- | Like 'replicateM' but uses an accumulator to prevent stack overflows.
 -- Unlike 'replicateM' the returned list is in reversed order.
 -- This doesn't matter though since this function is only used to create
-- 
GitLab