diff --git a/Data/IntMap/Base.hs b/Data/IntMap/Base.hs index edf2dd03fa6a925698c21a1ccd35ae21180a5743..b3f8864a11a8c4f3dfee441da04f382a6a214e1d 100644 --- a/Data/IntMap/Base.hs +++ b/Data/IntMap/Base.hs @@ -222,6 +222,8 @@ import Control.Applicative (Applicative(pure,(<*>)),(<$>)) import Control.Monad ( liftM ) import Control.DeepSeq (NFData(rnf)) +import Data.StrictPair + #if __GLASGOW_HASKELL__ import Text.Read import Data.Data (Data(..), mkNoRepType) @@ -1402,16 +1404,18 @@ partition p m -- > partitionWithKey (\ k _ -> k > 7) (fromList [(5,"a"), (3,"b")]) == (empty, fromList [(3, "b"), (5, "a")]) partitionWithKey :: (Key -> a -> Bool) -> IntMap a -> (IntMap a,IntMap a) -partitionWithKey predicate t - = case t of - Bin p m l r - -> let (l1,l2) = partitionWithKey predicate l - (r1,r2) = partitionWithKey predicate r - in (bin p m l1 r1, bin p m l2 r2) - Tip k x - | predicate k x -> (t,Nil) - | otherwise -> (Nil,t) - Nil -> (Nil,Nil) +partitionWithKey predicate0 t0 = toPair $ go predicate0 t0 + where + go predicate t + = case t of + Bin p m l r + -> let (l1 :*: l2) = go predicate l + (r1 :*: r2) = go predicate r + in bin p m l1 r1 :*: bin p m l2 r2 + Tip k x + | predicate k x -> (t :*: Nil) + | otherwise -> (Nil :*: t) + Nil -> (Nil :*: Nil) -- | /O(n)/. Map values and collect the 'Just' results. -- @@ -1457,15 +1461,17 @@ mapEither f m -- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")]) mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c) -mapEitherWithKey f (Bin p m l r) - = (bin p m l1 r1, bin p m l2 r2) +mapEitherWithKey f0 t0 = toPair $ go f0 t0 where - (l1,l2) = mapEitherWithKey f l - (r1,r2) = mapEitherWithKey f r -mapEitherWithKey f (Tip k x) = case f k x of - Left y -> (Tip k y, Nil) - Right z -> (Nil, Tip k z) -mapEitherWithKey _ Nil = (Nil, Nil) + go f (Bin p m l r) + = bin p m l1 r1 :*: bin p m l2 r2 + where + (l1 :*: l2) = go f l + (r1 :*: r2) = go f r + go f (Tip k x) = case f k x of + Left y -> (Tip k y :*: Nil) + Right z -> (Nil :*: Tip k z) + go _ Nil = (Nil :*: Nil) -- | /O(min(n,W))/. The expression (@'split' k map@) is a pair @(map1,map2)@ -- where all keys in @map1@ are lower than @k@ and all keys in @@ -1479,18 +1485,23 @@ mapEitherWithKey _ Nil = (Nil, Nil) split :: Key -> IntMap a -> (IntMap a, IntMap a) split k t = - case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers. - then case go k l of (lt, gt) -> (union r lt, gt) - else case go k r of (lt, gt) -> (lt, union gt l) - _ -> go k t + case t of + Bin _ m l r + | m < 0 -> if k >= 0 -- handle negative numbers. + then case go k l of (lt :*: gt) -> let lt' = union r lt + in lt' `seq` (lt', gt) + else case go k r of (lt :*: gt) -> let gt' = union gt l + in gt' `seq` (lt, gt') + _ -> case go k t of + (lt :*: gt) -> (lt, gt) where - go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nil) else (Nil, t') - | zero k' m = case go k' l of (lt, gt) -> (lt, union gt r) - | otherwise = case go k' r of (lt, gt) -> (union l lt, gt) - go k' t'@(Tip ky _) | k' > ky = (t', Nil) - | k' < ky = (Nil, t') - | otherwise = (Nil, Nil) - go _ Nil = (Nil, Nil) + go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then t' :*: Nil else Nil :*: t' + | zero k' m = case go k' l of (lt :*: gt) -> lt :*: union gt r + | otherwise = case go k' r of (lt :*: gt) -> union l lt :*: gt + go k' t'@(Tip ky _) | k' > ky = (t' :*: Nil) + | k' < ky = (Nil :*: t') + | otherwise = (Nil :*: Nil) + go _ Nil = (Nil :*: Nil) -- | /O(min(n,W))/. Performs a 'split' but also returns whether the pivot -- key was found in the original map. @@ -1503,14 +1514,23 @@ split k t = splitLookup :: Key -> IntMap a -> (IntMap a, Maybe a, IntMap a) splitLookup k t = - case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers. - then case go k l of (lt, fnd, gt) -> (union r lt, fnd, gt) - else case go k r of (lt, fnd, gt) -> (lt, fnd, union gt l) - _ -> go k t + case t of + Bin _ m l r + | m < 0 -> if k >= 0 -- handle negative numbers. + then case go k l of + (lt, fnd, gt) -> let lt' = union r lt + in lt' `seq` (lt', fnd, gt) + else case go k r of + (lt, fnd, gt) -> let gt' = union gt l + in gt' `seq` (lt, fnd, gt') + _ -> go k t where - go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t') - | zero k' m = case go k' l of (lt, fnd, gt) -> (lt, fnd, union gt r) - | otherwise = case go k' r of (lt, fnd, gt) -> (union l lt, fnd, gt) + go k' t'@(Bin p m l r) + | nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t') + | zero k' m = case go k' l of + (lt, fnd, gt) -> let gt' = union gt r in gt' `seq` (lt, fnd, gt') + | otherwise = case go k' r of + (lt, fnd, gt) -> let lt' = union l lt in lt' `seq` (lt', fnd, gt) go k' t'@(Tip ky y) | k' > ky = (t', Nothing, Nil) | k' < ky = (Nil, Nothing, t') | otherwise = (Nil, Just y, Nil) diff --git a/Data/IntMap/Strict.hs b/Data/IntMap/Strict.hs index b338df1a031c88424bdc7213f87d9709602b84dd..20c72d5ac5c9cb888484ae12053a2f2702974b6f 100644 --- a/Data/IntMap/Strict.hs +++ b/Data/IntMap/Strict.hs @@ -390,16 +390,18 @@ insertWithKey f k x t = k `seq` x `seq` -- > insertLookup 7 "x" (fromList [(5,"a"), (3,"b")]) == (Nothing, fromList [(3, "b"), (5, "a"), (7, "x")]) insertLookupWithKey :: (Key -> a -> a -> a) -> Key -> a -> IntMap a -> (Maybe a, IntMap a) -insertLookupWithKey f k x t = k `seq` x `seq` - case t of - Bin p m l r - | nomatch k p m -> Nothing `strictPair` join k (Tip k x) p t - | zero k m -> let (found,l') = insertLookupWithKey f k x l in (found `strictPair` Bin p m l' r) - | otherwise -> let (found,r') = insertLookupWithKey f k x r in (found `strictPair` Bin p m l r') - Tip ky y - | k==ky -> (Just y `strictPair` (Tip k $! f k x y)) - | otherwise -> (Nothing `strictPair` join k (Tip k x) ky t) - Nil -> Nothing `strictPair` Tip k x +insertLookupWithKey f0 k0 x0 t0 = k0 `seq` x0 `seq` toPair $ go f0 k0 x0 t0 + where + go f k x t = + case t of + Bin p m l r + | nomatch k p m -> Nothing :*: join k (Tip k x) p t + | zero k m -> let (found :*: l') = go f k x l in (found :*: Bin p m l' r) + | otherwise -> let (found :*: r') = go f k x r in (found :*: Bin p m l r') + Tip ky y + | k==ky -> (Just y :*: (Tip k $! f k x y)) + | otherwise -> (Nothing :*: join k (Tip k x) ky t) + Nil -> Nothing :*: Tip k x {-------------------------------------------------------------------- @@ -475,18 +477,20 @@ updateWithKey f k t = k `seq` -- > updateLookupWithKey f 3 (fromList [(5,"a"), (3,"b")]) == (Just "b", singleton 5 "a") updateLookupWithKey :: (Key -> a -> Maybe a) -> Key -> IntMap a -> (Maybe a,IntMap a) -updateLookupWithKey f k t = k `seq` - case t of - Bin p m l r - | nomatch k p m -> (Nothing, t) - | zero k m -> let (found,l') = updateLookupWithKey f k l in (found `strictPair` bin p m l' r) - | otherwise -> let (found,r') = updateLookupWithKey f k r in (found `strictPair` bin p m l r') - Tip ky y - | k==ky -> case f k y of - Just y' -> y' `seq` (Just y `strictPair` Tip ky y') - Nothing -> (Just y, Nil) - | otherwise -> (Nothing,t) - Nil -> (Nothing,Nil) +updateLookupWithKey f0 k0 t0 = k0 `seq` toPair $ go f0 k0 t0 + where + go f k t = + case t of + Bin p m l r + | nomatch k p m -> (Nothing :*: t) + | zero k m -> let (found :*: l') = go f k l in (found :*: bin p m l' r) + | otherwise -> let (found :*: r') = go f k r in (found :*: bin p m l r') + Tip ky y + | k==ky -> case f k y of + Just y' -> y' `seq` (Just y :*: Tip ky y') + Nothing -> (Just y :*: Nil) + | otherwise -> (Nothing :*: t) + Nil -> (Nothing :*: Nil) @@ -743,24 +747,28 @@ mapAccumWithKey f a t -- the accumulating argument and the both elements of the -- result of the function. mapAccumL :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c) -mapAccumL f a t - = case t of - Bin p m l r -> let (a1,l') = mapAccumL f a l - (a2,r') = mapAccumL f a1 r - in (a2 `strictPair` Bin p m l' r') - Tip k x -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x') - Nil -> (a `strictPair` Nil) +mapAccumL f0 a0 t0 = toPair $ go f0 a0 t0 + where + go f a t + = case t of + Bin p m l r -> let (a1 :*: l') = go f a l + (a2 :*: r') = go f a1 r + in (a2 :*: Bin p m l' r') + Tip k x -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x') + Nil -> (a :*: Nil) -- | /O(n)/. The function @'mapAccumR'@ threads an accumulating -- argument through the map in descending order of keys. mapAccumRWithKey :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c) -mapAccumRWithKey f a t - = case t of - Bin p m l r -> let (a1,r') = mapAccumRWithKey f a r - (a2,l') = mapAccumRWithKey f a1 l - in (a2 `strictPair` Bin p m l' r') - Tip k x -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x') - Nil -> (a `strictPair` Nil) +mapAccumRWithKey f0 a0 t0 = toPair $ go f0 a0 t0 + where + go f a t + = case t of + Bin p m l r -> let (a1 :*: r') = go f a r + (a2 :*: l') = go f a1 l + in (a2 :*: Bin p m l' r') + Tip k x -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x') + Nil -> (a :*: Nil) -- | /O(n*log n)/. -- @'mapKeysWith' c f s@ is the map obtained by applying @f@ to each key of @s@. @@ -822,15 +830,17 @@ mapEither f m -- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")]) mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c) -mapEitherWithKey f (Bin p m l r) - = bin p m l1 r1 `strictPair` bin p m l2 r2 +mapEitherWithKey f0 t0 = toPair $ go f0 t0 where - (l1,l2) = mapEitherWithKey f l - (r1,r2) = mapEitherWithKey f r -mapEitherWithKey f (Tip k x) = case f k x of - Left y -> y `seq` (Tip k y, Nil) - Right z -> z `seq` (Nil, Tip k z) -mapEitherWithKey _ Nil = (Nil, Nil) + go f (Bin p m l r) + = bin p m l1 r1 :*: bin p m l2 r2 + where + (l1 :*: l2) = go f l + (r1 :*: r2) = go f r + go f (Tip k x) = case f k x of + Left y -> y `seq` (Tip k y :*: Nil) + Right z -> z `seq` (Nil :*: Tip k z) + go _ Nil = (Nil :*: Nil) {-------------------------------------------------------------------- Conversions diff --git a/Data/IntSet/Base.hs b/Data/IntSet/Base.hs index 056aa52cfb56c99f22091575e21a0bc6b795861f..7e7c1a78556f043b4b0e1335a3244a7c683fb947 100644 --- a/Data/IntSet/Base.hs +++ b/Data/IntSet/Base.hs @@ -168,6 +168,8 @@ import Data.Maybe (fromMaybe) import Data.Typeable import Control.DeepSeq (NFData) +import Data.StrictPair + #if __GLASGOW_HASKELL__ import Text.Read import Data.Data (Data(..), mkNoRepType) @@ -655,19 +657,21 @@ filter predicate t -- | /O(n)/. partition the set according to some predicate. partition :: (Int -> Bool) -> IntSet -> (IntSet,IntSet) -partition predicate t - = case t of - Bin p m l r - -> let (l1,l2) = partition predicate l - (r1,r2) = partition predicate r - in (bin p m l1 r1, bin p m l2 r2) - Tip kx bm - -> let bm1 = foldl'Bits 0 (bitPred kx) 0 bm - in (tip kx bm1, tip kx (bm `xor` bm1)) - Nil -> (Nil,Nil) - where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi - | otherwise = bm - {-# INLINE bitPred #-} +partition predicate0 t0 = toPair $ go predicate0 t0 + where + go predicate t + = case t of + Bin p m l r + -> let (l1 :*: l2) = go predicate l + (r1 :*: r2) = go predicate r + in bin p m l1 r1 :*: bin p m l2 r2 + Tip kx bm + -> let bm1 = foldl'Bits 0 (bitPred kx) 0 bm + in tip kx bm1 :*: tip kx (bm `xor` bm1) + Nil -> (Nil :*: Nil) + where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi + | otherwise = bm + {-# INLINE bitPred #-} -- | /O(min(n,W))/. The expression (@'split' x set@) is a pair @(set1,set2)@ @@ -677,41 +681,65 @@ partition predicate t -- > split 3 (fromList [1..5]) == (fromList [1,2], fromList [4,5]) split :: Int -> IntSet -> (IntSet,IntSet) split x t = - case t of Bin _ m l r | m < 0 -> if x >= 0 then case go x l of (lt, gt) -> (union lt r, gt) - else case go x r of (lt, gt) -> (lt, union gt l) - _ -> go x t + case t of + Bin _ m l r + | m < 0 -> if x >= 0 -- handle negative numbers. + then case go x l of (lt :*: gt) -> let lt' = union lt r + in lt' `seq` (lt', gt) + else case go x r of (lt :*: gt) -> let gt' = union gt l + in gt' `seq` (lt, gt') + _ -> case go x t of + (lt :*: gt) -> (lt, gt) where - go x' t'@(Bin p m l r) | match x' p m = if zero x' m then case go x' l of (lt, gt) -> (lt, union gt r) - else case go x' r of (lt, gt) -> (union lt l, gt) - | otherwise = if x' < p then (Nil, t') - else (t', Nil) - go x' t'@(Tip kx' bm) | kx' > x' = (Nil, t') - -- equivalent to kx' > prefixOf x' - | kx' < prefixOf x' = (t', Nil) - | otherwise = (tip kx' (bm .&. lowerBitmap), tip kx' (bm .&. higherBitmap)) - where lowerBitmap = bitmapOf x' - 1 - higherBitmap = complement (lowerBitmap + bitmapOf x') - go _ Nil = (Nil, Nil) + go !x' t'@(Bin p m l r) + | match x' p m = if zero x' m + then case go x' l of + (lt :*: gt) -> lt :*: union gt r + else case go x' r of + (lt :*: gt) -> union lt l :*: gt + | otherwise = if x' < p then (Nil :*: t') + else (t' :*: Nil) + go x' t'@(Tip kx' bm) + | kx' > x' = (Nil :*: t') + -- equivalent to kx' > prefixOf x' + | kx' < prefixOf x' = (t' :*: Nil) + | otherwise = tip kx' (bm .&. lowerBitmap) :*: tip kx' (bm .&. higherBitmap) + where lowerBitmap = bitmapOf x' - 1 + higherBitmap = complement (lowerBitmap + bitmapOf x') + go _ Nil = (Nil :*: Nil) -- | /O(min(n,W))/. Performs a 'split' but also returns whether the pivot -- element was found in the original set. splitMember :: Int -> IntSet -> (IntSet,Bool,IntSet) splitMember x t = - case t of Bin _ m l r | m < 0 -> if x >= 0 then case go x l of (lt, fnd, gt) -> (union lt r, fnd, gt) - else case go x r of (lt, fnd, gt) -> (lt, fnd, union gt l) - _ -> go x t + case t of + Bin _ m l r | m < 0 -> if x >= 0 + then case go x l of + (lt, fnd, gt) -> let lt' = union lt r + in lt' `seq` (lt', fnd, gt) + else case go x r of + (lt, fnd, gt) -> let gt' = union gt l + in gt' `seq` (lt, fnd, gt') + _ -> go x t where - go x' t'@(Bin p m l r) | match x' p m = if zero x' m then case go x' l of (lt, fnd, gt) -> (lt, fnd, union gt r) - else case go x' r of (lt, fnd, gt) -> (union lt l, fnd, gt) - | otherwise = if x' < p then (Nil, False, t') - else (t', False, Nil) - go x' t'@(Tip kx' bm) | kx' > x' = (Nil, False, t') - -- equivalent to kx' > prefixOf x' - | kx' < prefixOf x' = (t', False, Nil) - | otherwise = (tip kx' (bm .&. lowerBitmap), (bm .&. bitmapOfx') /= 0, tip kx' (bm .&. higherBitmap)) - where bitmapOfx' = bitmapOf x' - lowerBitmap = bitmapOfx' - 1 - higherBitmap = complement (lowerBitmap + bitmapOfx') + go x' t'@(Bin p m l r) + | match x' p m = if zero x' m + then case go x' l of + (lt, fnd, gt) -> (lt, fnd, union gt r) + else case go x' r of + (lt, fnd, gt) -> (union lt l, fnd, gt) + | otherwise = if x' < p then (Nil, False, t') else (t', False, Nil) + go x' t'@(Tip kx' bm) + | kx' > x' = (Nil, False, t') + -- equivalent to kx' > prefixOf x' + | kx' < prefixOf x' = (t', False, Nil) + | otherwise = let lt = tip kx' (bm .&. lowerBitmap) + found = (bm .&. bitmapOfx') /= 0 + gt = tip kx' (bm .&. higherBitmap) + in lt `seq` found `seq` gt `seq` (lt, found, gt) + where bitmapOfx' = bitmapOf x' + lowerBitmap = bitmapOfx' - 1 + higherBitmap = complement (lowerBitmap + bitmapOfx') go _ Nil = (Nil, False, Nil) diff --git a/Data/Map/Base.hs b/Data/Map/Base.hs index 7fa79f60ff2afb8305c12bc5fbe8618c10a95472..dc71e9f51a4ae744896696d2bcc26f19db932baa 100644 --- a/Data/Map/Base.hs +++ b/Data/Map/Base.hs @@ -1548,13 +1548,15 @@ partition p m -- > partitionWithKey (\ k _ -> k > 7) (fromList [(5,"a"), (3,"b")]) == (empty, fromList [(3, "b"), (5, "a")]) partitionWithKey :: (k -> a -> Bool) -> Map k a -> (Map k a,Map k a) -partitionWithKey _ Tip = (Tip,Tip) -partitionWithKey p (Bin _ kx x l r) - | p kx x = (join kx x l1 r1,merge l2 r2) - | otherwise = (merge l1 r1,join kx x l2 r2) +partitionWithKey p0 t0 = toPair $ go p0 t0 where - (l1,l2) = partitionWithKey p l - (r1,r2) = partitionWithKey p r + go _ Tip = (Tip :*: Tip) + go p (Bin _ kx x l r) + | p kx x = join kx x l1 r1 :*: merge l2 r2 + | otherwise = merge l1 r1 :*: join kx x l2 r2 + where + (l1 :*: l2) = go p l + (r1 :*: r2) = go p r -- | /O(n)/. Map values and collect the 'Just' results. -- @@ -1598,13 +1600,15 @@ mapEither f m -- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")]) mapEitherWithKey :: (k -> a -> Either b c) -> Map k a -> (Map k b, Map k c) -mapEitherWithKey _ Tip = (Tip, Tip) -mapEitherWithKey f (Bin _ kx x l r) = case f kx x of - Left y -> (join kx y l1 r1, merge l2 r2) - Right z -> (merge l1 r1, join kx z l2 r2) - where - (l1,l2) = mapEitherWithKey f l - (r1,r2) = mapEitherWithKey f r +mapEitherWithKey f0 t0 = toPair $ go f0 t0 + where + go _ Tip = (Tip :*: Tip) + go f (Bin _ kx x l r) = case f kx x of + Left y -> join kx y l1 r1 :*: merge l2 r2 + Right z -> merge l1 r1 :*: join kx z l2 r2 + where + (l1 :*: l2) = go f l + (r1 :*: r2) = go f r {-------------------------------------------------------------------- Mapping @@ -2145,23 +2149,27 @@ trim (JustS lk) (JustS hk) t = middle lk hk t where middle lo hi (Bin _ k _ _ r -- See Note: Type of local 'go' function trimLookupLo :: Ord k => k -> MaybeS k -> Map k a -> (Maybe a, Map k a) -trimLookupLo lk NothingS t = greater lk t - where greater :: Ord k => k -> Map k a -> (Maybe a, Map k a) - greater lo t'@(Bin _ kx x l r) = case compare lo kx of LT -> lookup lo l `strictPair` t' - EQ -> (Just x, r) - GT -> greater lo r - greater _ Tip = (Nothing, Tip) -trimLookupLo lk (JustS hk) t = middle lk hk t - where middle :: Ord k => k -> k -> Map k a -> (Maybe a, Map k a) - middle lo hi t'@(Bin _ kx x l r) = case compare lo kx of LT | kx < hi -> lookup lo l `strictPair` t' - | otherwise -> middle lo hi l - EQ -> Just x `strictPair` lesser hi r - GT -> middle lo hi r - middle _ _ Tip = (Nothing, Tip) - - lesser :: Ord k => k -> Map k a -> Map k a - lesser hi (Bin _ k _ l _) | k >= hi = lesser hi l - lesser _ t' = t' +trimLookupLo lk0 mhk0 t0 = toPair $ go lk0 mhk0 t0 + where + go lk NothingS t = greater lk t + where greater :: Ord k => k -> Map k a -> StrictPair (Maybe a) (Map k a) + greater lo t'@(Bin _ kx x l r) = case compare lo kx of + LT -> lookup lo l :*: t' + EQ -> (Just x :*: r) + GT -> greater lo r + greater _ Tip = (Nothing :*: Tip) + go lk (JustS hk) t = middle lk hk t + where middle :: Ord k => k -> k -> Map k a -> StrictPair (Maybe a) (Map k a) + middle lo hi t'@(Bin _ kx x l r) = case compare lo kx of + LT | kx < hi -> lookup lo l :*: t' + | otherwise -> middle lo hi l + EQ -> Just x :*: lesser hi r + GT -> middle lo hi r + middle _ _ Tip = (Nothing :*: Tip) + + lesser :: Ord k => k -> Map k a -> Map k a + lesser hi (Bin _ k _ l _) | k >= hi = lesser hi l + lesser _ t' = t' #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE trimLookupLo #-} #endif @@ -2209,13 +2217,15 @@ filterLt (JustS b) t = filter' b t -- > split 6 (fromList [(5,"a"), (3,"b")]) == (fromList [(3,"b"), (5,"a")], empty) split :: Ord k => k -> Map k a -> (Map k a,Map k a) -split k t = k `seq` - case t of - Tip -> (Tip, Tip) - Bin _ kx x l r -> case compare k kx of - LT -> let (lt,gt) = split k l in (lt,join kx x gt r) - GT -> let (lt,gt) = split k r in (join kx x l lt,gt) - EQ -> (l,r) +split k0 t0 = k0 `seq` toPair $ go k0 t0 + where + go k t = + case t of + Tip -> (Tip :*: Tip) + Bin _ kx x l r -> case compare k kx of + LT -> let (lt :*: gt) = go k l in lt :*: join kx x gt r + GT -> let (lt :*: gt) = go k r in join kx x l lt :*: gt + EQ -> (l :*: r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE split #-} #endif @@ -2234,8 +2244,12 @@ splitLookup k t = k `seq` case t of Tip -> (Tip,Nothing,Tip) Bin _ kx x l r -> case compare k kx of - LT -> let (lt,z,gt) = splitLookup k l in (lt,z,join kx x gt r) - GT -> let (lt,z,gt) = splitLookup k r in (join kx x l lt,z,gt) + LT -> let (lt,z,gt) = splitLookup k l + gt' = join kx x gt r + in gt' `seq` (lt,z,gt') + GT -> let (lt,z,gt) = splitLookup k r + lt' = join kx x l lt + in lt' `seq` (lt',z,gt) EQ -> (l,Just x,r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE splitLookup #-} diff --git a/Data/Map/Strict.hs b/Data/Map/Strict.hs index 5e2912d781e7e1b1f0782c2985946b9a35d15f1e..de1a3666b7b286a247e2a1ae980f75225cebf1aa 100644 --- a/Data/Map/Strict.hs +++ b/Data/Map/Strict.hs @@ -434,19 +434,19 @@ insertWithKey = go -- See Map.Base.Note: Type of local 'go' function insertLookupWithKey :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> (Maybe a, Map k a) -insertLookupWithKey = go +insertLookupWithKey f0 kx0 x0 t0 = toPair $ go f0 kx0 x0 t0 where - go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> (Maybe a, Map k a) + go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> StrictPair (Maybe a) (Map k a) STRICT_2_3_OF_4(go) - go _ kx x Tip = Nothing `strictPair` singleton kx x + go _ kx x Tip = Nothing :*: singleton kx x go f kx x (Bin sy ky y l r) = case compare kx ky of - LT -> let (found, l') = go f kx x l - in found `strictPair` balanceL ky y l' r - GT -> let (found, r') = go f kx x r - in found `strictPair` balanceR ky y l r' + LT -> let (found :*: l') = go f kx x l + in found :*: balanceL ky y l' r + GT -> let (found :*: r') = go f kx x r + in found :*: balanceR ky y l r' EQ -> let x' = f kx x y - in x' `seq` (Just y `strictPair` Bin sy kx x' l r) + in x' `seq` (Just y :*: Bin sy kx x' l r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE insertLookupWithKey #-} #else @@ -547,20 +547,20 @@ updateWithKey = go -- See Map.Base.Note: Type of local 'go' function updateLookupWithKey :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a) -updateLookupWithKey = go +updateLookupWithKey f0 k0 t0 = toPair $ go f0 k0 t0 where - go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a) + go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> StrictPair (Maybe a) (Map k a) STRICT_2_OF_3(go) - go _ _ Tip = (Nothing,Tip) + go _ _ Tip = (Nothing :*: Tip) go f k (Bin sx kx x l r) = case compare k kx of - LT -> let (found,l') = go f k l - in found `strictPair` balanceR kx x l' r - GT -> let (found,r') = go f k r - in found `strictPair` balanceL kx x l r' + LT -> let (found :*: l') = go f k l + in found :*: balanceR kx x l' r + GT -> let (found :*: r') = go f k r + in found :*: balanceL kx x l r' EQ -> case f kx x of - Just x' -> x' `seq` (Just x' `strictPair` Bin sx kx x' l r) - Nothing -> (Just x,glue l r) + Just x' -> x' `seq` (Just x' :*: Bin sx kx x' l r) + Nothing -> (Just x :*: glue l r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE updateLookupWithKey #-} #else @@ -899,13 +899,15 @@ mapEither f m -- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")]) mapEitherWithKey :: (k -> a -> Either b c) -> Map k a -> (Map k b, Map k c) -mapEitherWithKey _ Tip = (Tip, Tip) -mapEitherWithKey f (Bin _ kx x l r) = case f kx x of - Left y -> y `seq` (join kx y l1 r1 `strictPair` merge l2 r2) - Right z -> z `seq` (merge l1 r1 `strictPair` join kx z l2 r2) - where - (l1,l2) = mapEitherWithKey f l - (r1,r2) = mapEitherWithKey f r +mapEitherWithKey f0 t0 = toPair $ go f0 t0 + where + go _ Tip = (Tip :*: Tip) + go f (Bin _ kx x l r) = case f kx x of + Left y -> y `seq` (join kx y l1 r1 :*: merge l2 r2) + Right z -> z `seq` (merge l1 r1 :*: join kx z l2 r2) + where + (l1 :*: l2) = go f l + (r1 :*: r2) = go f r {-------------------------------------------------------------------- Mapping diff --git a/Data/Set/Base.hs b/Data/Set/Base.hs index 0ac9ee0b8afb11c63d7ce37e1dbc872e9f7f5e38..600f3d2a396de42f2333c77144fd8babbc39ebe6 100644 --- a/Data/Set/Base.hs +++ b/Data/Set/Base.hs @@ -182,6 +182,8 @@ import qualified Data.Foldable as Foldable import Data.Typeable import Control.DeepSeq (NFData(rnf)) +import Data.StrictPair + #if __GLASGOW_HASKELL__ import GHC.Exts ( build ) import Text.Read @@ -623,11 +625,13 @@ filter p (Bin _ x l r) -- the predicate and one with all elements that don't satisfy the predicate. -- See also 'split'. partition :: (a -> Bool) -> Set a -> (Set a,Set a) -partition _ Tip = (Tip, Tip) -partition p (Bin _ x l r) = case (partition p l, partition p r) of - ((l1, l2), (r1, r2)) - | p x -> (join x l1 r1, merge l2 r2) - | otherwise -> (merge l1 r1, join x l2 r2) +partition p0 t0 = toPair $ go p0 t0 + where + go _ Tip = (Tip :*: Tip) + go p (Bin _ x l r) = case (go p l, go p r) of + ((l1 :*: l2), (r1 :*: r2)) + | p x -> join x l1 r1 :*: merge l2 r2 + | otherwise -> merge l1 r1 :*: join x l2 r2 {---------------------------------------------------------------------- Map @@ -958,12 +962,14 @@ filterLt (JustS b) t = filter' b t -- where @set1@ comprises the elements of @set@ less than @x@ and @set2@ -- comprises the elements of @set@ greater than @x@. split :: Ord a => a -> Set a -> (Set a,Set a) -split _ Tip = (Tip,Tip) -split x (Bin _ y l r) - = case compare x y of - LT -> let (lt,gt) = split x l in (lt,join y gt r) - GT -> let (lt,gt) = split x r in (join y l lt,gt) - EQ -> (l,r) +split x0 t0 = toPair $ go x0 t0 + where + go _ Tip = (Tip :*: Tip) + go x (Bin _ y l r) + = case compare x y of + LT -> let (lt :*: gt) = go x l in (lt :*: join y gt r) + GT -> let (lt :*: gt) = go x r in (join y l lt :*: gt) + EQ -> (l :*: r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE split #-} #endif @@ -974,8 +980,12 @@ splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a) splitMember _ Tip = (Tip, False, Tip) splitMember x (Bin _ y l r) = case compare x y of - LT -> let (lt, found, gt) = splitMember x l in (lt, found, join y gt r) - GT -> let (lt, found, gt) = splitMember x r in (join y l lt, found, gt) + LT -> let (lt, found, gt) = splitMember x l + gt' = join y gt r + in gt' `seq` (lt, found, gt') + GT -> let (lt, found, gt) = splitMember x r + lt' = join y l lt + in lt' `seq` (lt', found, gt) EQ -> (l, True, r) #if __GLASGOW_HASKELL__ >= 700 {-# INLINABLE splitMember #-} diff --git a/Data/StrictPair.hs b/Data/StrictPair.hs index 551482179e621bbadc966ea8fcdfe5657356132a..0e06534a29b857853423efa03e39bdeeb362fac4 100644 --- a/Data/StrictPair.hs +++ b/Data/StrictPair.hs @@ -1,6 +1,9 @@ -module Data.StrictPair (strictPair) where +module Data.StrictPair (StrictPair(..), toPair) where --- | Evaluate both argument to WHNF and create a pair of the result. -strictPair :: a -> b -> (a, b) -strictPair x y = x `seq` y `seq` (x, y) -{-# INLINE strictPair #-} +-- | Same as regular Haskell pairs, but (x :*: _|_) = (_|_ :*: y) = +-- _|_ +data StrictPair a b = !a :*: !b + +toPair :: StrictPair a b -> (a, b) +toPair (x :*: y) = (x, y) +{-# INLINE toPair #-} \ No newline at end of file diff --git a/benchmarks/Map.hs b/benchmarks/Map.hs index 1efe2457147ef7ede8625244bdb574da3fca4b4b..70a7cc91815b6f2cac66c14f9e03a0ecd5d54c87 100644 --- a/benchmarks/Map.hs +++ b/benchmarks/Map.hs @@ -57,6 +57,7 @@ main = do , bench "union" $ whnf (M.union m_even) m_odd , bench "difference" $ whnf (M.difference m) m_even , bench "intersection" $ whnf (M.intersection m) m_even + , bench "split" $ whnf (M.split (bound `div` 2)) m ] where bound = 2^10