diff --git a/CHANGELOG.markdown b/CHANGELOG.markdown index b2d507a8560d6605ef6908fa781638acc7aedb9d..afca41a3f0bb22f2a2b4905d97a7303300e24805 100644 --- a/CHANGELOG.markdown +++ b/CHANGELOG.markdown @@ -3,6 +3,10 @@ * Split out `MonadMask` * Added `transformers` 0.4 support +0.5 +--- +* Added instances of `MonadThrow` for `ListT`, `MaybeT`, `ErrorT` and `ContT`. + 0.4 --- * Factored out a separate `MonadThrow`. diff --git a/exceptions.cabal b/exceptions.cabal index cceaafe30a09559cea978752f5129817446d77ad..0b115ea5fe8f2f9af08908267c6be26cae7ed460 100644 --- a/exceptions.cabal +++ b/exceptions.cabal @@ -34,6 +34,7 @@ source-repository head library build-depends: base >= 4.3 && < 5, + stm >= 2.2 && < 3, transformers >= 0.2 && < 0.5, mtl >= 2.0 && < 2.3 @@ -52,6 +53,7 @@ test-suite exceptions-tests type: exitcode-stdio-1.0 build-depends: base, + stm, transformers, mtl, test-framework >= 0.8 && < 0.9, diff --git a/src/Control/Monad/Catch.hs b/src/Control/Monad/Catch.hs index 5c26eb0ea1a6cccc3fb177b05752bfaf50d25e7b..262e3a487e47b783b80707fff5eee9558d2bd060 100644 --- a/src/Control/Monad/Catch.hs +++ b/src/Control/Monad/Catch.hs @@ -37,7 +37,7 @@ -- This is very similar to 'ErrorT' and 'MonadError', but based on features of -- "Control.Exception". In particular, it handles the complex case of -- asynchronous exceptions by including 'mask' in the typeclass. Note that the --- extensible extensions feature relies the RankNTypes language extension. +-- extensible extensions feature relies on the RankNTypes language extension. -------------------------------------------------------------------- module Control.Monad.Catch ( @@ -81,12 +81,14 @@ import Prelude hiding (catch, foldr) import Control.Exception (Exception(..), SomeException(..)) import qualified Control.Exception as ControlException +import qualified Control.Monad.STM as STM import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS import qualified Control.Monad.Trans.RWS.Strict as StrictRWS import qualified Control.Monad.Trans.State.Lazy as LazyS import qualified Control.Monad.Trans.State.Strict as StrictS import qualified Control.Monad.Trans.Writer.Lazy as LazyW import qualified Control.Monad.Trans.Writer.Strict as StrictW +import Control.Monad.STM (STM) import Control.Monad.Trans.List (ListT(..), runListT) import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT) import Control.Monad.Trans.Error (ErrorT(..), Error, runErrorT) @@ -189,16 +191,23 @@ instance MonadMask IO where mask = ControlException.mask uninterruptibleMask = ControlException.uninterruptibleMask +instance MonadThrow STM where + throwM = STM.throwSTM +instance MonadCatch STM where + catch = STM.catchSTM + instance MonadThrow m => MonadThrow (IdentityT m) where throwM e = lift $ throwM e instance MonadCatch m => MonadCatch (IdentityT m) where catch (IdentityT m) f = IdentityT (catch m (runIdentityT . f)) instance MonadMask m => MonadMask (IdentityT m) where mask a = IdentityT $ mask $ \u -> runIdentityT (a $ q u) - where q u = IdentityT . u . runIdentityT + where q :: (m a -> m a) -> IdentityT m a -> IdentityT m a + q u = IdentityT . u . runIdentityT uninterruptibleMask a = IdentityT $ uninterruptibleMask $ \u -> runIdentityT (a $ q u) - where q u = IdentityT . u . runIdentityT + where q :: (m a -> m a) -> IdentityT m a -> IdentityT m a + q u = IdentityT . u . runIdentityT instance MonadThrow m => MonadThrow (LazyS.StateT s m) where throwM e = lift $ throwM e @@ -206,10 +215,12 @@ instance MonadCatch m => MonadCatch (LazyS.StateT s m) where catch = LazyS.liftCatch catch instance MonadMask m => MonadMask (LazyS.StateT s m) where mask a = LazyS.StateT $ \s -> mask $ \u -> LazyS.runStateT (a $ q u) s - where q u (LazyS.StateT b) = LazyS.StateT (u . b) + where q :: (m (a, s) -> m (a, s)) -> LazyS.StateT s m a -> LazyS.StateT s m a + q u (LazyS.StateT b) = LazyS.StateT (u . b) uninterruptibleMask a = LazyS.StateT $ \s -> uninterruptibleMask $ \u -> LazyS.runStateT (a $ q u) s - where q u (LazyS.StateT b) = LazyS.StateT (u . b) + where q :: (m (a, s) -> m (a, s)) -> LazyS.StateT s m a -> LazyS.StateT s m a + q u (LazyS.StateT b) = LazyS.StateT (u . b) instance MonadThrow m => MonadThrow (StrictS.StateT s m) where throwM e = lift $ throwM e @@ -217,10 +228,12 @@ instance MonadCatch m => MonadCatch (StrictS.StateT s m) where catch = StrictS.liftCatch catch instance MonadMask m => MonadMask (StrictS.StateT s m) where mask a = StrictS.StateT $ \s -> mask $ \u -> StrictS.runStateT (a $ q u) s - where q u (StrictS.StateT b) = StrictS.StateT (u . b) + where q :: (m (a, s) -> m (a, s)) -> StrictS.StateT s m a -> StrictS.StateT s m a + q u (StrictS.StateT b) = StrictS.StateT (u . b) uninterruptibleMask a = StrictS.StateT $ \s -> uninterruptibleMask $ \u -> StrictS.runStateT (a $ q u) s - where q u (StrictS.StateT b) = StrictS.StateT (u . b) + where q :: (m (a, s) -> m (a, s)) -> StrictS.StateT s m a -> StrictS.StateT s m a + q u (StrictS.StateT b) = StrictS.StateT (u . b) instance MonadThrow m => MonadThrow (ReaderT r m) where throwM e = lift $ throwM e @@ -228,10 +241,12 @@ instance MonadCatch m => MonadCatch (ReaderT r m) where catch (ReaderT m) c = ReaderT $ \r -> m r `catch` \e -> runReaderT (c e) r instance MonadMask m => MonadMask (ReaderT r m) where mask a = ReaderT $ \e -> mask $ \u -> runReaderT (a $ q u) e - where q u (ReaderT b) = ReaderT (u . b) + where q :: (m a -> m a) -> ReaderT e m a -> ReaderT e m a + q u (ReaderT b) = ReaderT (u . b) uninterruptibleMask a = ReaderT $ \e -> uninterruptibleMask $ \u -> runReaderT (a $ q u) e - where q u (ReaderT b) = ReaderT (u . b) + where q :: (m a -> m a) -> ReaderT e m a -> ReaderT e m a + q u (ReaderT b) = ReaderT (u . b) instance (MonadThrow m, Monoid w) => MonadThrow (StrictW.WriterT w m) where throwM e = lift $ throwM e @@ -239,10 +254,12 @@ instance (MonadCatch m, Monoid w) => MonadCatch (StrictW.WriterT w m) where catch (StrictW.WriterT m) h = StrictW.WriterT $ m `catch ` \e -> StrictW.runWriterT (h e) instance (MonadMask m, Monoid w) => MonadMask (StrictW.WriterT w m) where mask a = StrictW.WriterT $ mask $ \u -> StrictW.runWriterT (a $ q u) - where q u b = StrictW.WriterT $ u (StrictW.runWriterT b) + where q :: (m (a, w) -> m (a, w)) -> StrictW.WriterT w m a -> StrictW.WriterT w m a + q u b = StrictW.WriterT $ u (StrictW.runWriterT b) uninterruptibleMask a = StrictW.WriterT $ uninterruptibleMask $ \u -> StrictW.runWriterT (a $ q u) - where q u b = StrictW.WriterT $ u (StrictW.runWriterT b) + where q :: (m (a, w) -> m (a, w)) -> StrictW.WriterT w m a -> StrictW.WriterT w m a + q u b = StrictW.WriterT $ u (StrictW.runWriterT b) instance (MonadThrow m, Monoid w) => MonadThrow (LazyW.WriterT w m) where throwM e = lift $ throwM e @@ -250,10 +267,12 @@ instance (MonadCatch m, Monoid w) => MonadCatch (LazyW.WriterT w m) where catch (LazyW.WriterT m) h = LazyW.WriterT $ m `catch ` \e -> LazyW.runWriterT (h e) instance (MonadMask m, Monoid w) => MonadMask (LazyW.WriterT w m) where mask a = LazyW.WriterT $ mask $ \u -> LazyW.runWriterT (a $ q u) - where q u b = LazyW.WriterT $ u (LazyW.runWriterT b) + where q :: (m (a, w) -> m (a, w)) -> LazyW.WriterT w m a -> LazyW.WriterT w m a + q u b = LazyW.WriterT $ u (LazyW.runWriterT b) uninterruptibleMask a = LazyW.WriterT $ uninterruptibleMask $ \u -> LazyW.runWriterT (a $ q u) - where q u b = LazyW.WriterT $ u (LazyW.runWriterT b) + where q :: (m (a, w) -> m (a, w)) -> LazyW.WriterT w m a -> LazyW.WriterT w m a + q u b = LazyW.WriterT $ u (LazyW.runWriterT b) instance (MonadThrow m, Monoid w) => MonadThrow (LazyRWS.RWST r w s m) where throwM e = lift $ throwM e @@ -261,10 +280,12 @@ instance (MonadCatch m, Monoid w) => MonadCatch (LazyRWS.RWST r w s m) where catch (LazyRWS.RWST m) h = LazyRWS.RWST $ \r s -> m r s `catch` \e -> LazyRWS.runRWST (h e) r s instance (MonadMask m, Monoid w) => MonadMask (LazyRWS.RWST r w s m) where mask a = LazyRWS.RWST $ \r s -> mask $ \u -> LazyRWS.runRWST (a $ q u) r s - where q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) + where q :: (m (a, s, w) -> m (a, s, w)) -> LazyRWS.RWST r w s m a -> LazyRWS.RWST r w s m a + q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) uninterruptibleMask a = LazyRWS.RWST $ \r s -> uninterruptibleMask $ \u -> LazyRWS.runRWST (a $ q u) r s - where q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) + where q :: (m (a, s, w) -> m (a, s, w)) -> LazyRWS.RWST r w s m a -> LazyRWS.RWST r w s m a + q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s) instance (MonadThrow m, Monoid w) => MonadThrow (StrictRWS.RWST r w s m) where throwM e = lift $ throwM e @@ -272,10 +293,12 @@ instance (MonadCatch m, Monoid w) => MonadCatch (StrictRWS.RWST r w s m) where catch (StrictRWS.RWST m) h = StrictRWS.RWST $ \r s -> m r s `catch` \e -> StrictRWS.runRWST (h e) r s instance (MonadMask m, Monoid w) => MonadMask (StrictRWS.RWST r w s m) where mask a = StrictRWS.RWST $ \r s -> mask $ \u -> StrictRWS.runRWST (a $ q u) r s - where q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) + where q :: (m (a, s, w) -> m (a, s, w)) -> StrictRWS.RWST r w s m a -> StrictRWS.RWST r w s m a + q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) uninterruptibleMask a = StrictRWS.RWST $ \r s -> uninterruptibleMask $ \u -> StrictRWS.runRWST (a $ q u) r s - where q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) + where q :: (m (a, s, w) -> m (a, s, w)) -> StrictRWS.RWST r w s m a -> StrictRWS.RWST r w s m a + q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s) -- Transformers which are only instances of MonadThrow and MonadCatch, not MonadMask instance MonadThrow m => MonadThrow (ListT m) where diff --git a/tests/Control/Monad/Catch/Tests.hs b/tests/Control/Monad/Catch/Tests.hs index 9f55b8f2c1b982fc1e7b60f14528fc58d21ca7b9..6773b4ed22e99bb3853e998f2a46cc6ecc1086dd 100644 --- a/tests/Control/Monad/Catch/Tests.hs +++ b/tests/Control/Monad/Catch/Tests.hs @@ -18,6 +18,7 @@ import Control.Monad.Reader (ReaderT(..)) import Control.Monad.List (ListT(..)) import Control.Monad.Trans.Maybe (MaybeT(..)) import Control.Monad.Error (ErrorT(..)) +import Control.Monad.STM (STM, atomically) --import Control.Monad.Cont (ContT(..)) import Test.Framework (Test, testGroup) import Test.Framework.Providers.QuickCheck2 (testProperty) @@ -84,6 +85,7 @@ tests = testGroup "Control.Monad.Catch.Tests" $ , MSpec "ListT IO" $ \m -> io $ fmap (\[x] -> x) (runListT m) , MSpec "MaybeT IO" $ \m -> io $ fmap (maybe undefined id) (runMaybeT m) , MSpec "ErrorT IO" $ \m -> io $ fmap (either error id) (runErrorT m) + , MSpec "STM" $ io . atomically --, MSpec "ContT IO" $ \m -> io $ runContT m return , MSpec "CatchT Indentity" $ fromRight . runCatch