From 933e7bf5dbef908237ee50ce6b31f05f3bc98d3c Mon Sep 17 00:00:00 2001
From: Oleg Grenrus <oleg.grenrus@iki.fi>
Date: Tue, 19 Nov 2019 17:45:59 +0200
Subject: [PATCH] Make Async compat and use it to clean-up rawSystemStdInOut

---
 Cabal/Cabal.cabal                  |   1 +
 Cabal/Distribution/Compat/Async.hs | 139 +++++++++++++++++++++++++++++
 Cabal/Distribution/Simple/Utils.hs |  98 ++++++++++----------
 3 files changed, 188 insertions(+), 50 deletions(-)
 create mode 100644 Cabal/Distribution/Compat/Async.hs

diff --git a/Cabal/Cabal.cabal b/Cabal/Cabal.cabal
index 3827edab73..7bdfac8a14 100644
--- a/Cabal/Cabal.cabal
+++ b/Cabal/Cabal.cabal
@@ -534,6 +534,7 @@ library
     Distribution.Backpack.Id
     Distribution.Utils.UnionFind
     Distribution.Utils.Base62
+    Distribution.Compat.Async
     Distribution.Compat.CopyFile
     Distribution.Compat.GetShortPathName
     Distribution.Compat.MD5
diff --git a/Cabal/Distribution/Compat/Async.hs b/Cabal/Distribution/Compat/Async.hs
new file mode 100644
index 0000000000..896a91e99d
--- /dev/null
+++ b/Cabal/Distribution/Compat/Async.hs
@@ -0,0 +1,139 @@
+{-# LANGUAGE CPP                #-}
+{-# LANGUAGE DeriveDataTypeable #-}
+-- | 'Async', yet using 'MVar's.
+--
+-- Adopted from @async@ library
+-- Copyright (c) 2012, Simon Marlow
+-- Licensed under BSD-3-Clause
+--
+module Distribution.Compat.Async (
+    AsyncM,
+    withAsync, waitCatch,
+    wait, asyncThreadId,
+    cancel, uninterruptibleCancel, AsyncCancelled (..),
+    ) where
+
+import Control.Concurrent      (ThreadId, forkIO)
+import Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, readMVar)
+import Control.Exception
+       (BlockedIndefinitelyOnMVar (..), Exception (..), SomeException (..), catch, mask, throwIO,
+       throwTo, try, uninterruptibleMask_)
+import Control.Monad           (void)
+import Data.Typeable           (Typeable)
+import GHC.Exts                (inline)
+
+-- TODO: base?
+#if __GLASGOW_HASKELL__ >= 708
+import Control.Exception (asyncExceptionFromException, asyncExceptionToException)
+#endif
+
+-- | Async, but based on 'MVar', as we don't depend on @stm@.
+data AsyncM a = Async
+  { asyncThreadId :: {-# UNPACK #-} !ThreadId
+                  -- ^ Returns the 'ThreadId' of the thread running
+                  -- the given 'Async'.
+  , _asyncMVar    :: MVar (Either SomeException a)
+  }
+
+-- | Spawn an asynchronous action in a separate thread, and pass its
+-- @Async@ handle to the supplied function.  When the function returns
+-- or throws an exception, 'uninterruptibleCancel' is called on the @Async@.
+--
+-- > withAsync action inner = mask $ \restore -> do
+-- >   a <- async (restore action)
+-- >   restore (inner a) `finally` uninterruptibleCancel a
+--
+-- This is a useful variant of 'async' that ensures an @Async@ is
+-- never left running unintentionally.
+--
+-- Note: a reference to the child thread is kept alive until the call
+-- to `withAsync` returns, so nesting many `withAsync` calls requires
+-- linear memory.
+--
+withAsync :: IO a -> (AsyncM a -> IO b) -> IO b
+withAsync = inline withAsyncUsing forkIO
+
+withAsyncUsing :: (IO () -> IO ThreadId) -> IO a -> (AsyncM a -> IO b) -> IO b
+-- The bracket version works, but is slow.  We can do better by
+-- hand-coding it:
+withAsyncUsing doFork = \action inner -> do
+  var <- newEmptyMVar
+  mask $ \restore -> do
+    t <- doFork $ try (restore action) >>= putMVar var
+    let a = Async t var
+    r <- restore (inner a) `catchAll` \e -> do
+        uninterruptibleCancel a
+        throwIO e
+    uninterruptibleCancel a
+    return r
+
+-- | Wait for an asynchronous action to complete, and return its
+-- value.  If the asynchronous action threw an exception, then the
+-- exception is re-thrown by 'wait'.
+--
+-- > wait = atomically . waitSTM
+--
+{-# INLINE wait #-}
+wait :: AsyncM a -> IO a
+wait a = do
+    res <- waitCatch a
+    case res of
+        Left (SomeException e) -> throwIO e
+        Right x                -> return x
+
+-- | Wait for an asynchronous action to complete, and return either
+-- @Left e@ if the action raised an exception @e@, or @Right a@ if it
+-- returned a value @a@.
+--
+-- > waitCatch = atomically . waitCatchSTM
+--
+{-# INLINE waitCatch #-}
+waitCatch :: AsyncM a -> IO (Either SomeException a)
+waitCatch (Async _ var) = tryAgain (readMVar var)
+  where
+    -- See: https://github.com/simonmar/async/issues/14
+    tryAgain f = f `catch` \BlockedIndefinitelyOnMVar -> f
+
+catchAll :: IO a -> (SomeException -> IO a) -> IO a
+catchAll = catch
+
+-- | Cancel an asynchronous action by throwing the @AsyncCancelled@
+-- exception to it, and waiting for the `Async` thread to quit.
+-- Has no effect if the 'Async' has already completed.
+--
+-- > cancel a = throwTo (asyncThreadId a) AsyncCancelled <* waitCatch a
+--
+-- Note that 'cancel' will not terminate until the thread the 'Async'
+-- refers to has terminated. This means that 'cancel' will block for
+-- as long said thread blocks when receiving an asynchronous exception.
+--
+-- For example, it could block if:
+--
+-- * It's executing a foreign call, and thus cannot receive the asynchronous
+-- exception;
+-- * It's executing some cleanup handler after having received the exception,
+-- and the handler is blocking.
+{-# INLINE cancel #-}
+cancel :: AsyncM a -> IO ()
+cancel a@(Async t _) = do
+    throwTo t AsyncCancelled
+    void (waitCatch a)
+
+-- | The exception thrown by `cancel` to terminate a thread.
+data AsyncCancelled = AsyncCancelled
+  deriving (Show, Eq
+    , Typeable
+    )
+
+instance Exception AsyncCancelled where
+#if __GLASGOW_HASKELL__ >= 708
+  fromException = asyncExceptionFromException
+  toException = asyncExceptionToException
+#endif
+
+-- | Cancel an asynchronous action
+--
+-- This is a variant of `cancel`, but it is not interruptible.
+{-# INLINE uninterruptibleCancel #-}
+uninterruptibleCancel :: AsyncM a -> IO ()
+uninterruptibleCancel = uninterruptibleMask_ . cancel
diff --git a/Cabal/Distribution/Simple/Utils.hs b/Cabal/Distribution/Simple/Utils.hs
index e0079704e8..3a2e4f4e89 100644
--- a/Cabal/Distribution/Simple/Utils.hs
+++ b/Cabal/Distribution/Simple/Utils.hs
@@ -168,6 +168,7 @@ module Distribution.Simple.Utils (
 
 import Prelude ()
 import Distribution.Compat.Prelude
+import Control.Exception (SomeException)
 
 import Distribution.Utils.Generic
 import Distribution.Utils.IOData (IOData(..), IODataMode(..))
@@ -175,6 +176,7 @@ import qualified Distribution.Utils.IOData as IOData
 import Distribution.ModuleName as ModuleName
 import Distribution.System
 import Distribution.Version
+import Distribution.Compat.Async
 import Distribution.Compat.CopyFile
 import Distribution.Compat.Internal.TempFile
 import Distribution.Compat.Exception
@@ -200,8 +202,6 @@ import qualified Paths_Cabal (version)
 import Distribution.Pretty
 import Distribution.Parsec
 
-import Control.Concurrent.MVar
-    ( newEmptyMVar, putMVar, takeMVar )
 import Data.Typeable
     ( cast )
 import qualified Data.ByteString.Lazy as BS
@@ -227,8 +227,7 @@ import System.IO.Unsafe
 import qualified Control.Exception as Exception
 
 import Data.Time.Clock.POSIX (getPOSIXTime, POSIXTime)
-import Control.Exception (IOException, evaluate, throwIO)
-import Control.Concurrent (forkIO)
+import Control.Exception (IOException, evaluate, throwIO, fromException)
 import Numeric (showFFloat)
 import qualified System.Process as Process
          ( CreateProcess(..), StdStream(..), proc)
@@ -829,53 +828,52 @@ rawSystemStdInOut verbosity path args mcwd menv input outputMode = withFrozenCal
       -- fork off a couple threads to pull on the stderr and stdout
       -- so if the process writes to stderr we do not block.
 
-      err <- hGetContents errh
-
-      out <- IOData.hGetContents outh outputMode
-
-      mv <- newEmptyMVar
-      let force str = do
-            mberr <- Exception.try (evaluate (rnf str) >> return ())
-            putMVar mv (mberr :: Either IOError ())
-      _ <- forkIO $ force out
-      _ <- forkIO $ force err
-
-      -- push all the input, if any
-      case input of
-        Nothing -> return ()
-        Just inputData -> do
-          -- input mode depends on what the caller wants
-          IOData.hPutContents inh inputData
-          --TODO: this probably fails if the process refuses to consume
-          -- or if it closes stdin (eg if it exits)
-
-      -- wait for both to finish, in either order
-      mberr1 <- takeMVar mv
-      mberr2 <- takeMVar mv
-
-      -- wait for the program to terminate
-      exitcode <- waitForProcess pid
-      unless (exitcode == ExitSuccess) $
-        debug verbosity $ path ++ " returned " ++ show exitcode
-                       ++ if null err then "" else
-                          " with error message:\n" ++ err
-                       ++ case input of
-                            Nothing       -> ""
-                            Just d | IOData.null d -> ""
-                            Just (IODataText inp) -> "\nstdin input:\n" ++ inp
-                            Just (IODataBinary inp) -> "\nstdin input (binary):\n" ++ show inp
-
-      -- Check if we we hit an exception while consuming the output
-      -- (e.g. a text decoding error)
-      reportOutputIOError mberr1
-      reportOutputIOError mberr2
-
-      return (out, err, exitcode)
+      let force :: NFData a => a -> IO a
+          force str = do
+            evaluate (rnf str)
+            return str
+
+      withAsync (hGetContents errh >>= force) $ \errA -> withAsync (IOData.hGetContents outh outputMode >>= force) $ \outA -> do
+        -- push all the input, if any
+        case input of
+          Nothing        -> return ()
+          Just inputData -> do
+            -- input mode depends on what the caller wants
+            -- todo: ignoreSigPipe
+            IOData.hPutContents inh inputData
+            --TODO: this probably fails if the process refuses to consume
+            -- or if it closes stdin (eg if it exits)
+
+        -- wait for both to finish
+        mberr1 <- waitCatch outA
+        mberr2 <- waitCatch errA
+
+        err <- reportOutputIOError mberr2
+
+        -- wait for the program to terminate
+        exitcode <- waitForProcess pid
+
+        unless (exitcode == ExitSuccess) $
+          debug verbosity $ path ++ " returned " ++ show exitcode
+                         ++ if null err then "" else
+                            " with error message:\n" ++ err
+                         ++ case input of
+                              Nothing       -> ""
+                              Just d | IOData.null d -> ""
+                              Just (IODataText inp) -> "\nstdin input:\n" ++ inp
+                              Just (IODataBinary inp) -> "\nstdin input (binary):\n" ++ show inp
+
+        -- Check if we we hit an exception while consuming the output
+        -- (e.g. a text decoding error)
+        out <- reportOutputIOError mberr1
+
+        return (out, err, exitcode)
   where
-    reportOutputIOError :: Either IOError () -> NoCallStackIO ()
-    reportOutputIOError =
-      either (\e -> throwIO (ioeSetFileName e ("output of " ++ path)))
-             return
+    reportOutputIOError :: Either SomeException a -> NoCallStackIO a
+    reportOutputIOError (Right x) = return x
+    reportOutputIOError (Left exc) = case fromException exc of
+        Just ioe -> throwIO (ioeSetFileName ioe ("output of " ++ path))
+        Nothing  -> throwIO exc
 
 -- | Look for a program and try to find it's version number. It can accept
 -- either an absolute path or the name of a program binary, in which case we
-- 
GitLab