From 8754b08f24a146eea6ed71fc79016fb012a0521c Mon Sep 17 00:00:00 2001
From: Benjamin Maurer <maurer.benjamin@gmail.com>
Date: Wed, 14 Sep 2022 09:47:25 +0200
Subject: [PATCH] Make Phi-funs not count as uses in next-use-distance update.

Next-use-distance analysis would count Phis as uses,
so when updating NUD, distance of all Phi args ended up the same.
This does not line-up with how Spiller looks at loop headers
and so we end up spilling and reloading values live at the loop header.

This change makes it so that initial NUD calculation still works
(does not loop infinitely bc. of increasing distance bc. of loop exit
weight), but update works as expected.
---
 compiler/GHC/CmmToAsm/SSA/NextUseDistance.hs | 77 +++++++++++++++++---
 compiler/GHC/CmmToAsm/SSA/Spill.hs           | 25 ++-----
 compiler/GHC/CmmToAsm/SSA/Utils.hs           | 12 ++-
 3 files changed, 86 insertions(+), 28 deletions(-)

diff --git a/compiler/GHC/CmmToAsm/SSA/NextUseDistance.hs b/compiler/GHC/CmmToAsm/SSA/NextUseDistance.hs
index 81206bbec74c..c9577a9094eb 100644
--- a/compiler/GHC/CmmToAsm/SSA/NextUseDistance.hs
+++ b/compiler/GHC/CmmToAsm/SSA/NextUseDistance.hs
@@ -12,7 +12,7 @@ module GHC.CmmToAsm.SSA.NextUseDistance (
     nextUseDistanceToLiveness,
     distanceWeights,
     getPhiArgsForBlock,
-    targetPhiInfo,
+    targetPhiMapping,
 
     redoNextUseAnalysis,
 
@@ -27,14 +27,15 @@ import GHC.Prelude
 
 import GHC.CmmToAsm.SSA
 import GHC.CmmToAsm.SSA.LivenessTypes
+import GHC.CmmToAsm.SSA.Utils
 
 import GHC.CmmToAsm.CFG
 import GHC.CmmToAsm.Instr
 import GHC.CmmToAsm.Reg.Liveness
 import GHC.CmmToAsm.Reg.Target
-import GHC.CmmToAsm.Types
+import GHC.CmmToAsm.Reg.Utils
 
-import GHC.Cmm (GenCmmDecl(..))
+import GHC.Cmm (GenCmmDecl(..), GenBasicBlock(BasicBlock), blockId)
 import GHC.Cmm.BlockId
 import GHC.Cmm.Dataflow.Collections
 import GHC.Cmm.Dataflow.Label
@@ -53,7 +54,7 @@ import GHC.Utils.Misc
 import GHC.Utils.Outputable
 import GHC.Utils.Panic
 
-import Data.List (find, transpose, mapAccumL, mapAccumR)
+import Data.List (find, transpose, mapAccumL, mapAccumR, elemIndex)
 import Data.Maybe
 
 -- Debug only
@@ -212,6 +213,35 @@ distanceWeights cfg loopLvls bid = listToUFM succLvls
 
 -- Need an easy way to check if LR reaches Phi-arg in successor for
 -- deaths at branch instructions.
+--
+-- Note: this was changed last-minute, bc. I realized I needed different
+-- things for Next-Use-Analysis and update of NUDs.
+targetPhiMapping
+    :: Instruction instr
+    => ReverseCFG
+    -> UniqFM BlockId (LiveSsaBasicBlock instr)
+    -- ^ Block table
+    -> BlockId
+    -- ^ Source block (edge from)
+    -> BlockId
+    -- ^ Destination block (edge to)
+    -> RegMap Unique
+    -- ^ Map from phi def to arg
+
+targetPhiMapping rcfg blks srcBid dstBid
+ = let  -- TODO: copied from SSA.Spill for testing
+        mCol        = elemIndex srcBid $ getPredecessors rcfg dstBid
+        phis        = fromMaybe [] $ ssaBBPhiFuns <$> lookupUFM blks dstBid
+        assocDefs i = zipWith (\d r -> (d, getUnique r)) (map phiDef phis)
+                    $ (transpose $ map phiArgs phis) !! i
+        -- Map def to arg (unique)
+        phiMap      = toRegMap
+                    $ listToUFM
+                    $ maybe [] assocDefs mCol
+   in   phiMap
+
+
+-- Get all phi defs and args for target.
 targetPhiInfo
     :: Instruction instr
     => ReverseCFG
@@ -228,6 +258,7 @@ targetPhiInfo rcfg blks srcBid dstBid
  = let  mBlk    = lookupUFM blks dstBid
         phiArgs = maybe emptyUniqSet (getPhiArgsForBlock rcfg srcBid) mBlk
         phiDefs = maybe [] ssaBBPhiDefs mBlk
+
    in   (phiDefs, phiArgs)
 
 
@@ -386,14 +417,15 @@ nextUseAnalysis_bwd platform weights targetPhiArgs globalNus nextUses li@(LiveIn
      nextUses_br  = incNudMap $ plusUFM_C min liveReachingTargets nextUses'
 
      liveReachingTargets
-                  = liveInTargetsOnly weights targetPhiArgs globalNus targets
+                  = initialCrossBranchNUDs weights targetPhiArgs globalNus targets
 
      r_dying_br   = unionUniqSets
                 (nudMapToRegSet $ liveReachingTargets `minusUFM` nextUses)
                 (mkUniqSet r_dying)
 
 
-liveInTargetsOnly
+-- | Get Next-Use-Distances across branch, with loop weight added.
+initialCrossBranchNUDs
     :: BlockFM Int
        -- ^ Successor id to weight - add weight to loop exit edges
     -> (BlockId -> ([VirtualReg], RegSet))
@@ -405,7 +437,7 @@ liveInTargetsOnly
     -> NudMap
        -- ^ (Minimum) Next use distances from branch to targets.
 
-liveInTargetsOnly weights targetPhiArgs globalNus targets
+initialCrossBranchNUDs weights targetPhiArgs globalNus targets
      = addUseNudMap liveInTargets
      $ nonDetEltsUniqSet phiArgTargets
      -- Order doesn't matter. See Note [Unique Determinism and code generation]
@@ -428,6 +460,32 @@ liveInTargetsOnly weights targetPhiArgs globalNus targets
             $ concatMap (map getUnique) phiDefs
 
 
+-- | Map live-out vregs to Phi args and get their next use distances.
+getNUDsInTarget
+    :: BlockFM Int
+       -- ^ Successor id to weight - add weight to loop exit edges
+    -> (BlockId -> RegMap Unique)
+        -- ^ Function from BlockId to Phi defs and Phi args.
+    -> GlobalNextUses
+       -- ^ Map of blocks to next use distances .
+    -> [BlockId]
+       -- ^ Branch targets.
+    -> NudMap
+       -- ^ (Minimum) Next use distances from branch to targets.
+
+getNUDsInTarget weights targetPhiArgs globalNus targets
+     = listToUFM
+        $ map (\(NUD r d) -> let r' = renameVReg r (lookupUFM phiMap r)
+                             in  (r', NUD r' d))
+        $ nonDetEltsUFM $ succNuds
+ where
+     phiMap = plusUFMList $ map targetPhiArgs targets
+
+     succNuds = mergeNUDs $ weightedTargetNud globalNus weights targets
+
+
+-- | Get Next-Use-Distance maps for targets, with weights added
+--   for loop exit edges.
 weightedTargetNud :: GlobalNextUses -> BlockFM Int -> [BlockId]
           -> [NudMap]
 
@@ -467,7 +525,7 @@ updateNextUseDists
     :: Instruction instr
     => Platform
     -> BlockFM Int
-    -> (BlockId -> ([VirtualReg], RegSet))
+    -> (BlockId -> RegMap Unique)
     -> GlobalNextUses
     -> NudMap
     -> [LiveInstr instr]
@@ -506,7 +564,7 @@ updateNextUseDists platform weights targetPhis gnud nextUses insns
 
             -- Min next use distances at branch targets.
             liveInTargets
-                = liveInTargetsOnly weights targetPhis gnud targets
+                = getNUDsInTarget weights targetPhis gnud targets
 
             -- Merge the next use distances updated with this instruction's uses
             -- and then add anything left of interest from branch targets.
@@ -639,7 +697,6 @@ patchInstr instr reg new
  = patchReg reg (RegVirtual new) instr
 
 
--- Taken from Spill.hs
 patchReg
     :: Instruction instr
     => Reg -> Reg -> instr -> instr
diff --git a/compiler/GHC/CmmToAsm/SSA/Spill.hs b/compiler/GHC/CmmToAsm/SSA/Spill.hs
index 5d1e1fdb537f..34637e0de9cc 100644
--- a/compiler/GHC/CmmToAsm/SSA/Spill.hs
+++ b/compiler/GHC/CmmToAsm/SSA/Spill.hs
@@ -19,6 +19,7 @@ import GHC.Prelude
 import GHC.CmmToAsm.SSA
 import GHC.CmmToAsm.SSA.NextUseDistance
 import GHC.CmmToAsm.SSA.FixupBlocks
+import GHC.CmmToAsm.SSA.Utils
 
 import GHC.CmmToAsm.CFG
 import GHC.CmmToAsm.Config
@@ -60,6 +61,7 @@ import Data.Ord
 
 -- DEBUG
 -- import GHC.Utils.Trace
+-- import Debug.Trace (traceM)
 
 
 regSpillAll
@@ -301,7 +303,7 @@ initInRegsSets loops cfg rcfg avail entryIds blkTbl bid
     patchSet    = mapUniqSet (\r -> renameVReg r $ lookupUFM phiMap r)
 
     spillBlk = do
-        blk'    <- spillBlock avail edgeWeights (targetPhiInfo rcfg blkTbl bid) blk
+        blk'    <- spillBlock avail edgeWeights (targetPhiMapping rcfg blkTbl bid) blk
         return $ addToUFM blkTbl bid blk'
      where  blk         = lookupWithDefaultUFM blkTbl impossible bid
             impossible  = pprPanic "SSA.Spill.initInRegsSets: Block does not exist:"
@@ -564,7 +566,7 @@ spillBlock
     -- ^ Registers per register class.
     -> (BlockId -> BlockFM Int)
     -- ^ Block to execution frequency (weight).
-    -> (BlockId -> ([VirtualReg], RegSet))
+    -> (BlockId -> RegMap Unique)
     -- ^ Which Phi functions are reached from this block.
     -> LiveSsaBasicBlock instr
     -- ^ Block with instructions.
@@ -610,7 +612,7 @@ minAlgorithm
     => BlockId
     -> BlockFM Int
         -- ^ Edge weights from this block to successors.
-    -> (BlockId -> ([VirtualReg], RegSet))
+    -> (BlockId -> RegMap Unique)
         -- ^ Get Phi defs and args for target block.
     -> (RegClass -> Int)
         -- ^ Registers per register class.
@@ -772,7 +774,9 @@ minAlgorithm bid weights phiTargets avail (IPS nextUses inRegs spilled spillmap)
 
         stateFinal  = IPS nextUsesFinal inRegsFinal spilledFinal spillmapW
 
-        liveBr t    = (snd $ phiTargets t) `unionUniqSets`
+        targetPhiArgs = -- (\x -> pprTrace "targetPhiArgs in " (ppr bid <> ppr x) x) .
+            mkUniqSet . nonDetEltsUFM . mapUFM (RegVirtual . VirtualRegI) . phiTargets
+        liveBr t    = (targetPhiArgs t) `unionUniqSets`
                     (fromMaybe emptyRegSet $ nudMapToRegSet <$> lookupUFM gnud t)
 
         liveInTargets
@@ -914,13 +918,6 @@ mkReloadForReg bid sm reg   = LiveInstr (RELOAD lookupStackSlot reg)
                                     \no slot defined for spilled reg" (ppr reg <> text " in " <> ppr bid)
 
 
--- | Swap virtual reg's unique, if present.
-renameVReg :: Reg -> Maybe Unique -> Reg
-renameVReg rr@(RegReal _) _ = rr
-renameVReg r Nothing = r
-renameVReg (RegVirtual vr) (Just u) = RegVirtual $ renameVirtualReg u vr
-
-
 -- SpillMap --
 
 data SpillMap
@@ -1075,12 +1072,6 @@ getIncomingState bid = do
     return res
 
 
-hasIncomingStateFrom :: BlockId -> SpillM (BlockSet)
-hasIncomingStateFrom bid = do
-    mInRegs     <- gets $ \s -> lookupUFM (spsIncomingState s) bid
-    return $ maybe emptyUniqSet (mkUniqSet . map fstOf3) mInRegs
-
-
 takeIncomingStates :: BlockId -> SpillM ([(BlockId, PartitionedRegSet, RegSet)])
 takeIncomingStates to = do
     inStates     <- fromMaybe [] <$> (gets $ \s -> lookupUFM (spsIncomingState s) to)
diff --git a/compiler/GHC/CmmToAsm/SSA/Utils.hs b/compiler/GHC/CmmToAsm/SSA/Utils.hs
index 43bcfb4cdafc..0393b36c729e 100644
--- a/compiler/GHC/CmmToAsm/SSA/Utils.hs
+++ b/compiler/GHC/CmmToAsm/SSA/Utils.hs
@@ -3,7 +3,8 @@
 --
 
 module GHC.CmmToAsm.SSA.Utils (
-    mkLoopInfos
+    mkLoopInfos,
+    renameVReg
 ) where
 
 import GHC.Prelude
@@ -13,6 +14,8 @@ import GHC.CmmToAsm.SSA
 import GHC.CmmToAsm.CFG
 
 import GHC.Cmm (GenCmmDecl(..))
+import GHC.Platform.Reg
+import GHC.Types.Unique
 
 
 mkLoopInfos
@@ -26,3 +29,10 @@ mkLoopInfos _ cmmProc@(CmmProc _ _ _ (BlkTbl [] _)) = (Nothing, cmmProc)
 
 mkLoopInfos cfg cmmProc@(CmmProc _ _ _ (BlkTbl (entry : _) _))
  = (Just $ loopInfo cfg entry, cmmProc)
+
+
+-- | Swap virtual reg's unique, if present.
+renameVReg :: Reg -> Maybe Unique -> Reg
+renameVReg rr@(RegReal _) _ = rr
+renameVReg r Nothing = r
+renameVReg (RegVirtual vr) (Just u) = RegVirtual $ renameVirtualReg u vr
\ No newline at end of file
-- 
GitLab