From cd8d45fb8c4132a5bf56b140a40d9f37f04cfd56 Mon Sep 17 00:00:00 2001
From: Johan Tibell <johan.tibell@gmail.com>
Date: Fri, 24 Aug 2012 16:56:12 -0700
Subject: [PATCH] Force the components of returned pairs

Some functions, like partition, return a pair of values. Before this
change these functions would do almost no work and return immediately,
due to suspending most of the work in closures. This could cause space
leaks.

Closes #14.
---
 Data/IntMap/Base.hs   |  92 +++++++++++++++++++++--------------
 Data/IntMap/Strict.hs |  98 ++++++++++++++++++++-----------------
 Data/IntSet/Base.hs   | 110 ++++++++++++++++++++++++++----------------
 Data/Map/Base.hs      |  92 ++++++++++++++++++++---------------
 Data/Map/Strict.hs    |  50 ++++++++++---------
 Data/Set/Base.hs      |  36 +++++++++-----
 Data/StrictPair.hs    |  13 +++--
 benchmarks/Map.hs     |   1 +
 8 files changed, 290 insertions(+), 202 deletions(-)

diff --git a/Data/IntMap/Base.hs b/Data/IntMap/Base.hs
index edf2dd03..b3f8864a 100644
--- a/Data/IntMap/Base.hs
+++ b/Data/IntMap/Base.hs
@@ -222,6 +222,8 @@ import Control.Applicative (Applicative(pure,(<*>)),(<$>))
 import Control.Monad ( liftM )
 import Control.DeepSeq (NFData(rnf))
 
+import Data.StrictPair
+
 #if __GLASGOW_HASKELL__
 import Text.Read
 import Data.Data (Data(..), mkNoRepType)
@@ -1402,16 +1404,18 @@ partition p m
 -- > partitionWithKey (\ k _ -> k > 7) (fromList [(5,"a"), (3,"b")]) == (empty, fromList [(3, "b"), (5, "a")])
 
 partitionWithKey :: (Key -> a -> Bool) -> IntMap a -> (IntMap a,IntMap a)
-partitionWithKey predicate t
-  = case t of
-      Bin p m l r
-        -> let (l1,l2) = partitionWithKey predicate l
-               (r1,r2) = partitionWithKey predicate r
-           in (bin p m l1 r1, bin p m l2 r2)
-      Tip k x
-        | predicate k x -> (t,Nil)
-        | otherwise     -> (Nil,t)
-      Nil -> (Nil,Nil)
+partitionWithKey predicate0 t0 = toPair $ go predicate0 t0
+  where
+    go predicate t
+      = case t of
+          Bin p m l r
+            -> let (l1 :*: l2) = go predicate l
+                   (r1 :*: r2) = go predicate r
+               in bin p m l1 r1 :*: bin p m l2 r2
+          Tip k x
+            | predicate k x -> (t :*: Nil)
+            | otherwise     -> (Nil :*: t)
+          Nil -> (Nil :*: Nil)
 
 -- | /O(n)/. Map values and collect the 'Just' results.
 --
@@ -1457,15 +1461,17 @@ mapEither f m
 -- >     == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])
 
 mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c)
-mapEitherWithKey f (Bin p m l r)
-  = (bin p m l1 r1, bin p m l2 r2)
+mapEitherWithKey f0 t0 = toPair $ go f0 t0
   where
-    (l1,l2) = mapEitherWithKey f l
-    (r1,r2) = mapEitherWithKey f r
-mapEitherWithKey f (Tip k x) = case f k x of
-  Left y  -> (Tip k y, Nil)
-  Right z -> (Nil, Tip k z)
-mapEitherWithKey _ Nil = (Nil, Nil)
+    go f (Bin p m l r)
+      = bin p m l1 r1 :*: bin p m l2 r2
+      where
+        (l1 :*: l2) = go f l
+        (r1 :*: r2) = go f r
+    go f (Tip k x) = case f k x of
+      Left y  -> (Tip k y :*: Nil)
+      Right z -> (Nil :*: Tip k z)
+    go _ Nil = (Nil :*: Nil)
 
 -- | /O(min(n,W))/. The expression (@'split' k map@) is a pair @(map1,map2)@
 -- where all keys in @map1@ are lower than @k@ and all keys in
@@ -1479,18 +1485,23 @@ mapEitherWithKey _ Nil = (Nil, Nil)
 
 split :: Key -> IntMap a -> (IntMap a, IntMap a)
 split k t =
-  case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers.
-                                      then case go k l of (lt, gt) -> (union r lt, gt)
-                                      else case go k r of (lt, gt) -> (lt, union gt l)
-            _ -> go k t
+  case t of
+      Bin _ m l r
+          | m < 0 -> if k >= 0 -- handle negative numbers.
+                     then case go k l of (lt :*: gt) -> let lt' = union r lt 
+                                                        in lt' `seq` (lt', gt)
+                     else case go k r of (lt :*: gt) -> let gt' = union gt l
+                                                        in gt' `seq` (lt, gt')
+      _ -> case go k t of
+          (lt :*: gt) -> (lt, gt)
   where
-    go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nil) else (Nil, t')
-                           | zero k' m = case go k' l of (lt, gt) -> (lt, union gt r)
-                           | otherwise = case go k' r of (lt, gt) -> (union l lt, gt)
-    go k' t'@(Tip ky _) | k' > ky   = (t', Nil)
-                        | k' < ky   = (Nil, t')
-                        | otherwise = (Nil, Nil)
-    go _ Nil = (Nil, Nil)
+    go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then t' :*: Nil else Nil :*: t'
+                           | zero k' m = case go k' l of (lt :*: gt) -> lt :*: union gt r
+                           | otherwise = case go k' r of (lt :*: gt) -> union l lt :*: gt
+    go k' t'@(Tip ky _) | k' > ky   = (t' :*: Nil)
+                        | k' < ky   = (Nil :*: t')
+                        | otherwise = (Nil :*: Nil)
+    go _ Nil = (Nil :*: Nil)
 
 -- | /O(min(n,W))/. Performs a 'split' but also returns whether the pivot
 -- key was found in the original map.
@@ -1503,14 +1514,23 @@ split k t =
 
 splitLookup :: Key -> IntMap a -> (IntMap a, Maybe a, IntMap a)
 splitLookup k t =
-  case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers.
-                                      then case go k l of (lt, fnd, gt) -> (union r lt, fnd, gt)
-                                      else case go k r of (lt, fnd, gt) -> (lt, fnd, union gt l)
-            _ -> go k t
+  case t of
+      Bin _ m l r
+          | m < 0 -> if k >= 0 -- handle negative numbers.
+                     then case go k l of
+                         (lt, fnd, gt) -> let lt' = union r lt
+                                          in lt' `seq` (lt', fnd, gt)
+                     else case go k r of
+                         (lt, fnd, gt) -> let gt' = union gt l
+                                          in gt' `seq` (lt, fnd, gt')
+      _ -> go k t
   where
-    go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t')
-                           | zero k' m = case go k' l of (lt, fnd, gt) -> (lt, fnd, union gt r)
-                           | otherwise = case go k' r of (lt, fnd, gt) -> (union l lt, fnd, gt)
+    go k' t'@(Bin p m l r)
+        | nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t')
+        | zero k' m      = case go k' l of
+            (lt, fnd, gt) -> let gt' = union gt r in gt' `seq` (lt, fnd, gt')
+        | otherwise      = case go k' r of
+            (lt, fnd, gt) -> let lt' = union l lt in lt' `seq` (lt', fnd, gt)
     go k' t'@(Tip ky y) | k' > ky   = (t', Nothing, Nil)
                         | k' < ky   = (Nil, Nothing, t')
                         | otherwise = (Nil, Just y, Nil)
diff --git a/Data/IntMap/Strict.hs b/Data/IntMap/Strict.hs
index b338df1a..20c72d5a 100644
--- a/Data/IntMap/Strict.hs
+++ b/Data/IntMap/Strict.hs
@@ -390,16 +390,18 @@ insertWithKey f k x t = k `seq` x `seq`
 -- > insertLookup 7 "x" (fromList [(5,"a"), (3,"b")]) == (Nothing,  fromList [(3, "b"), (5, "a"), (7, "x")])
 
 insertLookupWithKey :: (Key -> a -> a -> a) -> Key -> a -> IntMap a -> (Maybe a, IntMap a)
-insertLookupWithKey f k x t = k `seq` x `seq`
-  case t of
-    Bin p m l r
-      | nomatch k p m -> Nothing `strictPair` join k (Tip k x) p t
-      | zero k m      -> let (found,l') = insertLookupWithKey f k x l in (found `strictPair` Bin p m l' r)
-      | otherwise     -> let (found,r') = insertLookupWithKey f k x r in (found `strictPair` Bin p m l r')
-    Tip ky y
-      | k==ky         -> (Just y `strictPair` (Tip k $! f k x y))
-      | otherwise     -> (Nothing `strictPair` join k (Tip k x) ky t)
-    Nil -> Nothing `strictPair` Tip k x
+insertLookupWithKey f0 k0 x0 t0 = k0 `seq` x0 `seq` toPair $ go f0 k0 x0 t0
+  where
+    go f k x t =
+      case t of
+        Bin p m l r
+          | nomatch k p m -> Nothing :*: join k (Tip k x) p t
+          | zero k m      -> let (found :*: l') = go f k x l in (found :*: Bin p m l' r)
+          | otherwise     -> let (found :*: r') = go f k x r in (found :*: Bin p m l r')
+        Tip ky y
+          | k==ky         -> (Just y :*: (Tip k $! f k x y))
+          | otherwise     -> (Nothing :*: join k (Tip k x) ky t)
+        Nil -> Nothing :*: Tip k x
 
 
 {--------------------------------------------------------------------
@@ -475,18 +477,20 @@ updateWithKey f k t = k `seq`
 -- > updateLookupWithKey f 3 (fromList [(5,"a"), (3,"b")]) == (Just "b", singleton 5 "a")
 
 updateLookupWithKey ::  (Key -> a -> Maybe a) -> Key -> IntMap a -> (Maybe a,IntMap a)
-updateLookupWithKey f k t = k `seq`
-  case t of
-    Bin p m l r
-      | nomatch k p m -> (Nothing, t)
-      | zero k m      -> let (found,l') = updateLookupWithKey f k l in (found `strictPair` bin p m l' r)
-      | otherwise     -> let (found,r') = updateLookupWithKey f k r in (found `strictPair` bin p m l r')
-    Tip ky y
-      | k==ky         -> case f k y of
-                           Just y' -> y' `seq` (Just y `strictPair` Tip ky y')
-                           Nothing -> (Just y, Nil)
-      | otherwise     -> (Nothing,t)
-    Nil -> (Nothing,Nil)
+updateLookupWithKey f0 k0 t0 = k0 `seq` toPair $ go f0 k0 t0
+  where
+    go f k t =
+      case t of
+        Bin p m l r
+          | nomatch k p m -> (Nothing :*: t)
+          | zero k m      -> let (found :*: l') = go f k l in (found :*: bin p m l' r)
+          | otherwise     -> let (found :*: r') = go f k r in (found :*: bin p m l r')
+        Tip ky y
+          | k==ky         -> case f k y of
+                               Just y' -> y' `seq` (Just y :*: Tip ky y')
+                               Nothing -> (Just y :*: Nil)
+          | otherwise     -> (Nothing :*: t)
+        Nil -> (Nothing :*: Nil)
 
 
 
@@ -743,24 +747,28 @@ mapAccumWithKey f a t
 -- the accumulating argument and the both elements of the
 -- result of the function.
 mapAccumL :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c)
-mapAccumL f a t
-  = case t of
-      Bin p m l r -> let (a1,l') = mapAccumL f a l
-                         (a2,r') = mapAccumL f a1 r
-                     in (a2 `strictPair` Bin p m l' r')
-      Tip k x     -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x')
-      Nil         -> (a `strictPair` Nil)
+mapAccumL f0 a0 t0 = toPair $ go f0 a0 t0
+  where
+    go f a t
+      = case t of
+          Bin p m l r -> let (a1 :*: l') = go f a l
+                             (a2 :*: r') = go f a1 r
+                         in (a2 :*: Bin p m l' r')
+          Tip k x     -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x')
+          Nil         -> (a :*: Nil)
 
 -- | /O(n)/. The function @'mapAccumR'@ threads an accumulating
 -- argument through the map in descending order of keys.
 mapAccumRWithKey :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c)
-mapAccumRWithKey f a t
-  = case t of
-      Bin p m l r -> let (a1,r') = mapAccumRWithKey f a r
-                         (a2,l') = mapAccumRWithKey f a1 l
-                     in (a2 `strictPair` Bin p m l' r')
-      Tip k x     -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x')
-      Nil         -> (a `strictPair` Nil)
+mapAccumRWithKey f0 a0 t0 = toPair $ go f0 a0 t0
+  where
+    go f a t
+      = case t of
+          Bin p m l r -> let (a1 :*: r') = go f a r
+                             (a2 :*: l') = go f a1 l
+                         in (a2 :*: Bin p m l' r')
+          Tip k x     -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x')
+          Nil         -> (a :*: Nil)
 
 -- | /O(n*log n)/.
 -- @'mapKeysWith' c f s@ is the map obtained by applying @f@ to each key of @s@.
@@ -822,15 +830,17 @@ mapEither f m
 -- >     == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])
 
 mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c)
-mapEitherWithKey f (Bin p m l r)
-  = bin p m l1 r1 `strictPair` bin p m l2 r2
+mapEitherWithKey f0 t0 = toPair $ go f0 t0
   where
-    (l1,l2) = mapEitherWithKey f l
-    (r1,r2) = mapEitherWithKey f r
-mapEitherWithKey f (Tip k x) = case f k x of
-  Left y  -> y `seq` (Tip k y, Nil)
-  Right z -> z `seq` (Nil, Tip k z)
-mapEitherWithKey _ Nil = (Nil, Nil)
+    go f (Bin p m l r)
+      = bin p m l1 r1 :*: bin p m l2 r2
+      where
+        (l1 :*: l2) = go f l
+        (r1 :*: r2) = go f r
+    go f (Tip k x) = case f k x of
+      Left y  -> y `seq` (Tip k y :*: Nil)
+      Right z -> z `seq` (Nil :*: Tip k z)
+    go _ Nil = (Nil :*: Nil)
 
 {--------------------------------------------------------------------
   Conversions
diff --git a/Data/IntSet/Base.hs b/Data/IntSet/Base.hs
index 056aa52c..7e7c1a78 100644
--- a/Data/IntSet/Base.hs
+++ b/Data/IntSet/Base.hs
@@ -168,6 +168,8 @@ import Data.Maybe (fromMaybe)
 import Data.Typeable
 import Control.DeepSeq (NFData)
 
+import Data.StrictPair
+
 #if __GLASGOW_HASKELL__
 import Text.Read
 import Data.Data (Data(..), mkNoRepType)
@@ -655,19 +657,21 @@ filter predicate t
 
 -- | /O(n)/. partition the set according to some predicate.
 partition :: (Int -> Bool) -> IntSet -> (IntSet,IntSet)
-partition predicate t
-  = case t of
-      Bin p m l r
-        -> let (l1,l2) = partition predicate l
-               (r1,r2) = partition predicate r
-           in (bin p m l1 r1, bin p m l2 r2)
-      Tip kx bm
-        -> let bm1 = foldl'Bits 0 (bitPred kx) 0 bm
-           in  (tip kx bm1, tip kx (bm `xor` bm1))
-      Nil -> (Nil,Nil)
-  where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi
-                         | otherwise           = bm
-        {-# INLINE bitPred #-}
+partition predicate0 t0 = toPair $ go predicate0 t0
+  where
+    go predicate t
+      = case t of
+          Bin p m l r
+            -> let (l1 :*: l2) = go predicate l
+                   (r1 :*: r2) = go predicate r
+               in bin p m l1 r1 :*: bin p m l2 r2
+          Tip kx bm
+            -> let bm1 = foldl'Bits 0 (bitPred kx) 0 bm
+               in  tip kx bm1 :*: tip kx (bm `xor` bm1)
+          Nil -> (Nil :*: Nil)
+      where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi
+                             | otherwise           = bm
+            {-# INLINE bitPred #-}
 
 
 -- | /O(min(n,W))/. The expression (@'split' x set@) is a pair @(set1,set2)@
@@ -677,41 +681,65 @@ partition predicate t
 -- > split 3 (fromList [1..5]) == (fromList [1,2], fromList [4,5])
 split :: Int -> IntSet -> (IntSet,IntSet)
 split x t =
-  case t of Bin _ m l r | m < 0 -> if x >= 0 then case go x l of (lt, gt) -> (union lt r, gt)
-                                             else case go x r of (lt, gt) -> (lt, union gt l)
-            _ -> go x t
+  case t of
+      Bin _ m l r
+          | m < 0 -> if x >= 0  -- handle negative numbers.
+                     then case go x l of (lt :*: gt) -> let lt' = union lt r
+                                                        in lt' `seq` (lt', gt)
+                     else case go x r of (lt :*: gt) -> let gt' = union gt l
+                                                        in gt' `seq` (lt, gt')
+      _ -> case go x t of
+          (lt :*: gt) -> (lt, gt)
   where
-    go x' t'@(Bin p m l r) | match x' p m = if zero x' m then case go x' l of (lt, gt) -> (lt, union gt r)
-                                                         else case go x' r of (lt, gt) -> (union lt l, gt)
-                           | otherwise   = if x' < p then (Nil, t')
-                                                     else (t', Nil)
-    go x' t'@(Tip kx' bm) | kx' > x'          = (Nil, t')
-                            -- equivalent to kx' > prefixOf x'
-                          | kx' < prefixOf x' = (t', Nil)
-                          | otherwise = (tip kx' (bm .&. lowerBitmap), tip kx' (bm .&. higherBitmap))
-                              where lowerBitmap = bitmapOf x' - 1
-                                    higherBitmap = complement (lowerBitmap + bitmapOf x')
-    go _ Nil = (Nil, Nil)
+    go !x' t'@(Bin p m l r)
+        | match x' p m = if zero x' m
+                         then case go x' l of
+                             (lt :*: gt) -> lt :*: union gt r
+                         else case go x' r of
+                             (lt :*: gt) -> union lt l :*: gt
+        | otherwise   = if x' < p then (Nil :*: t')
+                        else (t' :*: Nil)
+    go x' t'@(Tip kx' bm)
+        | kx' > x'          = (Nil :*: t')
+          -- equivalent to kx' > prefixOf x'
+        | kx' < prefixOf x' = (t' :*: Nil)
+        | otherwise = tip kx' (bm .&. lowerBitmap) :*: tip kx' (bm .&. higherBitmap)
+            where lowerBitmap = bitmapOf x' - 1
+                  higherBitmap = complement (lowerBitmap + bitmapOf x')
+    go _ Nil = (Nil :*: Nil)
 
 -- | /O(min(n,W))/. Performs a 'split' but also returns whether the pivot
 -- element was found in the original set.
 splitMember :: Int -> IntSet -> (IntSet,Bool,IntSet)
 splitMember x t =
-  case t of Bin _ m l r | m < 0 -> if x >= 0 then case go x l of (lt, fnd, gt) -> (union lt r, fnd, gt)
-                                             else case go x r of (lt, fnd, gt) -> (lt, fnd, union gt l)
-            _ -> go x t
+  case t of
+      Bin _ m l r | m < 0 -> if x >= 0
+                             then case go x l of
+                                 (lt, fnd, gt) -> let lt' = union lt r
+                                                  in lt' `seq` (lt', fnd, gt)
+                             else case go x r of
+                                 (lt, fnd, gt) -> let gt' = union gt l
+                                                  in gt' `seq` (lt, fnd, gt')
+      _ -> go x t
   where
-    go x' t'@(Bin p m l r) | match x' p m = if zero x' m then case go x' l of (lt, fnd, gt) -> (lt, fnd, union gt r)
-                                                         else case go x' r of (lt, fnd, gt) -> (union lt l, fnd, gt)
-                           | otherwise   = if x' < p then (Nil, False, t')
-                                                     else (t', False, Nil)
-    go x' t'@(Tip kx' bm) | kx' > x'          = (Nil, False, t')
-                            -- equivalent to kx' > prefixOf x'
-                          | kx' < prefixOf x' = (t', False, Nil)
-                          | otherwise = (tip kx' (bm .&. lowerBitmap), (bm .&. bitmapOfx') /= 0, tip kx' (bm .&. higherBitmap))
-                              where bitmapOfx' = bitmapOf x'
-                                    lowerBitmap = bitmapOfx' - 1
-                                    higherBitmap = complement (lowerBitmap + bitmapOfx')
+    go x' t'@(Bin p m l r)
+        | match x' p m = if zero x' m
+                         then case go x' l of
+                             (lt, fnd, gt) -> (lt, fnd, union gt r)
+                         else case go x' r of
+                             (lt, fnd, gt) -> (union lt l, fnd, gt)
+        | otherwise   = if x' < p then (Nil, False, t') else (t', False, Nil)
+    go x' t'@(Tip kx' bm)
+        | kx' > x'          = (Nil, False, t')
+          -- equivalent to kx' > prefixOf x'
+        | kx' < prefixOf x' = (t', False, Nil)
+        | otherwise = let lt = tip kx' (bm .&. lowerBitmap)
+                          found = (bm .&. bitmapOfx') /= 0
+                          gt = tip kx' (bm .&. higherBitmap)
+                      in lt `seq` found `seq` gt `seq` (lt, found, gt)
+            where bitmapOfx' = bitmapOf x'
+                  lowerBitmap = bitmapOfx' - 1
+                  higherBitmap = complement (lowerBitmap + bitmapOfx')
     go _ Nil = (Nil, False, Nil)
 
 
diff --git a/Data/Map/Base.hs b/Data/Map/Base.hs
index 7fa79f60..dc71e9f5 100644
--- a/Data/Map/Base.hs
+++ b/Data/Map/Base.hs
@@ -1548,13 +1548,15 @@ partition p m
 -- > partitionWithKey (\ k _ -> k > 7) (fromList [(5,"a"), (3,"b")]) == (empty, fromList [(3, "b"), (5, "a")])
 
 partitionWithKey :: (k -> a -> Bool) -> Map k a -> (Map k a,Map k a)
-partitionWithKey _ Tip = (Tip,Tip)
-partitionWithKey p (Bin _ kx x l r)
-  | p kx x    = (join kx x l1 r1,merge l2 r2)
-  | otherwise = (merge l1 r1,join kx x l2 r2)
+partitionWithKey p0 t0 = toPair $ go p0 t0
   where
-    (l1,l2) = partitionWithKey p l
-    (r1,r2) = partitionWithKey p r
+    go _ Tip = (Tip :*: Tip)
+    go p (Bin _ kx x l r)
+      | p kx x    = join kx x l1 r1 :*: merge l2 r2
+      | otherwise = merge l1 r1 :*: join kx x l2 r2
+      where
+        (l1 :*: l2) = go p l
+        (r1 :*: r2) = go p r
 
 -- | /O(n)/. Map values and collect the 'Just' results.
 --
@@ -1598,13 +1600,15 @@ mapEither f m
 -- >     == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])
 
 mapEitherWithKey :: (k -> a -> Either b c) -> Map k a -> (Map k b, Map k c)
-mapEitherWithKey _ Tip = (Tip, Tip)
-mapEitherWithKey f (Bin _ kx x l r) = case f kx x of
-  Left y  -> (join kx y l1 r1, merge l2 r2)
-  Right z -> (merge l1 r1, join kx z l2 r2)
- where
-    (l1,l2) = mapEitherWithKey f l
-    (r1,r2) = mapEitherWithKey f r
+mapEitherWithKey f0 t0 = toPair $ go f0 t0
+  where
+    go _ Tip = (Tip :*: Tip)
+    go f (Bin _ kx x l r) = case f kx x of
+      Left y  -> join kx y l1 r1 :*: merge l2 r2
+      Right z -> merge l1 r1 :*: join kx z l2 r2
+     where
+        (l1 :*: l2) = go f l
+        (r1 :*: r2) = go f r
 
 {--------------------------------------------------------------------
   Mapping
@@ -2145,23 +2149,27 @@ trim (JustS lk) (JustS hk) t = middle lk hk t  where middle lo hi (Bin _ k _ _ r
 
 -- See Note: Type of local 'go' function
 trimLookupLo :: Ord k => k -> MaybeS k -> Map k a -> (Maybe a, Map k a)
-trimLookupLo lk NothingS t = greater lk t
-  where greater :: Ord k => k -> Map k a -> (Maybe a, Map k a)
-        greater lo t'@(Bin _ kx x l r) = case compare lo kx of LT -> lookup lo l `strictPair` t'
-                                                               EQ -> (Just x, r)
-                                                               GT -> greater lo r
-        greater _ Tip = (Nothing, Tip)
-trimLookupLo lk (JustS hk) t = middle lk hk t
-  where middle :: Ord k => k -> k -> Map k a -> (Maybe a, Map k a)
-        middle lo hi t'@(Bin _ kx x l r) = case compare lo kx of LT | kx < hi -> lookup lo l `strictPair` t'
-                                                                    | otherwise -> middle lo hi l
-                                                                 EQ -> Just x `strictPair` lesser hi r
-                                                                 GT -> middle lo hi r
-        middle _ _ Tip = (Nothing, Tip)
-
-        lesser :: Ord k => k -> Map k a -> Map k a
-        lesser hi (Bin _ k _ l _) | k >= hi = lesser hi l
-        lesser _ t' = t'
+trimLookupLo lk0 mhk0 t0 = toPair $ go lk0 mhk0 t0
+  where
+    go lk NothingS t = greater lk t
+      where greater :: Ord k => k -> Map k a -> StrictPair (Maybe a) (Map k a)
+            greater lo t'@(Bin _ kx x l r) = case compare lo kx of
+                LT -> lookup lo l :*: t'
+                EQ -> (Just x :*: r)
+                GT -> greater lo r
+            greater _ Tip = (Nothing :*: Tip)
+    go lk (JustS hk) t = middle lk hk t
+      where middle :: Ord k => k -> k -> Map k a -> StrictPair (Maybe a) (Map k a)
+            middle lo hi t'@(Bin _ kx x l r) = case compare lo kx of
+                LT | kx < hi -> lookup lo l :*: t'
+                   | otherwise -> middle lo hi l
+                EQ -> Just x :*: lesser hi r
+                GT -> middle lo hi r
+            middle _ _ Tip = (Nothing :*: Tip)
+
+            lesser :: Ord k => k -> Map k a -> Map k a
+            lesser hi (Bin _ k _ l _) | k >= hi = lesser hi l
+            lesser _ t' = t'
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE trimLookupLo #-}
 #endif
@@ -2209,13 +2217,15 @@ filterLt (JustS b) t = filter' b t
 -- > split 6 (fromList [(5,"a"), (3,"b")]) == (fromList [(3,"b"), (5,"a")], empty)
 
 split :: Ord k => k -> Map k a -> (Map k a,Map k a)
-split k t = k `seq`
-  case t of
-    Tip            -> (Tip, Tip)
-    Bin _ kx x l r -> case compare k kx of
-      LT -> let (lt,gt) = split k l in (lt,join kx x gt r)
-      GT -> let (lt,gt) = split k r in (join kx x l lt,gt)
-      EQ -> (l,r)
+split k0 t0 = k0 `seq` toPair $ go k0 t0
+  where
+    go k t =
+      case t of
+        Tip            -> (Tip :*: Tip)
+        Bin _ kx x l r -> case compare k kx of
+          LT -> let (lt :*: gt) = go k l in lt :*: join kx x gt r
+          GT -> let (lt :*: gt) = go k r in join kx x l lt :*: gt
+          EQ -> (l :*: r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE split #-}
 #endif
@@ -2234,8 +2244,12 @@ splitLookup k t = k `seq`
   case t of
     Tip            -> (Tip,Nothing,Tip)
     Bin _ kx x l r -> case compare k kx of
-      LT -> let (lt,z,gt) = splitLookup k l in (lt,z,join kx x gt r)
-      GT -> let (lt,z,gt) = splitLookup k r in (join kx x l lt,z,gt)
+      LT -> let (lt,z,gt) = splitLookup k l
+                gt' = join kx x gt r
+            in gt' `seq` (lt,z,gt')
+      GT -> let (lt,z,gt) = splitLookup k r
+                lt' = join kx x l lt
+            in lt' `seq` (lt',z,gt)
       EQ -> (l,Just x,r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE splitLookup #-}
diff --git a/Data/Map/Strict.hs b/Data/Map/Strict.hs
index 5e2912d7..de1a3666 100644
--- a/Data/Map/Strict.hs
+++ b/Data/Map/Strict.hs
@@ -434,19 +434,19 @@ insertWithKey = go
 -- See Map.Base.Note: Type of local 'go' function
 insertLookupWithKey :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a
                     -> (Maybe a, Map k a)
-insertLookupWithKey = go
+insertLookupWithKey f0 kx0 x0 t0 = toPair $ go f0 kx0 x0 t0
   where
-    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> (Maybe a, Map k a)
+    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> StrictPair (Maybe a) (Map k a)
     STRICT_2_3_OF_4(go)
-    go _ kx x Tip = Nothing `strictPair` singleton kx x
+    go _ kx x Tip = Nothing :*: singleton kx x
     go f kx x (Bin sy ky y l r) =
         case compare kx ky of
-            LT -> let (found, l') = go f kx x l
-                  in found `strictPair` balanceL ky y l' r
-            GT -> let (found, r') = go f kx x r
-                  in found `strictPair` balanceR ky y l r'
+            LT -> let (found :*: l') = go f kx x l
+                  in found :*: balanceL ky y l' r
+            GT -> let (found :*: r') = go f kx x r
+                  in found :*: balanceR ky y l r'
             EQ -> let x' = f kx x y
-                  in x' `seq` (Just y `strictPair` Bin sy kx x' l r)
+                  in x' `seq` (Just y :*: Bin sy kx x' l r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE insertLookupWithKey #-}
 #else
@@ -547,20 +547,20 @@ updateWithKey = go
 
 -- See Map.Base.Note: Type of local 'go' function
 updateLookupWithKey :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a)
-updateLookupWithKey = go
+updateLookupWithKey f0 k0 t0 = toPair $ go f0 k0 t0
  where
-   go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a)
+   go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> StrictPair (Maybe a) (Map k a)
    STRICT_2_OF_3(go)
-   go _ _ Tip = (Nothing,Tip)
+   go _ _ Tip = (Nothing :*: Tip)
    go f k (Bin sx kx x l r) =
           case compare k kx of
-               LT -> let (found,l') = go f k l
-                     in found `strictPair` balanceR kx x l' r
-               GT -> let (found,r') = go f k r
-                     in found `strictPair` balanceL kx x l r'
+               LT -> let (found :*: l') = go f k l
+                     in found :*: balanceR kx x l' r
+               GT -> let (found :*: r') = go f k r
+                     in found :*: balanceL kx x l r'
                EQ -> case f kx x of
-                       Just x' -> x' `seq` (Just x' `strictPair` Bin sx kx x' l r)
-                       Nothing -> (Just x,glue l r)
+                       Just x' -> x' `seq` (Just x' :*: Bin sx kx x' l r)
+                       Nothing -> (Just x :*: glue l r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE updateLookupWithKey #-}
 #else
@@ -899,13 +899,15 @@ mapEither f m
 -- >     == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])
 
 mapEitherWithKey :: (k -> a -> Either b c) -> Map k a -> (Map k b, Map k c)
-mapEitherWithKey _ Tip = (Tip, Tip)
-mapEitherWithKey f (Bin _ kx x l r) = case f kx x of
-  Left y  -> y `seq` (join kx y l1 r1 `strictPair` merge l2 r2)
-  Right z -> z `seq` (merge l1 r1 `strictPair` join kx z l2 r2)
- where
-    (l1,l2) = mapEitherWithKey f l
-    (r1,r2) = mapEitherWithKey f r
+mapEitherWithKey f0 t0 = toPair $ go f0 t0
+  where
+    go _ Tip = (Tip :*: Tip)
+    go f (Bin _ kx x l r) = case f kx x of
+      Left y  -> y `seq` (join kx y l1 r1 :*: merge l2 r2)
+      Right z -> z `seq` (merge l1 r1 :*: join kx z l2 r2)
+     where
+        (l1 :*: l2) = go f l
+        (r1 :*: r2) = go f r
 
 {--------------------------------------------------------------------
   Mapping
diff --git a/Data/Set/Base.hs b/Data/Set/Base.hs
index 0ac9ee0b..600f3d2a 100644
--- a/Data/Set/Base.hs
+++ b/Data/Set/Base.hs
@@ -182,6 +182,8 @@ import qualified Data.Foldable as Foldable
 import Data.Typeable
 import Control.DeepSeq (NFData(rnf))
 
+import Data.StrictPair
+
 #if __GLASGOW_HASKELL__
 import GHC.Exts ( build )
 import Text.Read
@@ -623,11 +625,13 @@ filter p (Bin _ x l r)
 -- the predicate and one with all elements that don't satisfy the predicate.
 -- See also 'split'.
 partition :: (a -> Bool) -> Set a -> (Set a,Set a)
-partition _ Tip = (Tip, Tip)
-partition p (Bin _ x l r) = case (partition p l, partition p r) of
-  ((l1, l2), (r1, r2))
-    | p x       -> (join x l1 r1, merge l2 r2)
-    | otherwise -> (merge l1 r1, join x l2 r2)
+partition p0 t0 = toPair $ go p0 t0
+  where
+    go _ Tip = (Tip :*: Tip)
+    go p (Bin _ x l r) = case (go p l, go p r) of
+      ((l1 :*: l2), (r1 :*: r2))
+        | p x       -> join x l1 r1 :*: merge l2 r2
+        | otherwise -> merge l1 r1 :*: join x l2 r2
 
 {----------------------------------------------------------------------
   Map
@@ -958,12 +962,14 @@ filterLt (JustS b) t = filter' b t
 -- where @set1@ comprises the elements of @set@ less than @x@ and @set2@
 -- comprises the elements of @set@ greater than @x@.
 split :: Ord a => a -> Set a -> (Set a,Set a)
-split _ Tip = (Tip,Tip)
-split x (Bin _ y l r)
-  = case compare x y of
-      LT -> let (lt,gt) = split x l in (lt,join y gt r)
-      GT -> let (lt,gt) = split x r in (join y l lt,gt)
-      EQ -> (l,r)
+split x0 t0 = toPair $ go x0 t0
+  where
+    go _ Tip = (Tip :*: Tip)
+    go x (Bin _ y l r)
+      = case compare x y of
+          LT -> let (lt :*: gt) = go x l in (lt :*: join y gt r)
+          GT -> let (lt :*: gt) = go x r in (join y l lt :*: gt)
+          EQ -> (l :*: r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE split #-}
 #endif
@@ -974,8 +980,12 @@ splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a)
 splitMember _ Tip = (Tip, False, Tip)
 splitMember x (Bin _ y l r)
    = case compare x y of
-       LT -> let (lt, found, gt) = splitMember x l in (lt, found, join y gt r)
-       GT -> let (lt, found, gt) = splitMember x r in (join y l lt, found, gt)
+       LT -> let (lt, found, gt) = splitMember x l
+                 gt' = join y gt r
+             in gt' `seq` (lt, found, gt')
+       GT -> let (lt, found, gt) = splitMember x r
+                 lt' = join y l lt
+             in lt' `seq` (lt', found, gt)
        EQ -> (l, True, r)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE splitMember #-}
diff --git a/Data/StrictPair.hs b/Data/StrictPair.hs
index 55148217..0e06534a 100644
--- a/Data/StrictPair.hs
+++ b/Data/StrictPair.hs
@@ -1,6 +1,9 @@
-module Data.StrictPair (strictPair) where
+module Data.StrictPair (StrictPair(..), toPair) where
 
--- | Evaluate both argument to WHNF and create a pair of the result.
-strictPair :: a -> b -> (a, b)
-strictPair x y = x `seq` y `seq` (x, y)
-{-# INLINE strictPair #-}
+-- | Same as regular Haskell pairs, but (x :*: _|_) = (_|_ :*: y) =
+-- _|_
+data StrictPair a b = !a :*: !b
+
+toPair :: StrictPair a b -> (a, b)
+toPair (x :*: y) = (x, y)
+{-# INLINE toPair #-}
\ No newline at end of file
diff --git a/benchmarks/Map.hs b/benchmarks/Map.hs
index 1efe2457..70a7cc91 100644
--- a/benchmarks/Map.hs
+++ b/benchmarks/Map.hs
@@ -57,6 +57,7 @@ main = do
         , bench "union" $ whnf (M.union m_even) m_odd
         , bench "difference" $ whnf (M.difference m) m_even
         , bench "intersection" $ whnf (M.intersection m) m_even
+        , bench "split" $ whnf (M.split (bound `div` 2)) m
         ]
   where
     bound = 2^10
-- 
GitLab