Commit de1160be authored by Joachim Breitner's avatar Joachim Breitner

Refactor the story around switches (#10137)

This re-implements the code generation for case expressions at the Stg →
Cmm level, both for data type cases as well as for integral literal
cases. (Cases on float are still treated as before).

The goal is to allow for fancier strategies in implementing them, for a
cleaner separation of the strategy from the gritty details of Cmm, and
to run this later than the Common Block Optimization, allowing for one
way to attack #10124. The new module CmmSwitch contains a number of
notes explaining this changes. For example, it creates larger
consecutive jump tables than the previous code, if possible.

nofib shows little significant overall improvement of runtime. The
rather large wobbling comes from changes in the code block order
(see #8082, not much we can do about it). But the decrease in code size
alone makes this worthwhile.

```
        Program           Size    Allocs   Runtime   Elapsed  TotalMem
            Min          -1.8%      0.0%     -6.1%     -6.1%     -2.9%
            Max          -0.7%     +0.0%     +5.6%     +5.7%     +7.8%
 Geometric Mean          -1.4%     -0.0%     -0.3%     -0.3%     +0.0%
```

Compilation time increases slightly:
```
        -1 s.d.                -----            -2.0%
        +1 s.d.                -----            +2.5%
        Average                -----            +0.3%
```

The test case T783 regresses a lot, but it is the only one exhibiting
any regression. The cause is the changed order of branches in an
if-then-else tree, which makes the hoople data flow analysis traverse
the blocks in a suboptimal order. Reverting that gets rid of this
regression, but has a consistent, if only very small (+0.2%), negative
effect on runtime. So I conclude that this test is an extreme outlier
and no reason to change the code.

Differential Revision: https://phabricator.haskell.org/D720
parent e24f6381
...@@ -30,7 +30,7 @@ module Literal ...@@ -30,7 +30,7 @@ module Literal
, inIntRange, inWordRange, tARGET_MAX_INT, inCharRange , inIntRange, inWordRange, tARGET_MAX_INT, inCharRange
, isZeroLit , isZeroLit
, litFitsInChar , litFitsInChar
, onlyWithinBounds , litValue
-- ** Coercions -- ** Coercions
, word2IntLit, int2WordLit , word2IntLit, int2WordLit
...@@ -271,6 +271,17 @@ isZeroLit (MachFloat 0) = True ...@@ -271,6 +271,17 @@ isZeroLit (MachFloat 0) = True
isZeroLit (MachDouble 0) = True isZeroLit (MachDouble 0) = True
isZeroLit _ = False isZeroLit _ = False
-- | Returns the 'Integer' contained in the 'Literal', for when that makes
-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
litValue :: Literal -> Integer
litValue (MachChar c) = toInteger $ ord c
litValue (MachInt i) = i
litValue (MachInt64 i) = i
litValue (MachWord i) = i
litValue (MachWord64 i) = i
litValue (LitInteger i _) = i
litValue l = pprPanic "litValue" (ppr l)
{- {-
Coercions Coercions
~~~~~~~~~ ~~~~~~~~~
...@@ -360,16 +371,6 @@ litIsLifted :: Literal -> Bool ...@@ -360,16 +371,6 @@ litIsLifted :: Literal -> Bool
litIsLifted (LitInteger {}) = True litIsLifted (LitInteger {}) = True
litIsLifted _ = False litIsLifted _ = False
-- | x `onlyWithinBounds` (l,h) is true if l <= y < h ==> x = y
onlyWithinBounds :: Literal -> (Literal, Literal) -> Bool
onlyWithinBounds (MachChar x) (MachChar l, MachChar h) = x == l && succ x == h
onlyWithinBounds (MachInt x) (MachInt l, MachInt h) = x == l && succ x == h
onlyWithinBounds (MachWord x) (MachWord l, MachWord h) = x == l && succ x == h
onlyWithinBounds (MachInt64 x) (MachInt64 l, MachInt64 h) = x == l && succ x == h
onlyWithinBounds (MachWord64 x) (MachWord64 l, MachWord64 h) = x == l && succ x == h
onlyWithinBounds _ _ = False
{- {-
Types Types
~~~~~ ~~~~~
......
...@@ -8,6 +8,7 @@ where ...@@ -8,6 +8,7 @@ where
import BlockId import BlockId
import Cmm import Cmm
import CmmUtils import CmmUtils
import CmmSwitch (eqSwitchTargetWith)
import CmmContFlowOpt import CmmContFlowOpt
import Prelude hiding (iterate, succ, unzip, zip) import Prelude hiding (iterate, succ, unzip, zip)
...@@ -203,13 +204,10 @@ eqLastWith eqBid (CmmCondBranch c1 t1 f1) (CmmCondBranch c2 t2 f2) = ...@@ -203,13 +204,10 @@ eqLastWith eqBid (CmmCondBranch c1 t1 f1) (CmmCondBranch c2 t2 f2) =
c1 == c2 && eqBid t1 t2 && eqBid f1 f2 c1 == c2 && eqBid t1 t2 && eqBid f1 f2
eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) = eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2 t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
eqLastWith eqBid (CmmSwitch e1 bs1) (CmmSwitch e2 bs2) = eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
e1 == e2 && eqListWith (eqMaybeWith eqBid) bs1 bs2 e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
eqLastWith _ _ _ = False eqLastWith _ _ _ = False
eqListWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
eqListWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
eqMaybeWith eltEq (Just e) (Just e') = eltEq e e' eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
eqMaybeWith _ Nothing Nothing = True eqMaybeWith _ Nothing Nothing = True
......
...@@ -12,6 +12,7 @@ import Hoopl ...@@ -12,6 +12,7 @@ import Hoopl
import BlockId import BlockId
import Cmm import Cmm
import CmmUtils import CmmUtils
import CmmSwitch (mapSwitchTargets)
import Maybes import Maybes
import Panic import Panic
...@@ -355,7 +356,7 @@ replaceLabels env g ...@@ -355,7 +356,7 @@ replaceLabels env g
txnode :: CmmNode e x -> CmmNode e x txnode :: CmmNode e x -> CmmNode e x
txnode (CmmBranch bid) = CmmBranch (lookup bid) txnode (CmmBranch bid) = CmmBranch (lookup bid)
txnode (CmmCondBranch p t f) = mkCmmCondBranch (exp p) (lookup t) (lookup f) txnode (CmmCondBranch p t f) = mkCmmCondBranch (exp p) (lookup t) (lookup f)
txnode (CmmSwitch e arms) = CmmSwitch (exp e) (map (liftM lookup) arms) txnode (CmmSwitch e ids) = CmmSwitch (exp e) (mapSwitchTargets lookup ids)
txnode (CmmCall t k rg a res r) = CmmCall (exp t) (liftM lookup k) rg a res r txnode (CmmCall t k rg a res r) = CmmCall (exp t) (liftM lookup k) rg a res r
txnode fc@CmmForeignCall{} = fc{ args = map exp (args fc) txnode fc@CmmForeignCall{} = fc{ args = map exp (args fc)
, succ = lookup (succ fc) } , succ = lookup (succ fc) }
......
{-# LANGUAGE GADTs #-}
module CmmImplementSwitchPlans
( cmmImplementSwitchPlans
)
where
import Hoopl
import BlockId
import Cmm
import CmmUtils
import CmmSwitch
import UniqSupply
import DynFlags
--
-- This module replaces Switch statements as generated by the Stg -> Cmm
-- transformation, which might be huge and sparse and hence unsuitable for
-- assembly code, by proper constructs (if-then-else trees, dense jump tables).
--
-- The actual, abstract strategy is determined by createSwitchPlan in
-- CmmSwitch and returned as a SwitchPlan; here is just the implementation in
-- terms of Cmm code. See Note [Cmm Switches, the general plan] in CmmSwitch.
--
-- This division into different modules is both to clearly separte concerns,
-- but also because createSwitchPlan needs access to the constructors of
-- SwitchTargets, a data type exported abstractly by CmmSwitch.
--
-- | Traverses the 'CmmGraph', making sure that 'CmmSwitch' are suitable for
-- code generation.
cmmImplementSwitchPlans :: DynFlags -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans dflags g
| targetSupportsSwitch (hscTarget dflags) = return g
| otherwise = do
blocks' <- concat `fmap` mapM (visitSwitches dflags) (toBlockList g)
return $ ofBlockList (g_entry g) blocks'
visitSwitches :: DynFlags -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches dflags block
| (entry@(CmmEntry _ scope), middle, CmmSwitch expr ids) <- blockSplit block
= do
let plan = createSwitchPlan ids
(newTail, newBlocks) <- implementSwitchPlan dflags scope expr plan
let block' = entry `blockJoinHead` middle `blockAppend` newTail
return $ block' : newBlocks
| otherwise
= return [block]
-- Implementing a switch plan (returning a tail block)
implementSwitchPlan :: DynFlags -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan dflags scope expr = go
where
go (Unconditionally l)
= return (emptyBlock `blockJoinTail` CmmBranch l, [])
go (JumpTable ids)
= return (emptyBlock `blockJoinTail` CmmSwitch expr ids, [])
go (IfLT signed i ids1 ids2)
= do
(bid1, newBlocks1) <- go' ids1
(bid2, newBlocks2) <- go' ids2
let lt | signed = cmmSLtWord
| otherwise = cmmULtWord
scrut = lt dflags expr $ CmmLit $ mkWordCLit dflags i
lastNode = CmmCondBranch scrut bid1 bid2
lastBlock = emptyBlock `blockJoinTail` lastNode
return (lastBlock, newBlocks1++newBlocks2)
go (IfEqual i l ids2)
= do
(bid2, newBlocks2) <- go' ids2
let scrut = cmmNeWord dflags expr $ CmmLit $ mkWordCLit dflags i
lastNode = CmmCondBranch scrut bid2 l
lastBlock = emptyBlock `blockJoinTail` lastNode
return (lastBlock, newBlocks2)
-- Same but returning a label to branch to
go' (Unconditionally l)
= return (l, [])
go' p
= do
bid <- mkBlockId `fmap` getUniqueM
(last, newBlocks) <- go p
let block = CmmEntry bid scope `blockJoinHead` last
return (bid, block: newBlocks)
...@@ -14,13 +14,13 @@ import Hoopl ...@@ -14,13 +14,13 @@ import Hoopl
import Cmm import Cmm
import CmmUtils import CmmUtils
import CmmLive import CmmLive
import CmmSwitch (switchTargetsToList)
import PprCmm () import PprCmm ()
import BlockId import BlockId
import FastString import FastString
import Outputable import Outputable
import DynFlags import DynFlags
import Data.Maybe
import Control.Monad (liftM, ap) import Control.Monad (liftM, ap)
#if __GLASGOW_HASKELL__ < 709 #if __GLASGOW_HASKELL__ < 709
import Control.Applicative (Applicative(..)) import Control.Applicative (Applicative(..))
...@@ -171,9 +171,9 @@ lintCmmLast labels node = case node of ...@@ -171,9 +171,9 @@ lintCmmLast labels node = case node of
_ <- lintCmmExpr e _ <- lintCmmExpr e
checkCond dflags e checkCond dflags e
CmmSwitch e branches -> do CmmSwitch e ids -> do
dflags <- getDynFlags dflags <- getDynFlags
mapM_ checkTarget $ catMaybes branches mapM_ checkTarget $ switchTargetsToList ids
erep <- lintCmmExpr e erep <- lintCmmExpr e
if (erep `cmmEqType_ignoring_ptrhood` bWord dflags) if (erep `cmmEqType_ignoring_ptrhood` bWord dflags)
then return () then return ()
......
...@@ -23,6 +23,7 @@ module CmmNode ( ...@@ -23,6 +23,7 @@ module CmmNode (
import CodeGen.Platform import CodeGen.Platform
import CmmExpr import CmmExpr
import CmmSwitch
import DynFlags import DynFlags
import FastString import FastString
import ForeignCall import ForeignCall
...@@ -89,11 +90,10 @@ data CmmNode e x where ...@@ -89,11 +90,10 @@ data CmmNode e x where
cml_true, cml_false :: ULabel cml_true, cml_false :: ULabel
} -> CmmNode O C } -> CmmNode O C
CmmSwitch :: CmmExpr -> [Maybe Label] -> CmmNode O C -- Table branch CmmSwitch
-- The scrutinee is zero-based; :: CmmExpr -- Scrutinee, of some integral type
-- zero -> first block -> SwitchTargets -- Cases. See [Note SwitchTargets]
-- one -> second block etc -> CmmNode O C
-- Undefined outside range, and when there's a Nothing
CmmCall :: { -- A native call or tail call CmmCall :: { -- A native call or tail call
cml_target :: CmmExpr, -- never a CmmPrim to a CallishMachOp! cml_target :: CmmExpr, -- never a CmmPrim to a CallishMachOp!
...@@ -228,7 +228,7 @@ instance NonLocal CmmNode where ...@@ -228,7 +228,7 @@ instance NonLocal CmmNode where
successors (CmmBranch l) = [l] successors (CmmBranch l) = [l]
successors (CmmCondBranch {cml_true=t, cml_false=f}) = [f, t] -- meets layout constraint successors (CmmCondBranch {cml_true=t, cml_false=f}) = [f, t] -- meets layout constraint
successors (CmmSwitch _ ls) = catMaybes ls successors (CmmSwitch _ ids) = switchTargetsToList ids
successors (CmmCall {cml_cont=l}) = maybeToList l successors (CmmCall {cml_cont=l}) = maybeToList l
successors (CmmForeignCall {succ=l}) = [l] successors (CmmForeignCall {succ=l}) = [l]
...@@ -464,7 +464,7 @@ mapExp f (CmmStore addr e) = CmmStore (f addr) (f e) ...@@ -464,7 +464,7 @@ mapExp f (CmmStore addr e) = CmmStore (f addr) (f e)
mapExp f (CmmUnsafeForeignCall tgt fs as) = CmmUnsafeForeignCall (mapForeignTarget f tgt) fs (map f as) mapExp f (CmmUnsafeForeignCall tgt fs as) = CmmUnsafeForeignCall (mapForeignTarget f tgt) fs (map f as)
mapExp _ l@(CmmBranch _) = l mapExp _ l@(CmmBranch _) = l
mapExp f (CmmCondBranch e ti fi) = CmmCondBranch (f e) ti fi mapExp f (CmmCondBranch e ti fi) = CmmCondBranch (f e) ti fi
mapExp f (CmmSwitch e tbl) = CmmSwitch (f e) tbl mapExp f (CmmSwitch e ids) = CmmSwitch (f e) ids
mapExp f n@CmmCall {cml_target=tgt} = n{cml_target = f tgt} mapExp f n@CmmCall {cml_target=tgt} = n{cml_target = f tgt}
mapExp f (CmmForeignCall tgt fs as succ ret_args updfr intrbl) = CmmForeignCall (mapForeignTarget f tgt) fs (map f as) succ ret_args updfr intrbl mapExp f (CmmForeignCall tgt fs as succ ret_args updfr intrbl) = CmmForeignCall (mapForeignTarget f tgt) fs (map f as) succ ret_args updfr intrbl
...@@ -560,7 +560,7 @@ foldExpDeep f = foldExp (wrapRecExpf f) ...@@ -560,7 +560,7 @@ foldExpDeep f = foldExp (wrapRecExpf f)
mapSuccessors :: (Label -> Label) -> CmmNode O C -> CmmNode O C mapSuccessors :: (Label -> Label) -> CmmNode O C -> CmmNode O C
mapSuccessors f (CmmBranch bid) = CmmBranch (f bid) mapSuccessors f (CmmBranch bid) = CmmBranch (f bid)
mapSuccessors f (CmmCondBranch p y n) = CmmCondBranch p (f y) (f n) mapSuccessors f (CmmCondBranch p y n) = CmmCondBranch p (f y) (f n)
mapSuccessors f (CmmSwitch e arms) = CmmSwitch e (map (fmap f) arms) mapSuccessors f (CmmSwitch e ids) = CmmSwitch e (mapSwitchTargets f ids)
mapSuccessors _ n = n mapSuccessors _ n = n
-- ----------------------------------------------------------------------------- -- -----------------------------------------------------------------------------
......
...@@ -226,6 +226,7 @@ import CmmOpt ...@@ -226,6 +226,7 @@ import CmmOpt
import MkGraph import MkGraph
import Cmm import Cmm
import CmmUtils import CmmUtils
import CmmSwitch ( mkSwitchTargets )
import CmmInfo import CmmInfo
import BlockId import BlockId
import CmmLex import CmmLex
...@@ -258,6 +259,7 @@ import Data.Array ...@@ -258,6 +259,7 @@ import Data.Array
import Data.Char ( ord ) import Data.Char ( ord )
import System.Exit import System.Exit
import Data.Maybe import Data.Maybe
import qualified Data.Map as M
#include "HsVersions.h" #include "HsVersions.h"
} }
...@@ -676,24 +678,24 @@ globals :: { [GlobalReg] } ...@@ -676,24 +678,24 @@ globals :: { [GlobalReg] }
: GLOBALREG { [$1] } : GLOBALREG { [$1] }
| GLOBALREG ',' globals { $1 : $3 } | GLOBALREG ',' globals { $1 : $3 }
maybe_range :: { Maybe (Int,Int) } maybe_range :: { Maybe (Integer,Integer) }
: '[' INT '..' INT ']' { Just (fromIntegral $2, fromIntegral $4) } : '[' INT '..' INT ']' { Just ($2, $4) }
| {- empty -} { Nothing } | {- empty -} { Nothing }
arms :: { [CmmParse ([Int],Either BlockId (CmmParse ()))] } arms :: { [CmmParse ([Integer],Either BlockId (CmmParse ()))] }
: {- empty -} { [] } : {- empty -} { [] }
| arm arms { $1 : $2 } | arm arms { $1 : $2 }
arm :: { CmmParse ([Int],Either BlockId (CmmParse ())) } arm :: { CmmParse ([Integer],Either BlockId (CmmParse ())) }
: 'case' ints ':' arm_body { do b <- $4; return ($2, b) } : 'case' ints ':' arm_body { do b <- $4; return ($2, b) }
arm_body :: { CmmParse (Either BlockId (CmmParse ())) } arm_body :: { CmmParse (Either BlockId (CmmParse ())) }
: '{' body '}' { return (Right (withSourceNote $1 $3 $2)) } : '{' body '}' { return (Right (withSourceNote $1 $3 $2)) }
| 'goto' NAME ';' { do l <- lookupLabel $2; return (Left l) } | 'goto' NAME ';' { do l <- lookupLabel $2; return (Left l) }
ints :: { [Int] } ints :: { [Integer] }
: INT { [ fromIntegral $1 ] } : INT { [ $1 ] }
| INT ',' ints { fromIntegral $1 : $3 } | INT ',' ints { $1 : $3 }
default :: { Maybe (CmmParse ()) } default :: { Maybe (CmmParse ()) }
: 'default' ':' '{' body '}' { Just (withSourceNote $3 $5 $4) } : 'default' ':' '{' body '}' { Just (withSourceNote $3 $5 $4) }
...@@ -1307,7 +1309,9 @@ withSourceNote a b parse = do ...@@ -1307,7 +1309,9 @@ withSourceNote a b parse = do
-- optional range on the switch (eg. switch [0..7] {...}), or by -- optional range on the switch (eg. switch [0..7] {...}), or by
-- the minimum/maximum values from the branches. -- the minimum/maximum values from the branches.
doSwitch :: Maybe (Int,Int) -> CmmParse CmmExpr -> [([Int],Either BlockId (CmmParse ()))] doSwitch :: Maybe (Integer,Integer)
-> CmmParse CmmExpr
-> [([Integer],Either BlockId (CmmParse ()))]
-> Maybe (CmmParse ()) -> CmmParse () -> Maybe (CmmParse ()) -> CmmParse ()
doSwitch mb_range scrut arms deflt doSwitch mb_range scrut arms deflt
= do = do
...@@ -1319,22 +1323,16 @@ doSwitch mb_range scrut arms deflt ...@@ -1319,22 +1323,16 @@ doSwitch mb_range scrut arms deflt
-- Compile each case branch -- Compile each case branch
table_entries <- mapM emitArm arms table_entries <- mapM emitArm arms
let table = M.fromList (concat table_entries)
-- Construct the table dflags <- getDynFlags
let let range = fromMaybe (0, tARGET_MAX_WORD dflags) mb_range
all_entries = concat table_entries
ixs = map fst all_entries
(min,max)
| Just (l,u) <- mb_range = (l,u)
| otherwise = (minimum ixs, maximum ixs)
entries = elems (accumArray (\_ a -> Just a) dflt_entry (min,max)
all_entries)
expr <- scrut expr <- scrut
-- ToDo: check for out of range and jump to default if necessary -- ToDo: check for out of range and jump to default if necessary
emit (mkSwitch expr entries) emit $ mkSwitch expr (mkSwitchTargets False range dflt_entry table)
where where
emitArm :: ([Int],Either BlockId (CmmParse ())) -> CmmParse [(Int,BlockId)] emitArm :: ([Integer],Either BlockId (CmmParse ())) -> CmmParse [(Integer,BlockId)]
emitArm (ints,Left blockid) = return [ (i,blockid) | i <- ints ] emitArm (ints,Left blockid) = return [ (i,blockid) | i <- ints ]
emitArm (ints,Right code) = do emitArm (ints,Right code) = do
blockid <- forkLabelledCode code blockid <- forkLabelledCode code
......
...@@ -11,6 +11,7 @@ import Cmm ...@@ -11,6 +11,7 @@ import Cmm
import CmmLint import CmmLint
import CmmBuildInfoTables import CmmBuildInfoTables
import CmmCommonBlockElim import CmmCommonBlockElim
import CmmImplementSwitchPlans
import CmmProcPoint import CmmProcPoint
import CmmContFlowOpt import CmmContFlowOpt
import CmmLayoutStack import CmmLayoutStack
...@@ -71,6 +72,10 @@ cpsTop hsc_env proc = ...@@ -71,6 +72,10 @@ cpsTop hsc_env proc =
-- Any work storing block Labels must be performed _after_ -- Any work storing block Labels must be performed _after_
-- elimCommonBlocks -- elimCommonBlocks
g <- {-# SCC "createSwitchPlans" #-}
runUniqSM $ cmmImplementSwitchPlans dflags g
dump Opt_D_dump_cmm_switch "Post switch plan" g
----------- Proc points ------------------------------------------------- ----------- Proc points -------------------------------------------------
let call_pps = {-# SCC "callProcPoints" #-} callProcPoints g let call_pps = {-# SCC "callProcPoints" #-} callProcPoints g
proc_points <- proc_points <-
......
...@@ -18,6 +18,7 @@ import PprCmm () ...@@ -18,6 +18,7 @@ import PprCmm ()
import CmmUtils import CmmUtils
import CmmInfo import CmmInfo
import CmmLive (cmmGlobalLiveness) import CmmLive (cmmGlobalLiveness)
import CmmSwitch
import Data.List (sortBy) import Data.List (sortBy)
import Maybes import Maybes
import Control.Monad import Control.Monad
...@@ -295,7 +296,7 @@ splitAtProcPoints dflags entry_label callPPs procPoints procMap ...@@ -295,7 +296,7 @@ splitAtProcPoints dflags entry_label callPPs procPoints procMap
case lastNode block of case lastNode block of
CmmBranch id -> add_if_pp id rst CmmBranch id -> add_if_pp id rst
CmmCondBranch _ ti fi -> add_if_pp ti (add_if_pp fi rst) CmmCondBranch _ ti fi -> add_if_pp ti (add_if_pp fi rst)
CmmSwitch _ tbl -> foldr add_if_pp rst (catMaybes tbl) CmmSwitch _ ids -> foldr add_if_pp rst $ switchTargetsToList ids
_ -> rst _ -> rst
-- when jumping to a PP that has an info table, if -- when jumping to a PP that has an info table, if
...@@ -382,7 +383,7 @@ replaceBranches env cmmg ...@@ -382,7 +383,7 @@ replaceBranches env cmmg
last :: CmmNode O C -> CmmNode O C last :: CmmNode O C -> CmmNode O C
last (CmmBranch id) = CmmBranch (lookup id) last (CmmBranch id) = CmmBranch (lookup id)
last (CmmCondBranch e ti fi) = CmmCondBranch e (lookup ti) (lookup fi) last (CmmCondBranch e ti fi) = CmmCondBranch e (lookup ti) (lookup fi)
last (CmmSwitch e tbl) = CmmSwitch e (map (fmap lookup) tbl) last (CmmSwitch e ids) = CmmSwitch e (mapSwitchTargets lookup ids)
last l@(CmmCall {}) = l { cml_cont = Nothing } last l@(CmmCall {}) = l { cml_cont = Nothing }
-- NB. remove the continuation of a CmmCall, since this -- NB. remove the continuation of a CmmCall, since this
-- label will now be in a different CmmProc. Not only -- label will now be in a different CmmProc. Not only
......
{-# LANGUAGE GADTs #-}
module CmmSwitch (
SwitchTargets,
mkSwitchTargets,
switchTargetsCases, switchTargetsDefault, switchTargetsRange, switchTargetsSigned,
mapSwitchTargets, switchTargetsToTable, switchTargetsFallThrough,
switchTargetsToList, eqSwitchTargetWith,
SwitchPlan(..),
targetSupportsSwitch,
createSwitchPlan,
) where
import Outputable
import DynFlags
import Compiler.Hoopl (Label)
import Data.Maybe
import Data.List (groupBy)
import Data.Function (on)
import qualified Data.Map as M
-- Note [Cmm Switches, the general plan]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
--
-- Compiling a high-level switch statement, as it comes out of a STG case
-- expression, for example, allows for a surprising amount of design decisions.
-- Therefore, we cleanly separated this from the Stg → Cmm transformation, as
-- well as from the actual code generation.
--
-- The overall plan is:
-- * The Stg → Cmm transformation creates a single `SwitchTargets` in
-- emitSwitch and emitCmmLitSwitch in StgCmmUtils.hs.
-- At this stage, they are unsuitable for code generation.
-- * A dedicated Cmm transformation (CmmImplementSwitchPlans) replaces these
-- switch statements with code that is suitable for code generation, i.e.
-- a nice balanced tree of decisions with dense jump tables in the leafs.
-- The actual planning of this tree is performed in pure code in createSwitchPlan
-- in this module. See Note [createSwitchPlan].
-- * The actual code generation will not do any further processing and
-- implement each CmmSwitch with a jump tables.
--
-- When compiling to LLVM or C, CmmImplementSwitchPlans leaves the switch
-- statements alone, as we can turn a SwitchTargets value into a nice
-- switch-statement in LLVM resp. C, and leave the rest to the compiler.
--
-- See Note [CmmSwitch vs. CmmImplementSwitchPlans] why the two module are
-- separated.
-----------------------------------------------------------------------------
-- Magic Constants
--
-- There are a lot of heuristics here that depend on magic values where it is
-- hard to determine the "best" value (for whatever that means). These are the
-- magic values:
-- | Number of consecutive default values allowed in a jump table. If there are
-- more of them, the jump tables are split.
--
-- Currently 7, as it costs 7 words of additional code when a jump table is
-- split (at least on x64, determined experimentally).
maxJumpTableHole :: Integer
maxJumpTableHole = 7
-- | Minimum size of a jump table. If the number is smaller, the switch is
-- implemented using conditionals.
-- Currently 5, because an if-then-else tree of 4 values is nice and compact.
minJumpTableSize :: Int
minJumpTableSize = 5
-- | Minimum non-zero offset for a jump table. See Note [Jump Table Offset].
minJumpTableOffset :: Integer
minJumpTableOffset = 2
-----------------------------------------------------------------------------
-- Switch Targets
-- Note [SwitchTargets]:
-- ~~~~~~~~~~~~~~~~~~~~~
--
-- The branches of a switch are stored in a SwitchTargets, which consists of an
-- (optional) default jump target, and a map from values to jump targets.
--
-- If the default jump target is absent, the behaviour of the switch outside the
-- values of the map is undefined.
--
-- We use an Integer for the keys the map so that it can be used in switches on
-- unsigned as well as signed integers.
--
-- The map must not be empty.
--
-- Before code generation, the table needs to be brought into a form where all
-- entries are non-negative, so that it can be compiled into a jump table.
-- See switchTargetsToTable.
-- | A value of type SwitchTargets contains the alternatives for a 'CmmSwitch'
-- value, and knows whether the value is signed, the possible range, an
-- optional default value and a map from values to jump labels.
data SwitchTargets =
SwitchTargets
Bool -- Signed values
(Integer, Integer) -- Range
(Maybe Label) -- Default value
(M.Map Integer Label) -- The branches
deriving (Show, Eq)
-- | The smart constructr mkSwitchTargets normalises the map a bit:
-- * No entries outside the range
-- * No entries equal to the default
-- * No default if all elements have explicit values
mkSwitchTargets :: Bool -> (Integer, Integer) -> Maybe Label -> M.Map Integer Label -> SwitchTargets
mkSwitchTargets signed range@(lo,hi) mbdef ids
= SwitchTargets signed range mbdef' ids'
where
ids' = dropDefault $ restrict ids
mbdef' | defaultNeeded = mbdef
| otherwise = Nothing
-- Drop entries outside the range, if there is a range
restrict = M.filterWithKey (\x _ -> lo <= x && x <= hi)
-- Drop entries that equal the default, if there is a default
dropDefault | Just l <- mbdef = M.filter (/= l)
| otherwise = id
-- Check if the default is still needed
defaultNeeded = fromIntegral (M.size ids') /= hi-lo+1
-- | Changes all labels mentioned in the SwitchTargets value
mapSwitchTargets :: (Label -> Label) -> SwitchTargets -> SwitchTargets
mapSwitchTargets f (SwitchTargets signed range mbdef branches)
= SwitchTargets signed range (fmap f mbdef) (fmap f branches)
-- | Returns the list of non-default branches of the SwitchTargets value
switchTargetsCases :: SwitchTargets -> [(Integer, Label)]
switchTargetsCases (SwitchTargets _ _ _ branches) = M.toList branches
-- | Return the default label of the SwitchTargets value
switchTargetsDefault :: SwitchTargets -> Maybe Label
switchTargetsDefault (SwitchTargets _ _ mbdef _) = mbdef
-- | Return the range of the SwitchTargets value
switchTargetsRange :: SwitchTargets -> (Integer, Integer)
switchTargetsRange (SwitchTargets _ range _ _) = range
-- | Return whether this is used for a signed value
switchTargetsSigned :: SwitchTargets -> Bool
switchTargetsSigned (SwitchTargets signed _ _ _) = signed
-- | switchTargetsToTable creates a dense jump table, usable for code generation.
-- Returns an offset to add to the value; the list is 0-based on the result.
-- The conversion from Integer to Int is a bit of a wart, but works due to
-- wrap-around arithmetic (as verified by the CmmSwitchTest test case).
switchTargetsToTable :: SwitchTargets -> (Int, [Maybe Label])
switchTargetsToTable (SwitchTargets _ (lo,hi) mbdef branches)
= (fromIntegral (-start), [ labelFor i | i <- [start..hi] ])
where
labelFor i = case M.lookup i branches of Just l -> Just l
Nothing -> mbdef
start | lo >= 0 && lo < minJumpTableOffset = 0 -- See Note [Jump Table Offset]
| otherwise = lo
-- Note [Jump Table Offset]
-- ~~~~~~~~~~~~~~~~~~~~~~~~
--
-- Usually, the code for a jump table starting at x will first subtract x from
-- the value, to avoid a large amount of empty entries. But if x is very small,
-- the extra entries are no worse than the subtraction in terms of code size, and
-- not having to do the subtraction is quicker.
--
-- I.e. instead of
-- _u20N:
-- leaq -1(%r14),%rax
-- jmp *_n20R(,%rax,8)
-- _n20R:
-- .quad _c20p
-- .quad _c20q
-- do
-- _u20N:
-- jmp *_n20Q(,%r14,8)
--
-- _n20Q:
-- .quad 0
-- .quad _c20p
-- .quad _c20q
-- .quad _c20r
-- | The list of all labels occuring in