From 7ec3240eabc246b38ff0dc24af0be4535d7b57cf Mon Sep 17 00:00:00 2001 From: Michael Snoyman <michael@snoyman.com> Date: Sun, 28 Jan 2018 15:27:52 +0200 Subject: [PATCH] Fix general bracket for ExceptT/ErrorT (#60) * Demonstrate broken ExceptT instance for MonadMask * Simplify generalBracket so it works for ExceptT The previous type signature was in fact invalid, as it did not allow for a valid instance for ExceptT (et al). In particular, in the case of a non-exceptional Left result, neither the release nor cleanup functions could be used, since: * No result value was available for release * No SomeException value was available for cleanup It appears that this less pleasing version of generalBracket is the only one that works for all the types we care about. Furthermore: once we accept this, we are now forced to perform some discarding of updated state and Monoid written values in the StateT, WriterT, and RWST instances. This seems inherent to making things compatible with ExceptT. An alternative to this is to simply remove the ExceptT and ErrorT instances, but that's contrary to what many users want it seems. * Doc cleanups @RyanGIScott's review --- exceptions.cabal | 3 +- src/Control/Monad/Catch.hs | 81 +++++++++++++++--------------- src/Control/Monad/Catch/Pure.hs | 4 +- tests/Control/Monad/Catch/Tests.hs | 17 ++++++- 4 files changed, 61 insertions(+), 44 deletions(-) diff --git a/exceptions.cabal b/exceptions.cabal index e7728da..5b24105 100644 --- a/exceptions.cabal +++ b/exceptions.cabal @@ -62,7 +62,8 @@ test-suite exceptions-tests template-haskell, transformers, transformers-compat, - mtl, + mtl >= 2.2, test-framework >= 0.8 && < 0.9, + test-framework-hunit >= 0.3 && < 0.4, test-framework-quickcheck2 >= 0.3 && < 0.4, QuickCheck >= 2.5 && < 2.12 diff --git a/src/Control/Monad/Catch.hs b/src/Control/Monad/Catch.hs index 1099e00..3677ebf 100644 --- a/src/Control/Monad/Catch.hs +++ b/src/Control/Monad/Catch.hs @@ -180,24 +180,24 @@ class MonadCatch m => MonadMask m where -- | A generalized version of the standard bracket function which allows -- distinguishing different exit cases. Instead of providing it a single - -- cleanup action, this function takes two different actions: one for the + -- release action, this function takes two different actions: one for the -- case of a successful run of the inner function, and one in the case of an - -- exception. The former function is provided the acquired value and the - -- inner function's result, and returns a new result value. The exception - -- cleanup function is provided both the acquired value and the exception - -- that was thrown. + -- exception. The former function is provided the acquired value, while + -- the exception release function is provided both the acquired value and + -- the exception that was thrown. The result values of both of these + -- functions are ignored. -- -- @since 0.9.0 generalBracket :: m a -- ^ acquire some resource - -> (a -> b -> m c) - -- ^ cleanup, no exception thrown - -> (a -> SomeException -> m ignored) - -- ^ cleanup, some exception thrown; the exception will be rethrown + -> (a -> m ignored1) + -- ^ release, no exception thrown + -> (a -> SomeException -> m ignored2) + -- ^ release, some exception thrown; the exception will be rethrown -> (a -> m b) -- ^ inner action to perform with the resource - -> m c + -> m b instance MonadThrow [] where throwM _ = [] @@ -218,7 +218,8 @@ instance MonadMask IO where result <- unmasked (use resource) `catch` \e -> do _ <- cleanup resource e throwM e - release resource result + _ <- release resource + return result instance MonadThrow STM where throwM = STM.throwSTM @@ -245,7 +246,9 @@ instance e ~ SomeException => MonadMask (Either e) where Right resource -> case use resource of Left e -> cleanup resource e >> Left e - Right result -> release resource result + Right result -> do + _ <- release resource + return result instance MonadThrow m => MonadThrow (IdentityT m) where throwM e = lift $ throwM e @@ -263,7 +266,7 @@ instance MonadMask m => MonadMask (IdentityT m) where generalBracket acquire release cleanup use = IdentityT $ generalBracket (runIdentityT acquire) - (\resource b -> runIdentityT (release resource b)) + (runIdentityT . release) (\resource e -> runIdentityT (cleanup resource e)) (\resource -> runIdentityT (use resource)) @@ -283,7 +286,12 @@ instance MonadMask m => MonadMask (LazyS.StateT s m) where generalBracket acquire release cleanup use = LazyS.StateT $ \s0 -> generalBracket (LazyS.runStateT acquire s0) - (\(resource, _) (b1, s1) -> LazyS.runStateT (release resource b1) s1) + + -- Note that we're reverting to s1 here, the state after the + -- acquire step, and _not_ getting the state from the successful + -- run of the inner action. This is because we may be on top of + -- something like ExceptT, where no updated state is available. + (\(resource, s1) -> LazyS.runStateT (release resource) s1) (\(resource, s1) e -> LazyS.runStateT (cleanup resource e) s1) (\(resource, s1) -> LazyS.runStateT (use resource) s1) @@ -303,7 +311,7 @@ instance MonadMask m => MonadMask (StrictS.StateT s m) where generalBracket acquire release cleanup use = StrictS.StateT $ \s0 -> generalBracket (StrictS.runStateT acquire s0) - (\(resource, _) (b1, s1) -> StrictS.runStateT (release resource b1) s1) + (\(resource, s1) -> StrictS.runStateT (release resource) s1) (\(resource, s1) e -> StrictS.runStateT (cleanup resource e) s1) (\(resource, s1) -> StrictS.runStateT (use resource) s1) @@ -323,7 +331,7 @@ instance MonadMask m => MonadMask (ReaderT r m) where generalBracket acquire release cleanup use = ReaderT $ \r -> generalBracket (runReaderT acquire r) - (\resource b -> runReaderT (release resource b) r) + (\resource -> runReaderT (release resource) r) (\resource e -> runReaderT (cleanup resource e) r) (\resource -> runReaderT (use resource) r) @@ -343,9 +351,9 @@ instance (MonadMask m, Monoid w) => MonadMask (StrictW.WriterT w m) where generalBracket acquire release cleanup use = StrictW.WriterT $ generalBracket (StrictW.runWriterT acquire) - (\(resource, _) (b1, w1) -> do - (b2, w2) <- StrictW.runWriterT (release resource b1) - return (b2, mappend w1 w2)) + -- NOTE: The updated writer values here are actually going to be + -- lost, as the return value of this cleanup is discarded + (StrictW.runWriterT . release . fst) (\(resource, w1) e -> do (a, w2) <- StrictW.runWriterT (cleanup resource e) return (a, mappend w1 w2)) @@ -369,9 +377,7 @@ instance (MonadMask m, Monoid w) => MonadMask (LazyW.WriterT w m) where generalBracket acquire release cleanup use = LazyW.WriterT $ generalBracket (LazyW.runWriterT acquire) - (\(resource, _) (b1, w1) -> do - (b2, w2) <- LazyW.runWriterT (release resource b1) - return (b2, mappend w1 w2)) + (LazyW.runWriterT . release . fst) (\(resource, w1) e -> do (a, w2) <- LazyW.runWriterT (cleanup resource e) return (a, mappend w1 w2)) @@ -395,9 +401,8 @@ instance (MonadMask m, Monoid w) => MonadMask (LazyRWS.RWST r w s m) where generalBracket acquire release cleanup use = LazyRWS.RWST $ \r s0 -> generalBracket (LazyRWS.runRWST acquire r s0) - (\(resource, _, _) (b1, s1, w1) -> do - (b2, s2, w2) <- LazyRWS.runRWST (release resource b1) r s1 - return (b2, s2, mappend w1 w2)) + -- All comments from StateT and WriterT apply here too + (\(resource, s1, _) -> LazyRWS.runRWST (release resource) r s1) (\(resource, s1, w1) e -> do (a, s2, w2) <- LazyRWS.runRWST (cleanup resource e) r s1 return (a, s2, mappend w1 w2)) @@ -421,9 +426,7 @@ instance (MonadMask m, Monoid w) => MonadMask (StrictRWS.RWST r w s m) where generalBracket acquire release cleanup use = StrictRWS.RWST $ \r s0 -> generalBracket (StrictRWS.runRWST acquire r s0) - (\(resource, _, _) (b1, s1, w1) -> do - (b2, s2, w2) <- StrictRWS.runRWST (release resource b1) r s1 - return (b2, s2, mappend w1 w2)) + (\(resource, s1, _) -> StrictRWS.runRWST (release resource) r s1) (\(resource, s1, w1) e -> do (a, s2, w2) <- StrictRWS.runRWST (cleanup resource e) r s1 return (a, s2, mappend w1 w2)) @@ -465,11 +468,10 @@ instance (Error e, MonadMask m) => MonadMask (ErrorT e m) where generalBracket acquire release cleanup use = ErrorT $ generalBracket (runErrorT acquire) - (\eresource eresult -> - case (eresource, eresult) of - (Left e, _) -> return $ Left e - (_, Left e) -> return $ Left e - (Right resource, Right result) -> runErrorT (release resource result)) + (\eresource -> + case eresource of + Left _ -> return () -- nothing to release, it didn't succeed + Right resource -> runErrorT (release resource) >> return ()) (\eresource e -> case eresource of Left _ -> throwM e @@ -497,11 +499,10 @@ instance MonadMask m => MonadMask (ExceptT e m) where generalBracket acquire release cleanup use = ExceptT $ generalBracket (runExceptT acquire) - (\eresource eresult -> - case (eresource, eresult) of - (Left e, _) -> return $ Left e - (_, Left e) -> return $ Left e - (Right resource, Right result) -> runExceptT (release resource result)) + (\eresource -> + case eresource of + Left _ -> return () + Right resource -> runExceptT (release resource) >> return ()) (\eresource e -> case eresource of Left _ -> throwM e @@ -620,7 +621,7 @@ onException action handler = action `catchAll` \e -> handler >> throwM e bracket :: MonadMask m => m a -> (a -> m b) -> (a -> m c) -> m c bracket acquire release use = generalBracket acquire - (\a b -> release a >> return b) + release (\a _e -> release a) use @@ -639,6 +640,6 @@ finally action finalizer = bracket_ (return ()) finalizer action bracketOnError :: MonadMask m => m a -> (a -> m b) -> (a -> m c) -> m c bracketOnError acquire release use = generalBracket acquire - (\_ b -> return b) + (\_ -> return ()) (\a _e -> release a) use diff --git a/src/Control/Monad/Catch/Pure.hs b/src/Control/Monad/Catch/Pure.hs index f12be81..f17db1e 100644 --- a/src/Control/Monad/Catch/Pure.hs +++ b/src/Control/Monad/Catch/Pure.hs @@ -169,7 +169,9 @@ instance Monad m => MonadMask (CatchT m) where Left e -> do _ <- runCatchT (cleanup resource e) return $ Left e - Right result -> runCatchT (release resource result) + Right result -> do + _ <- runCatchT (release resource) + return $ Right result instance MonadState s m => MonadState s (CatchT m) where get = lift get diff --git a/tests/Control/Monad/Catch/Tests.hs b/tests/Control/Monad/Catch/Tests.hs index 619955a..9e98ea9 100644 --- a/tests/Control/Monad/Catch/Tests.hs +++ b/tests/Control/Monad/Catch/Tests.hs @@ -11,16 +11,21 @@ import Prelude hiding (catch) #endif import Control.Applicative ((<*>)) +import Control.Monad (unless) import Data.Data (Data, Typeable) +import Data.IORef (newIORef, writeIORef, readIORef) +import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Identity (IdentityT(..)) import Control.Monad.Reader (ReaderT(..)) import Control.Monad.List (ListT(..)) import Control.Monad.Trans.Maybe (MaybeT(..)) import Control.Monad.Error (ErrorT(..)) +import Control.Monad.Except (ExceptT(..), runExceptT) import Control.Monad.STM (STM, atomically) --import Control.Monad.Cont (ContT(..)) import Test.Framework (Test, testGroup) +import Test.Framework.Providers.HUnit (testCase) import Test.Framework.Providers.QuickCheck2 (testProperty) import Test.QuickCheck (Property, once) import Test.QuickCheck.Monadic (monadic, run, assert) @@ -67,9 +72,11 @@ testCatchJust MSpec { mspecRunner } = monadic mspecRunner $ do tests :: Test tests = testGroup "Control.Monad.Catch.Tests" $ - [ mkMonadCatch + ([ mkMonadCatch , mkCatchJust - ] <*> mspecs + ] <*> mspecs) ++ + [ testCase "ExceptT+Left" exceptTLeft + ] where mspecs = [ MSpec "IO" io @@ -102,3 +109,9 @@ tests = testGroup "Control.Monad.Catch.Tests" $ mkTestType name test = \spec -> testProperty (name ++ " " ++ mspecName spec) $ once $ test spec + + exceptTLeft = do + ref <- newIORef False + Left () <- runExceptT $ ExceptT (return $ Left ()) `finally` lift (writeIORef ref True) + val <- readIORef ref + unless val $ error "Looks like cleanup didn't happen" -- GitLab