From 3e23fa2a60d1139dbd80c7ccce69469fbe990e34 Mon Sep 17 00:00:00 2001
From: Michael Snoyman <michael@snoyman.com>
Date: Tue, 6 Jun 2017 11:54:01 +0300
Subject: [PATCH] Add generalBracket to MonadMask

For motivation, see:
https://www.fpcomplete.com/blog/2017/02/monadmask-vs-monadbracket
---
 CHANGELOG.markdown              |   5 ++
 exceptions.cabal                |   2 +-
 src/Control/Monad/Catch.hs      | 136 +++++++++++++++++++++++++++++---
 src/Control/Monad/Catch/Pure.hs |  11 +++
 4 files changed, 140 insertions(+), 14 deletions(-)

diff --git a/CHANGELOG.markdown b/CHANGELOG.markdown
index d314d42..b37a182 100644
--- a/CHANGELOG.markdown
+++ b/CHANGELOG.markdown
@@ -1,3 +1,8 @@
+0.9.0
+-----
+* Add `generalBracket` to the `MonadMask` typeclass, allowing more
+  valid instances
+
 0.8.3
 -----
 * `MonadCatch` and `MonadMask` instances for `Either SomeException`
diff --git a/exceptions.cabal b/exceptions.cabal
index 63eebbc..f2d812e 100644
--- a/exceptions.cabal
+++ b/exceptions.cabal
@@ -1,6 +1,6 @@
 name:          exceptions
 category:      Control, Exceptions, Monad
-version:       0.8.3
+version:       0.9.0
 cabal-version: >= 1.8
 license:       BSD3
 license-file:  LICENSE
diff --git a/src/Control/Monad/Catch.hs b/src/Control/Monad/Catch.hs
index 546876c..94162fb 100644
--- a/src/Control/Monad/Catch.hs
+++ b/src/Control/Monad/Catch.hs
@@ -148,11 +148,10 @@ class MonadThrow m => MonadCatch m where
   -- 'ControlException.catch'.
   catch :: Exception e => m a -> (e -> m a) -> m a
 
--- | A class for monads which provide for the ability to account for all
--- possible exit points from a computation, and to mask asynchronous
--- exceptions. Continuation-based monads, and stacks such as @ErrorT e IO@
--- which provide for multiple failure modes, are invalid instances of this
--- class.
+-- | A class for monads which provide for the ability to account for
+-- all possible exit points from a computation, and to mask
+-- asynchronous exceptions. Continuation-based monads are invalid
+-- instances of this class.
 --
 -- Note that this package /does/ provide a @MonadMask@ instance for @CatchT@.
 -- This instance is /only/ valid if the base monad provides no ability to
@@ -179,6 +178,21 @@ class MonadCatch m => MonadMask m where
   -- and/or unkillable.
   uninterruptibleMask :: ((forall a. m a -> m a) -> m b) -> m b
 
+  -- | A generalized version of the standard bracket function which
+  -- allows distinguishing different exit cases.
+  --
+  -- @since 0.8.4
+  generalBracket
+    :: m a
+    -- ^ acquire some resource
+    -> (a -> b -> m b)
+    -- ^ cleanup, no exception thrown
+    -> (a -> SomeException -> m ignored)
+    -- ^ cleanup, some exception thrown. The exception will be rethrown
+    -> (a -> m b)
+    -- ^ inner action to perform with the resource
+    -> m b
+
 instance MonadThrow [] where
   throwM _ = []
 instance MonadThrow Maybe where
@@ -193,6 +207,12 @@ instance MonadCatch IO where
 instance MonadMask IO where
   mask = ControlException.mask
   uninterruptibleMask = ControlException.uninterruptibleMask
+  generalBracket acquire release cleanup use = mask $ \unmasked -> do
+    resource <- acquire
+    result <- unmasked (use resource) `catch` \e -> do
+      _ <- cleanup resource e
+      throwM e
+    release resource result
 
 instance MonadThrow STM where
   throwM = STM.throwSTM
@@ -213,6 +233,14 @@ instance e ~ SomeException => MonadMask (Either e) where
   mask f = f id
   uninterruptibleMask f = f id
 
+  generalBracket acquire release cleanup use =
+    case acquire of
+      Left e -> Left e
+      Right resource ->
+        case use resource of
+          Left e -> cleanup resource e >> Left e
+          Right result -> release resource result >> return result
+
 instance MonadThrow m => MonadThrow (IdentityT m) where
   throwM e = lift $ throwM e
 instance MonadCatch m => MonadCatch (IdentityT m) where
@@ -226,6 +254,13 @@ instance MonadMask m => MonadMask (IdentityT m) where
       where q :: (m a -> m a) -> IdentityT m a -> IdentityT m a
             q u = IdentityT . u . runIdentityT
 
+  generalBracket acquire release cleanup use = IdentityT $
+    generalBracket
+      (runIdentityT acquire)
+      (\resource b -> runIdentityT (release resource b))
+      (\resource e -> runIdentityT (cleanup resource e))
+      (\resource -> runIdentityT (use resource))
+
 instance MonadThrow m => MonadThrow (LazyS.StateT s m) where
   throwM e = lift $ throwM e
 instance MonadCatch m => MonadCatch (LazyS.StateT s m) where
@@ -239,6 +274,13 @@ instance MonadMask m => MonadMask (LazyS.StateT s m) where
       where q :: (m (a, s) -> m (a, s)) -> LazyS.StateT s m a -> LazyS.StateT s m a
             q u (LazyS.StateT b) = LazyS.StateT (u . b)
 
+  generalBracket acquire release cleanup use = LazyS.StateT $ \s0 ->
+    generalBracket
+      (LazyS.runStateT acquire s0)
+      (\(resource, _) (b1, s1) -> LazyS.runStateT (release resource b1) s1)
+      (\(resource, s1) e -> LazyS.runStateT (cleanup resource e) s1)
+      (\(resource, s1) -> LazyS.runStateT (use resource) s1)
+
 instance MonadThrow m => MonadThrow (StrictS.StateT s m) where
   throwM e = lift $ throwM e
 instance MonadCatch m => MonadCatch (StrictS.StateT s m) where
@@ -252,6 +294,13 @@ instance MonadMask m => MonadMask (StrictS.StateT s m) where
       where q :: (m (a, s) -> m (a, s)) -> StrictS.StateT s m a -> StrictS.StateT s m a
             q u (StrictS.StateT b) = StrictS.StateT (u . b)
 
+  generalBracket acquire release cleanup use = StrictS.StateT $ \s0 ->
+    generalBracket
+      (StrictS.runStateT acquire s0)
+      (\(resource, _) (b1, s1) -> StrictS.runStateT (release resource b1) s1)
+      (\(resource, s1) e -> StrictS.runStateT (cleanup resource e) s1)
+      (\(resource, s1) -> StrictS.runStateT (use resource) s1)
+
 instance MonadThrow m => MonadThrow (ReaderT r m) where
   throwM e = lift $ throwM e
 instance MonadCatch m => MonadCatch (ReaderT r m) where
@@ -265,6 +314,13 @@ instance MonadMask m => MonadMask (ReaderT r m) where
       where q :: (m a -> m a) -> ReaderT e m a -> ReaderT e m a
             q u (ReaderT b) = ReaderT (u . b)
 
+  generalBracket acquire release cleanup use = ReaderT $ \r ->
+    generalBracket
+      (runReaderT acquire r)
+      (\resource b -> runReaderT (release resource b) r)
+      (\resource e -> runReaderT (cleanup resource e) r)
+      (\resource -> runReaderT (use resource) r)
+
 instance (MonadThrow m, Monoid w) => MonadThrow (StrictW.WriterT w m) where
   throwM e = lift $ throwM e
 instance (MonadCatch m, Monoid w) => MonadCatch (StrictW.WriterT w m) where
@@ -278,6 +334,19 @@ instance (MonadMask m, Monoid w) => MonadMask (StrictW.WriterT w m) where
       where q :: (m (a, w) -> m (a, w)) -> StrictW.WriterT w m a -> StrictW.WriterT w m a
             q u b = StrictW.WriterT $ u (StrictW.runWriterT b)
 
+  generalBracket acquire release cleanup use = StrictW.WriterT $
+    generalBracket
+      (StrictW.runWriterT acquire)
+      (\(resource, _) (b1, w1) -> do
+        (b2, w2) <- StrictW.runWriterT (release resource b1)
+        return (b2, mappend w1 w2))
+      (\(resource, w1) e -> do
+        (a, w2) <- StrictW.runWriterT (cleanup resource e)
+        return (a, mappend w1 w2))
+      (\(resource, w1) -> do
+        (a, w2) <- StrictW.runWriterT (use resource)
+        return (a, mappend w1 w2))
+
 instance (MonadThrow m, Monoid w) => MonadThrow (LazyW.WriterT w m) where
   throwM e = lift $ throwM e
 instance (MonadCatch m, Monoid w) => MonadCatch (LazyW.WriterT w m) where
@@ -291,6 +360,19 @@ instance (MonadMask m, Monoid w) => MonadMask (LazyW.WriterT w m) where
       where q :: (m (a, w) -> m (a, w)) -> LazyW.WriterT w m a -> LazyW.WriterT w m a
             q u b = LazyW.WriterT $ u (LazyW.runWriterT b)
 
+  generalBracket acquire release cleanup use = LazyW.WriterT $
+    generalBracket
+      (LazyW.runWriterT acquire)
+      (\(resource, _) (b1, w1) -> do
+        (b2, w2) <- LazyW.runWriterT (release resource b1)
+        return (b2, mappend w1 w2))
+      (\(resource, w1) e -> do
+        (a, w2) <- LazyW.runWriterT (cleanup resource e)
+        return (a, mappend w1 w2))
+      (\(resource, w1) -> do
+        (a, w2) <- LazyW.runWriterT (use resource)
+        return (a, mappend w1 w2))
+
 instance (MonadThrow m, Monoid w) => MonadThrow (LazyRWS.RWST r w s m) where
   throwM e = lift $ throwM e
 instance (MonadCatch m, Monoid w) => MonadCatch (LazyRWS.RWST r w s m) where
@@ -304,6 +386,19 @@ instance (MonadMask m, Monoid w) => MonadMask (LazyRWS.RWST r w s m) where
       where q :: (m (a, s, w) -> m (a, s, w)) -> LazyRWS.RWST r w s m a -> LazyRWS.RWST r w s m a
             q u (LazyRWS.RWST b) = LazyRWS.RWST $ \ r s -> u (b r s)
 
+  generalBracket acquire release cleanup use = LazyRWS.RWST $ \r s0 ->
+    generalBracket
+      (LazyRWS.runRWST acquire r s0)
+      (\(resource, _, _) (b1, s1, w1) -> do
+        (b2, s2, w2) <- LazyRWS.runRWST (release resource b1) r s1
+        return (b2, s2, mappend w1 w2))
+      (\(resource, s1, w1) e -> do
+        (a, s2, w2) <- LazyRWS.runRWST (cleanup resource e) r s1
+        return (a, s2, mappend w1 w2))
+      (\(resource, s1, w1) -> do
+        (a, s2, w2) <- LazyRWS.runRWST (use resource) r s1
+        return (a, s2, mappend w1 w2))
+
 instance (MonadThrow m, Monoid w) => MonadThrow (StrictRWS.RWST r w s m) where
   throwM e = lift $ throwM e
 instance (MonadCatch m, Monoid w) => MonadCatch (StrictRWS.RWST r w s m) where
@@ -317,6 +412,19 @@ instance (MonadMask m, Monoid w) => MonadMask (StrictRWS.RWST r w s m) where
       where q :: (m (a, s, w) -> m (a, s, w)) -> StrictRWS.RWST r w s m a -> StrictRWS.RWST r w s m a
             q u (StrictRWS.RWST b) = StrictRWS.RWST $ \ r s -> u (b r s)
 
+  generalBracket acquire release cleanup use = StrictRWS.RWST $ \r s0 ->
+    generalBracket
+      (StrictRWS.runRWST acquire r s0)
+      (\(resource, _, _) (b1, s1, w1) -> do
+        (b2, s2, w2) <- StrictRWS.runRWST (release resource b1) r s1
+        return (b2, s2, mappend w1 w2))
+      (\(resource, s1, w1) e -> do
+        (a, s2, w2) <- StrictRWS.runRWST (cleanup resource e) r s1
+        return (a, s2, mappend w1 w2))
+      (\(resource, s1, w1) -> do
+        (a, s2, w2) <- StrictRWS.runRWST (use resource) r s1
+        return (a, s2, mappend w1 w2))
+
 -- Transformers which are only instances of MonadThrow and MonadCatch, not MonadMask
 instance MonadThrow m => MonadThrow (ListT m) where
   throwM = lift . throwM
@@ -448,11 +556,11 @@ onException action handler = action `catchAll` \e -> handler >> throwM e
 -- If an exception occurs during the use, the release still happens before the
 -- exception is rethrown.
 bracket :: MonadMask m => m a -> (a -> m b) -> (a -> m c) -> m c
-bracket acquire release use = mask $ \unmasked -> do
-  resource <- acquire
-  result <- unmasked (use resource) `onException` release resource
-  _ <- release resource
-  return result
+bracket acquire release use = generalBracket
+  acquire
+  (\a b -> release a >> return b)
+  (\a _e -> release a)
+  use
 
 -- | Version of 'bracket' without any value being passed to the second and
 -- third actions.
@@ -467,6 +575,8 @@ finally action finalizer = bracket_ (return ()) finalizer action
 -- | Like 'bracket', but only performs the final action if there was an
 -- exception raised by the in-between computation.
 bracketOnError :: MonadMask m => m a -> (a -> m b) -> (a -> m c) -> m c
-bracketOnError acquire release use = mask $ \unmasked -> do
-  resource <- acquire
-  unmasked (use resource) `onException` release resource
+bracketOnError acquire release use = generalBracket
+  acquire
+  (\_ b -> return b)
+  (\a _e -> release a)
+  use
diff --git a/src/Control/Monad/Catch/Pure.hs b/src/Control/Monad/Catch/Pure.hs
index 2f62cb7..f12be81 100644
--- a/src/Control/Monad/Catch/Pure.hs
+++ b/src/Control/Monad/Catch/Pure.hs
@@ -159,6 +159,17 @@ instance Monad m => MonadCatch (CatchT m) where
 instance Monad m => MonadMask (CatchT m) where
   mask a = a id
   uninterruptibleMask a = a id
+  generalBracket acquire release cleanup use = CatchT $ do
+    eresource <- runCatchT acquire
+    case eresource of
+      Left e -> return $ Left e
+      Right resource -> do
+        eresult <- runCatchT (use resource)
+        case eresult of
+          Left e -> do
+            _ <- runCatchT (cleanup resource e)
+            return $ Left e
+          Right result -> runCatchT (release resource result)
 
 instance MonadState s m => MonadState s (CatchT m) where
   get = lift get
-- 
GitLab