From 5546ad9edd507e478ad3e8d40b7e728bcfe6687b Mon Sep 17 00:00:00 2001
From: Matthew Pickering <matthewtpickering@gmail.com>
Date: Mon, 7 Aug 2023 09:55:47 +0100
Subject: [PATCH] Revert "AArch NCG: Pure refactor"

This reverts commit 00fb6e6b06598752414a0b9a92840fb6ca61338d.

See #23793
---
 compiler/GHC/CmmToAsm/AArch64/CodeGen.hs | 61 ++++++++++++++----------
 1 file changed, 35 insertions(+), 26 deletions(-)

diff --git a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
index 6bb0510ef256..052d649d7615 100644
--- a/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
+++ b/compiler/GHC/CmmToAsm/AArch64/CodeGen.hs
@@ -794,25 +794,33 @@ getRegister' config plat expr
     -- 0. TODO This should not exist! Rewrite: Reg +- 0 -> Reg
     CmmMachOp (MO_Add _) [expr'@(CmmReg (CmmGlobal _r)), CmmLit (CmmInt 0 _)] -> getRegister' config plat expr'
     CmmMachOp (MO_Sub _) [expr'@(CmmReg (CmmGlobal _r)), CmmLit (CmmInt 0 _)] -> getRegister' config plat expr'
-    -- Immediates are handled via `getArithImm` in the generic code path.
+    -- 1. Compute Reg +/- n directly.
+    --    For Add/Sub we can directly encode 12bits, or 12bits lsl #12.
+    CmmMachOp (MO_Add w) [(CmmReg reg), CmmLit (CmmInt n _)]
+      | n > 0 && n < 4096 -> return $ Any (intFormat w) (\d -> unitOL $ annExpr expr (ADD (OpReg w d) (OpReg w' r') (OpImm (ImmInteger n))))
+      -- TODO: 12bits lsl #12; e.g. lower 12 bits of n are 0; shift n >> 12, and set lsl to #12.
+      where w' = formatToWidth (cmmTypeFormat (cmmRegType reg))
+            r' = getRegisterReg plat reg
+    CmmMachOp (MO_Sub w) [(CmmReg reg), CmmLit (CmmInt n _)]
+      | n > 0 && n < 4096 -> return $ Any (intFormat w) (\d -> unitOL $ annExpr expr (SUB (OpReg w d) (OpReg w' r') (OpImm (ImmInteger n))))
+      -- TODO: 12bits lsl #12; e.g. lower 12 bits of n are 0; shift n >> 12, and set lsl to #12.
+      where w' = formatToWidth (cmmTypeFormat (cmmRegType reg))
+            r' = getRegisterReg plat reg
 
     CmmMachOp (MO_U_Quot w) [x, y] | w == W8 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL`
-                                                                        (UXTB (OpReg w reg_y) (OpReg w reg_y)) `snocOL`
-                                                                        (UDIV (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (UXTB (OpReg w reg_y) (OpReg w reg_y)) `snocOL` (UDIV (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
     CmmMachOp (MO_U_Quot w) [x, y] | w == W16 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTH (OpReg w reg_x) (OpReg w reg_x)) `snocOL`
-                                                                        (UXTH (OpReg w reg_y) (OpReg w reg_y)) `snocOL`
-                                                                        (UDIV (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTH (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (UXTH (OpReg w reg_y) (OpReg w reg_y)) `snocOL` (UDIV (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
 
     -- 2. Shifts. x << n, x >> n.
-    CmmMachOp (MO_Shl w) [x, (CmmLit (CmmInt n _))]
-      | w == W32 || w == W64
-      , 0 <= n, n < fromIntegral (widthInBits w) -> do
+    CmmMachOp (MO_Shl w) [x, (CmmLit (CmmInt n _))] | w == W32, 0 <= n, n < 32 -> do
+      (reg_x, _format_x, code_x) <- getSomeReg x
+      return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSL (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
+    CmmMachOp (MO_Shl w) [x, (CmmLit (CmmInt n _))] | w == W64, 0 <= n, n < 64 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSL (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
@@ -822,8 +830,7 @@ getRegister' config plat expr
     CmmMachOp (MO_S_Shr w) [x, y] | w == W8 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (SXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL`
-                                                                         (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (SXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
 
     CmmMachOp (MO_S_Shr w) [x, (CmmLit (CmmInt n _))] | w == W16, 0 <= n, n < 16 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
@@ -831,23 +838,24 @@ getRegister' config plat expr
     CmmMachOp (MO_S_Shr w) [x, y] | w == W16 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (SXTH (OpReg w reg_x) (OpReg w reg_x)) `snocOL`
-                                                                         (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (SXTH (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
 
-    CmmMachOp (MO_S_Shr w) [x, (CmmLit (CmmInt n _))]
-      | w == W32 || w == W64
-      , 0 <= n, n < fromIntegral (widthInBits w) -> do
+    CmmMachOp (MO_S_Shr w) [x, (CmmLit (CmmInt n _))] | w == W32, 0 <= n, n < 32 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (ASR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
+    CmmMachOp (MO_S_Shr w) [x, (CmmLit (CmmInt n _))] | w == W64, 0 <= n, n < 64 -> do
+      (reg_x, _format_x, code_x) <- getSomeReg x
+      return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (ASR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
+
+
     CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))] | w == W8, 0 <= n, n < 8 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (UBFX (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n)) (OpImm (ImmInteger (8-n)))))
     CmmMachOp (MO_U_Shr w) [x, y] | w == W8 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL`
-                                                                        (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTB (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
 
     CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))] | w == W16, 0 <= n, n < 16 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
@@ -855,12 +863,13 @@ getRegister' config plat expr
     CmmMachOp (MO_U_Shr w) [x, y] | w == W16 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       (reg_y, _format_y, code_y) <- getSomeReg y
-      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTH (OpReg w reg_x) (OpReg w reg_x))
-                                                                `snocOL` (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+      return $ Any (intFormat w) (\dst -> code_x `appOL` code_y `snocOL` annExpr expr (UXTH (OpReg w reg_x) (OpReg w reg_x)) `snocOL` (ASR (OpReg w dst) (OpReg w reg_x) (OpReg w reg_y)))
+
+    CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))] | w == W32, 0 <= n, n < 32 -> do
+      (reg_x, _format_x, code_x) <- getSomeReg x
+      return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
-    CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))]
-      | w == W32 || w == W64
-      , 0 <= n, n < fromIntegral (widthInBits w) -> do
+    CmmMachOp (MO_U_Shr w) [x, (CmmLit (CmmInt n _))] | w == W64, 0 <= n, n < 64 -> do
       (reg_x, _format_x, code_x) <- getSomeReg x
       return $ Any (intFormat w) (\dst -> code_x `snocOL` annExpr expr (LSR (OpReg w dst) (OpReg w reg_x) (OpImm (ImmInteger n))))
 
@@ -906,8 +915,8 @@ getRegister' config plat expr
           -- sign-extend both arguments to 32-bits.
           -- See Note [Signed arithmetic on AArch64].
           intOpImm :: Bool -> Width -> (Operand -> Operand -> Operand -> OrdList Instr) -> (Integer -> Width -> Maybe Operand) -> NatM (Register)
-          intOpImm {- is signed -} True  w op _encode_imm = intOp True w op
-          intOpImm                 False w op  encode_imm = do
+          intOpImm {- is signed -} True w op _encode_imm = intOp True w op
+          intOpImm False w op encode_imm = do
               -- compute x<m> <- x
               -- compute x<o> <- y
               -- <OP> x<n>, x<m>, x<o>
-- 
GitLab