Commit 9aa73892 authored by Ben Gamari's avatar Ben Gamari Committed by Ben Gamari

cmm/CBE: Use foldLocalRegsDefd

Simonpj suggested this as a follow-on to #14226 to avoid code
duplication. This also gives us the ability to CBE cases involving
foreign calls for free.

Test Plan: Validate

Reviewers: austin, simonmar, simonpj

Reviewed By: simonpj

Subscribers: michalt, simonpj, rwbarton, thomie

GHC Trac Issues: #14226

Differential Revision: https://phabricator.haskell.org/D3999
parent ddb38b51
......@@ -24,6 +24,7 @@ import qualified Data.List as List
import Data.Word
import qualified Data.Map as M
import Outputable
import DynFlags (DynFlags)
import UniqFM
import UniqDFM
import qualified TrieMap as TM
......@@ -59,11 +60,11 @@ import Control.Arrow (first, second)
-- rightfully complained: #10397
-- TODO: Use optimization fuel
elimCommonBlocks :: CmmGraph -> CmmGraph
elimCommonBlocks g = replaceLabels env $ copyTicks env g
elimCommonBlocks :: DynFlags -> CmmGraph -> CmmGraph
elimCommonBlocks dflags g = replaceLabels env $ copyTicks env g
where
env = iterate mapEmpty blocks_with_key
groups = groupByInt hash_block (postorderDfs g)
env = iterate dflags mapEmpty blocks_with_key
groups = groupByInt (hash_block dflags) (postorderDfs g)
blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups]
-- Invariant: The blocks in the list are pairwise distinct
......@@ -73,42 +74,47 @@ type Key = [Label]
type Subst = LabelMap BlockId
-- The outer list groups by hash. We retain this grouping throughout.
iterate :: Subst -> [[(Key, DistinctBlocks)]] -> Subst
iterate subst blocks
iterate :: DynFlags -> Subst -> [[(Key, DistinctBlocks)]] -> Subst
iterate dflags subst blocks
| mapNull new_substs = subst
| otherwise = iterate subst' updated_blocks
| otherwise = iterate dflags subst' updated_blocks
where
grouped_blocks :: [[(Key, [DistinctBlocks])]]
grouped_blocks = map groupByLabel blocks
merged_blocks :: [[(Key, DistinctBlocks)]]
(new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
(new_substs, merged_blocks) =
List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
where
go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db))
where
(new_subst2, db) = mergeBlockList subst dbs
(new_subst2, db) = mergeBlockList dflags subst dbs
subst' = subst `mapUnion` new_substs
updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks
mergeBlocks :: Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks)
mergeBlocks subst existing new = go new
mergeBlocks :: DynFlags -> Subst
-> DistinctBlocks -> DistinctBlocks
-> (Subst, DistinctBlocks)
mergeBlocks dflags subst existing new = go new
where
go [] = (mapEmpty, existing)
go (b:bs) = case List.find (eqBlockBodyWith (eqBid subst) b) existing of
-- This block is a duplicate. Drop it, and add it to the substitution
Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
-- This block is not a duplicate, keep it.
Nothing -> second (b:) $ go bs
mergeBlockList :: Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks)
mergeBlockList _ [] = pprPanic "mergeBlockList" empty
mergeBlockList subst (b:bs) = go mapEmpty b bs
go (b:bs) =
case List.find (eqBlockBodyWith dflags (eqBid subst) b) existing of
-- This block is a duplicate. Drop it, and add it to the substitution
Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
-- This block is not a duplicate, keep it.
Nothing -> second (b:) $ go bs
mergeBlockList :: DynFlags -> Subst -> [DistinctBlocks]
-> (Subst, DistinctBlocks)
mergeBlockList _ _ [] = pprPanic "mergeBlockList" empty
mergeBlockList dflags subst (b:bs) = go mapEmpty b bs
where
go !new_subst1 b [] = (new_subst1, b)
go !new_subst1 b1 (b2:bs) = go new_subst b bs
where
(new_subst2, b) = mergeBlocks subst b1 b2
(new_subst2, b) = mergeBlocks dflags subst b1 b2
new_subst = new_subst1 `mapUnion` new_subst2
......@@ -175,8 +181,8 @@ data HashEnv = HashEnv { localRegHashEnv :: !(LocalRegEnv DeBruijn)
, nextIndex :: !DeBruijn
}
hash_block :: CmmBlock -> HashCode
hash_block block =
hash_block :: DynFlags -> CmmBlock -> HashCode
hash_block dflags block =
--pprTrace "hash_block" (ppr (entryLabel block) $$ ppr hash)
hash
where hash_fst _ (env, h) = (env, h)
......@@ -196,20 +202,24 @@ hash_block block =
hash_node :: HashEnv -> CmmNode O x -> (HashEnv, Word32)
hash_node env n =
case n of
n | dont_care n -> pure_ 0 -- don't care
CmmAssign (CmmLocal r) e -> (bind_local_reg r env, hash_e env e)
CmmAssign r e -> pure_ $ hash_reg env r + hash_e env e
CmmStore e e' -> pure_ $ hash_e env e + hash_e env e'
CmmUnsafeForeignCall t _ as
-> pure_ $ hash_tgt env t + hash_list (hash_e env) as
CmmBranch _ -> pure_ 23 -- NB. ignore the label
CmmCondBranch p _ _ _ -> pure_ $ hash_e env p
CmmCall e _ _ _ _ _ -> pure_ $ hash_e env e
CmmForeignCall t _ _ _ _ _ _ -> pure_ $ hash_tgt env t
CmmSwitch e _ -> pure_ $ hash_e env e
_ -> error "hash_node: unknown Cmm node!"
where pure_ x = (env, x)
(env', hash)
where
hash =
case n of
n | dont_care n -> 0 -- don't care
-- don't include register as it is a binding occurrence
CmmAssign (CmmLocal _) e -> hash_e env e
CmmAssign r e -> hash_reg env r + hash_e env e
CmmStore e e' -> hash_e env e + hash_e env e'
CmmUnsafeForeignCall t _ as
-> hash_tgt env t + hash_list (hash_e env) as
CmmBranch _ -> 23 -- NB. ignore the label
CmmCondBranch p _ _ _ -> hash_e env p
CmmCall e _ _ _ _ _ -> hash_e env e
CmmForeignCall t _ _ _ _ _ _ -> hash_tgt env t
CmmSwitch e _ -> hash_e env e
_ -> error "hash_node: unknown Cmm node!"
env' = foldLocalRegsDefd dflags (flip bind_local_reg) env n
hash_reg :: HashEnv -> CmmReg -> Word32
hash_reg env (CmmLocal localReg)
......@@ -281,38 +291,45 @@ type LocalRegMapping = LocalRegEnv LocalReg
-- CmmStackSlot and CmmBlock, so we have to use a special equality for
-- these.
--
eqMiddleWith :: (BlockId -> BlockId -> Bool)
eqMiddleWith :: DynFlags
-> (BlockId -> BlockId -> Bool)
-> LocalRegMapping
-> CmmNode O O -> CmmNode O O
-> (LocalRegMapping, Bool)
eqMiddleWith eqBid env a b =
eqMiddleWith dflags eqBid env a b =
case (a, b) of
(CmmAssign (CmmLocal r1) e1, CmmAssign (CmmLocal r2) e2) ->
-- registers aren't compared since they are binding occurrences
(CmmAssign (CmmLocal _) e1, CmmAssign (CmmLocal _) e2) ->
let eq = eqExprWith eqBid env e1 e2
env' = addToUFM env r1 r2
in (env', eq)
(CmmAssign r1 e1, CmmAssign r2 e2) ->
let eq = r1 == r2
&& eqExprWith eqBid env e1 e2
in (env, eq)
in (env', eq)
(CmmStore l1 r1, CmmStore l2 r2) ->
let eq = eqExprWith eqBid env l1 l2
&& eqExprWith eqBid env r1 r2
in (env, eq)
in (env', eq)
(CmmUnsafeForeignCall t1 r1 a1, CmmUnsafeForeignCall t2 r2 a2) ->
-- result registers aren't compared since they are binding occurrences
(CmmUnsafeForeignCall t1 _ a1, CmmUnsafeForeignCall t2 _ a2) ->
let eq = t1 == t2
&& r1 == r2
&& and (zipWith (eqExprWith eqBid env) a1 a2)
in (env, eq)
in (env', eq)
_ -> (env, False)
where
env' = List.foldl' (\acc (ra,rb) -> addToUFM acc ra rb) emptyUFM
$ List.zip defd_a defd_b
defd_a = foldLocalRegsDefd dflags (flip (:)) [] a
defd_b = foldLocalRegsDefd dflags (flip (:)) [] b
eqExprWith :: (BlockId -> BlockId -> Bool)
-> LocalRegMapping
-> CmmExpr -> CmmExpr -> Bool
-> CmmExpr -> CmmExpr
-> Bool
eqExprWith eqBid env = eq
where
CmmLit l1 `eq` CmmLit l2 = eqLit l1 l2
......@@ -340,47 +357,50 @@ eqExprWith eqBid env = eq
-- Equality on the body of a block, modulo a function mapping block
-- IDs to block IDs.
eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
eqBlockBodyWith eqBid block block'
eqBlockBodyWith :: DynFlags
-> (BlockId -> BlockId -> Bool)
-> CmmBlock -> CmmBlock -> Bool
eqBlockBodyWith dflags eqBid block block'
{-
| equal = pprTrace "equal" (vcat [ppr block, ppr block']) True
| otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False
-}
= equal_go emptyUFM nodes nodes'
= equal
where (_,m,l) = blockSplit block
nodes = filter (not . dont_care) (blockToList m)
(_,m',l') = blockSplit block'
nodes' = filter (not . dont_care) (blockToList m')
-- Compare middle nodes, accumulating a local register mapping as we go.
-- We also must ensure that the lists are of equal length. Finally,
-- compare the last nodes.
equal_go :: LocalRegMapping -> [CmmNode O O] -> [CmmNode O O] -> Bool
equal_go acc (a:as) (b:bs)
| let (acc', eq) = eqMiddleWith eqBid acc a b
, eq
= equal_go acc' as bs
equal_go acc [] [] = eqLastWith eqBid acc l l'
equal_go _ _ _ = False
(env_mid, eqs_mid) =
List.mapAccumL (\acc (a,b) -> eqMiddleWith dflags eqBid acc a b)
emptyUFM
(List.zip nodes nodes')
equal = and eqs_mid && eqLastWith eqBid env_mid l l'
eqLastWith :: (BlockId -> BlockId -> Bool) -> LocalRegMapping
-> CmmNode O C -> CmmNode O C -> Bool
eqLastWith eqBid env a b =
case (a, b) of
(CmmBranch bid1, CmmBranch bid2) ->
eqBid bid1 bid2
(CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) ->
eqExprWith eqBid env c1 c2
&& l1 == l2 && eqBid t1 t2 && eqBid f1 f2
(CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) ->
eqExprWith eqBid env t1 t2
&& eqMaybeWith eqBid c1 c2
&& a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
(CmmSwitch e1 ids1, CmmSwitch e2 ids2) ->
eqExprWith eqBid env e1 e2
&& eqSwitchTargetWith eqBid ids1 ids2
_ -> False
case (a, b) of
(CmmBranch bid1, CmmBranch bid2) -> eqBid bid1 bid2
(CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) ->
eqExprWith eqBid env c1 c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
(CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) ->
t1 == t2
&& eqMaybeWith eqBid c1 c2
&& a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
(CmmSwitch e1 ids1, CmmSwitch e2 ids2) ->
eqExprWith eqBid env e1 e2 && eqSwitchTargetWith eqBid ids1 ids2
-- result registers aren't compared since they are binding occurrences
(CmmForeignCall t1 _ a1 s1 ret_args1 ret_off1 intrbl1,
CmmForeignCall t2 _ a2 s2 ret_args2 ret_off2 intrbl2) ->
t1 == t2
&& and (zipWith (eqExprWith eqBid env) a1 a2)
&& s1 == s2
&& ret_args1 == ret_args2
&& ret_off1 == ret_off2
&& intrbl1 == intrbl2
_ -> False
eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
......
......@@ -68,7 +68,7 @@ cpsTop hsc_env proc =
----------- Eliminate common blocks -------------------------------------
g <- {-# SCC "elimCommonBlocks" #-}
condPass Opt_CmmElimCommonBlocks elimCommonBlocks g
condPass Opt_CmmElimCommonBlocks (elimCommonBlocks dflags) g
Opt_D_dump_cmm_cbe "Post common block elimination"
-- Any work storing block Labels must be performed _after_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment