Commit 5b98a38a authored by Sebastian Graf's avatar Sebastian Graf

Make `UniqDSet` a newtype

Summary:
This brings the situation of `UniqDSet` in line with `UniqSet`.

@dfeuer said in D3146#92820 that he would do this, but probably
never got around to it.

Validated locally.

Reviewers: AndreasK, mpickering, bgamari, dfeuer, simonpj

Reviewed By: simonpj

Subscribers: simonpj, rwbarton, carter, dfeuer

GHC Trac Issues: #15879, #13114

Differential Revision: https://phabricator.haskell.org/D5313
parent 0f2ac24c
......@@ -961,7 +961,7 @@ renameHoleUnitId' pkg_map env uid =
IndefUnitId{ indefUnitIdComponentId = cid
, indefUnitIdInsts = insts
, indefUnitIdFreeHoles = fh })
-> if isNullUFM (intersectUFM_C const (udfmToUfm fh) env)
-> if isNullUFM (intersectUFM_C const (udfmToUfm (getUniqDSet fh)) env)
then uid
-- Functorially apply the substitution to the instantiation,
-- then check the 'PackageConfigMap' to see if there is
......
......@@ -268,7 +268,7 @@ dVarSetIntersectVarSet = uniqDSetIntersectUniqSet
-- | True if empty intersection
disjointDVarSet :: DVarSet -> DVarSet -> Bool
disjointDVarSet s1 s2 = disjointUDFM s1 s2
disjointDVarSet s1 s2 = disjointUDFM (getUniqDSet s1) (getUniqDSet s2)
-- | True if non-empty intersection
intersectsDVarSet :: DVarSet -> DVarSet -> Bool
......@@ -290,10 +290,10 @@ foldDVarSet :: (Var -> a -> a) -> a -> DVarSet -> a
foldDVarSet = foldUniqDSet
anyDVarSet :: (Var -> Bool) -> DVarSet -> Bool
anyDVarSet = anyUDFM
anyDVarSet p = anyUDFM p . getUniqDSet
allDVarSet :: (Var -> Bool) -> DVarSet -> Bool
allDVarSet = allUDFM
allDVarSet p = allUDFM p . getUniqDSet
filterDVarSet :: (Var -> Bool) -> DVarSet -> DVarSet
filterDVarSet = filterUniqDSet
......@@ -318,7 +318,7 @@ extendDVarSetList = addListToUniqDSet
-- | Convert a DVarSet to a VarSet by forgeting the order of insertion
dVarSetToVarSet :: DVarSet -> VarSet
dVarSetToVarSet = unsafeUFMToUniqSet . udfmToUfm
dVarSetToVarSet = unsafeUFMToUniqSet . udfmToUfm . getUniqDSet
-- | transCloVarSet for DVarSet
transCloDVarSet :: (DVarSet -> DVarSet)
......
......@@ -52,8 +52,7 @@ import SrcLoc
import ListSetOps( assocMaybe )
import Data.List
import Util
import UniqDFM
import UniqSet
import UniqDSet
data DsCmdEnv = DsCmdEnv {
arr_id, compose_id, first_id, app_id, choice_id, loop_id :: CoreExpr
......@@ -379,7 +378,7 @@ dsCmd ids local_vars stack_ty res_ty
res_ty
core_make_arg
core_arrow,
exprFreeIdsDSet core_arg `udfmIntersectUFM` (getUniqSet local_vars))
exprFreeIdsDSet core_arg `uniqDSetIntersectUniqSet` local_vars)
-- D, xs |- fun :: a t1 t2
-- D, xs |- arg :: t1
......@@ -408,7 +407,7 @@ dsCmd ids local_vars stack_ty res_ty
core_make_pair
(do_app ids arg_ty res_ty),
(exprsFreeIdsDSet [core_arrow, core_arg])
`udfmIntersectUFM` getUniqSet local_vars)
`uniqDSetIntersectUniqSet` local_vars)
-- D; ys |-a cmd : (t,stk) --> t'
-- D, xs |- exp :: t
......@@ -441,7 +440,7 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdApp _ cmd arg) env_ids = do
core_map
core_cmd,
free_vars `unionDVarSet`
(exprFreeIdsDSet core_arg `udfmIntersectUFM` getUniqSet local_vars))
(exprFreeIdsDSet core_arg `uniqDSetIntersectUniqSet` local_vars))
-- D; ys |-a cmd : stk t'
-- -----------------------------------------------
......@@ -479,7 +478,7 @@ dsCmd ids local_vars stack_ty res_ty
-- match the old environment and stack against the input
select_code <- matchEnvStack env_ids stack_id param_code
return (do_premap ids in_ty in_ty' res_ty select_code core_body,
free_vars `udfmMinusUFM` getUniqSet pat_vars)
free_vars `uniqDSetMinusUniqSet` pat_vars)
dsCmd ids local_vars stack_ty res_ty (HsCmdPar _ cmd) env_ids
= dsLCmd ids local_vars stack_ty res_ty cmd env_ids
......@@ -511,7 +510,7 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdIf _ mb_fun cond then_cmd else_cmd)
then_ty = envStackType then_ids stack_ty
else_ty = envStackType else_ids stack_ty
sum_ty = mkTyConApp either_con [then_ty, else_ty]
fvs_cond = exprFreeIdsDSet core_cond `udfmIntersectUFM` getUniqSet local_vars
fvs_cond = exprFreeIdsDSet core_cond `uniqDSetIntersectUniqSet` local_vars
core_left = mk_left_expr then_ty else_ty (buildEnvStack then_ids stack_id)
core_right = mk_right_expr then_ty else_ty (buildEnvStack else_ids stack_id)
......@@ -611,7 +610,7 @@ dsCmd ids local_vars stack_ty res_ty
core_matches <- matchEnvStack env_ids stack_id core_body
return (do_premap ids in_ty sum_ty res_ty core_matches core_choices,
exprFreeIdsDSet core_body `udfmIntersectUFM` getUniqSet local_vars)
exprFreeIdsDSet core_body `uniqDSetIntersectUniqSet` local_vars)
-- D; ys |-a cmd : stk --> t
-- ----------------------------------
......@@ -637,7 +636,7 @@ dsCmd ids local_vars stack_ty res_ty (HsCmdLet _ lbinds@(L _ binds) body)
res_ty
core_map
core_body,
exprFreeIdsDSet core_binds `udfmIntersectUFM` getUniqSet local_vars)
exprFreeIdsDSet core_binds `uniqDSetIntersectUniqSet` local_vars)
-- D; xs |-a ss : t
-- ----------------------------------
......@@ -892,7 +891,7 @@ dsCmdStmt ids local_vars out_ids (BindStmt _ pat cmd _ _) env_ids = do
do_compose ids before_c_ty after_c_ty out_ty
(do_first ids in_ty1 pat_ty in_ty2 core_cmd) $
do_arr ids after_c_ty out_ty proj_expr,
fv_cmd `unionDVarSet` (mkDVarSet out_ids `udfmMinusUFM` getUniqSet pat_vars))
fv_cmd `unionDVarSet` (mkDVarSet out_ids `uniqDSetMinusUniqSet` pat_vars))
-- D; xs' |-a do { ss } : t
-- --------------------------------------
......@@ -909,7 +908,7 @@ dsCmdStmt ids local_vars out_ids (LetStmt _ binds) env_ids = do
(mkBigCoreVarTupTy env_ids)
(mkBigCoreVarTupTy out_ids)
core_map,
exprFreeIdsDSet core_binds `udfmIntersectUFM` getUniqSet local_vars)
exprFreeIdsDSet core_binds `uniqDSetIntersectUniqSet` local_vars)
-- D; ys |-a do { ss; returnA -< ((xs1), (ys2)) } : ...
-- D; xs' |-a do { ss' } : t
......@@ -1029,7 +1028,7 @@ dsRecCmd ids local_vars stmts later_ids later_rets rec_ids rec_rets = do
rec_id <- newSysLocalDs rec_ty
let
env1_id_set = fv_stmts `udfmMinusUFM` getUniqSet rec_id_set
env1_id_set = fv_stmts `uniqDSetMinusUniqSet` rec_id_set
env1_ids = dVarSetElems env1_id_set
env1_ty = mkBigCoreVarTupTy env1_ids
in_pair_ty = mkCorePairTy env1_ty rec_ty
......
......@@ -82,6 +82,7 @@ import IdInfo
import Var
import VarSet
import UniqSet ( nonDetFoldUniqSet )
import UniqDSet ( getUniqDSet )
import VarEnv
import Literal ( litIsTrivial )
import Demand ( StrictSig, Demand, isStrictDmd, splitStrictSig, increaseStrictSigArity )
......@@ -1404,7 +1405,7 @@ isFunction (_, AnnLam b e) | isId b = True
isFunction _ = False
countFreeIds :: DVarSet -> Int
countFreeIds = nonDetFoldUDFM add 0
countFreeIds = nonDetFoldUDFM add 0 . getUniqDSet
-- It's OK to use nonDetFoldUDFM here because we're just counting things.
where
add :: Var -> Int -> Int
......
......@@ -46,12 +46,13 @@ module UniqDFM (
intersectUDFM, udfmIntersectUFM,
intersectsUDFM,
disjointUDFM, disjointUdfmUfm,
equalKeysUDFM,
minusUDFM,
listToUDFM,
udfmMinusUFM,
partitionUDFM,
anyUDFM, allUDFM,
pprUDFM,
pprUniqDFM, pprUDFM,
udfmToList,
udfmToUfm,
......@@ -66,6 +67,7 @@ import Outputable
import qualified Data.IntMap as M
import Data.Data
import Data.Functor.Classes (Eq1 (..))
import Data.List (sortBy)
import Data.Function (on)
import qualified Data.Semigroup as Semi
......@@ -288,6 +290,10 @@ udfmToList (UDFM m _i) =
[ (getUnique k, taggedFst v)
| (k, v) <- sortBy (compare `on` (taggedSnd . snd)) $ M.toList m ]
-- Determines whether two 'UniqDFM's contain the same keys.
equalKeysUDFM :: UniqDFM a -> UniqDFM b -> Bool
equalKeysUDFM (UDFM m1 _) (UDFM m2 _) = liftEq (\_ _ -> True) m1 m2
isNullUDFM :: UniqDFM elt -> Bool
isNullUDFM (UDFM m _) = M.null m
......
......@@ -3,14 +3,19 @@
-- |
-- Specialised deterministic sets, for things with @Uniques@
--
-- Based on @UniqDFMs@ (as you would expect).
-- Based on 'UniqDFM's (as you would expect).
-- See Note [Deterministic UniqFM] in UniqDFM for explanation why we need it.
--
-- Basically, the things need to be in class @Uniquable@.
-- Basically, the things need to be in class 'Uniquable'.
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveDataTypeable #-}
module UniqDSet (
-- * Unique set type
UniqDSet, -- type synonym for UniqFM a
getUniqDSet,
pprUniqDSet,
-- ** Manipulating these sets
delOneFromUniqDSet, delListFromUniqDSet,
......@@ -21,7 +26,6 @@ module UniqDSet (
unionUniqDSets, unionManyUniqDSets,
minusUniqDSet, uniqDSetMinusUniqSet,
intersectUniqDSets, uniqDSetIntersectUniqSet,
intersectsUniqDSets,
foldUniqDSet,
elementOfUniqDSet,
filterUniqDSet,
......@@ -34,76 +38,99 @@ module UniqDSet (
import GhcPrelude
import Outputable
import UniqDFM
import UniqSet
import Unique
type UniqDSet a = UniqDFM a
import Data.Coerce
import Data.Data
import qualified Data.Semigroup as Semi
-- See Note [UniqSet invariant] in UniqSet.hs for why we want a newtype here.
-- Beyond preserving invariants, we may also want to 'override' typeclass
-- instances.
newtype UniqDSet a = UniqDSet {getUniqDSet' :: UniqDFM a}
deriving (Data, Semi.Semigroup, Monoid)
emptyUniqDSet :: UniqDSet a
emptyUniqDSet = emptyUDFM
emptyUniqDSet = UniqDSet emptyUDFM
unitUniqDSet :: Uniquable a => a -> UniqDSet a
unitUniqDSet x = unitUDFM x x
unitUniqDSet x = UniqDSet (unitUDFM x x)
mkUniqDSet :: Uniquable a => [a] -> UniqDSet a
mkUniqDSet :: Uniquable a => [a] -> UniqDSet a
mkUniqDSet = foldl' addOneToUniqDSet emptyUniqDSet
-- The new element always goes to the right of existing ones.
addOneToUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
addOneToUniqDSet set x = addToUDFM set x x
addOneToUniqDSet (UniqDSet set) x = UniqDSet (addToUDFM set x x)
addListToUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
addListToUniqDSet = foldl' addOneToUniqDSet
delOneFromUniqDSet :: Uniquable a => UniqDSet a -> a -> UniqDSet a
delOneFromUniqDSet = delFromUDFM
delOneFromUniqDSet (UniqDSet s) = UniqDSet . delFromUDFM s
delListFromUniqDSet :: Uniquable a => UniqDSet a -> [a] -> UniqDSet a
delListFromUniqDSet = delListFromUDFM
delListFromUniqDSet (UniqDSet s) = UniqDSet . delListFromUDFM s
unionUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
unionUniqDSets = plusUDFM
unionUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (plusUDFM s t)
unionManyUniqDSets :: [UniqDSet a] -> UniqDSet a
unionManyUniqDSets [] = emptyUniqDSet
unionManyUniqDSets sets = foldr1 unionUniqDSets sets
minusUniqDSet :: UniqDSet a -> UniqDSet a -> UniqDSet a
minusUniqDSet = minusUDFM
minusUniqDSet (UniqDSet s) (UniqDSet t) = UniqDSet (minusUDFM s t)
uniqDSetMinusUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
uniqDSetMinusUniqSet xs ys = udfmMinusUFM xs (getUniqSet ys)
uniqDSetMinusUniqSet xs ys
= UniqDSet (udfmMinusUFM (getUniqDSet xs) (getUniqSet ys))
intersectUniqDSets :: UniqDSet a -> UniqDSet a -> UniqDSet a
intersectUniqDSets = intersectUDFM
intersectUniqDSets (UniqDSet s) (UniqDSet t) = UniqDSet (intersectUDFM s t)
uniqDSetIntersectUniqSet :: UniqDSet a -> UniqSet b -> UniqDSet a
uniqDSetIntersectUniqSet xs ys = xs `udfmIntersectUFM` getUniqSet ys
intersectsUniqDSets :: UniqDSet a -> UniqDSet a -> Bool
intersectsUniqDSets = intersectsUDFM
uniqDSetIntersectUniqSet xs ys
= UniqDSet (udfmIntersectUFM (getUniqDSet xs) (getUniqSet ys))
foldUniqDSet :: (a -> b -> b) -> b -> UniqDSet a -> b
foldUniqDSet = foldUDFM
foldUniqDSet c n (UniqDSet s) = foldUDFM c n s
elementOfUniqDSet :: Uniquable a => a -> UniqDSet a -> Bool
elementOfUniqDSet = elemUDFM
elementOfUniqDSet k = elemUDFM k . getUniqDSet
filterUniqDSet :: (a -> Bool) -> UniqDSet a -> UniqDSet a
filterUniqDSet = filterUDFM
filterUniqDSet p (UniqDSet s) = UniqDSet (filterUDFM p s)
sizeUniqDSet :: UniqDSet a -> Int
sizeUniqDSet = sizeUDFM
sizeUniqDSet = sizeUDFM . getUniqDSet
isEmptyUniqDSet :: UniqDSet a -> Bool
isEmptyUniqDSet = isNullUDFM
isEmptyUniqDSet = isNullUDFM . getUniqDSet
lookupUniqDSet :: Uniquable a => UniqDSet a -> a -> Maybe a
lookupUniqDSet = lookupUDFM
lookupUniqDSet = lookupUDFM . getUniqDSet
uniqDSetToList :: UniqDSet a -> [a]
uniqDSetToList = eltsUDFM
uniqDSetToList = eltsUDFM . getUniqDSet
partitionUniqDSet :: (a -> Bool) -> UniqDSet a -> (UniqDSet a, UniqDSet a)
partitionUniqDSet = partitionUDFM
partitionUniqDSet p = coerce . partitionUDFM p . getUniqDSet
-- Two 'UniqDSet's are considered equal if they contain the same
-- uniques.
instance Eq (UniqDSet a) where
UniqDSet a == UniqDSet b = equalKeysUDFM a b
getUniqDSet :: UniqDSet a -> UniqDFM a
getUniqDSet = getUniqDSet'
instance Outputable a => Outputable (UniqDSet a) where
ppr = pprUniqDSet ppr
pprUniqDSet :: (a -> SDoc) -> UniqDSet a -> SDoc
pprUniqDSet f (UniqDSet s) = pprUniqDFM f s
......@@ -336,7 +336,7 @@ nonDetUFMToList (UFM m) = map (\(k, v) -> (getUnique k, v)) $ M.toList m
ufmToIntMap :: UniqFM elt -> M.IntMap elt
ufmToIntMap (UFM m) = m
-- Determines whether two 'UniqFm's contain the same keys.
-- Determines whether two 'UniqFM's contain the same keys.
equalKeysUFM :: UniqFM a -> UniqFM b -> Bool
equalKeysUFM (UFM m1) (UFM m2) = liftEq (\_ _ -> True) m1 m2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment