diff --git a/compiler/GHC/CmmToAsm/X86/CodeGen.hs b/compiler/GHC/CmmToAsm/X86/CodeGen.hs index a8fb7b67208375150a97e30dd70505dfec5d9a0a..103a5f639ea0fcda426a4a19cbb1329f101f5659 100644 --- a/compiler/GHC/CmmToAsm/X86/CodeGen.hs +++ b/compiler/GHC/CmmToAsm/X86/CodeGen.hs @@ -3424,7 +3424,6 @@ genFMA3Code :: Width -> FMASign -> CmmExpr -> CmmExpr -> CmmExpr -> NatM Register genFMA3Code w signs x y z = do - -- For the FMA instruction, we want to compute x * y + z -- -- There are three possible instructions we could emit: @@ -3445,17 +3444,45 @@ genFMA3Code w signs x y z = do -- -- Currently we follow neither of these optimisations, -- opting to always use fmadd213 for simplicity. + -- + -- We would like to compute the result directly into the requested register. + -- To do so we must first compute `x` into the destination register. This is + -- only possible if the other arguments don't use the destination register. + -- We check for this and if there is a conflict we move the result only after + -- the computation. See #24496 how this went wrong in the past. let rep = floatFormat w (y_reg, y_code) <- getNonClobberedReg y - (z_reg, z_code) <- getNonClobberedReg z + (z_op, z_code) <- getNonClobberedOperand z x_code <- getAnyReg x + x_tmp <- getNewRegNat rep let fma213 = FMA3 rep signs FMA213 - code dst - = y_code `appOL` + + code, code_direct, code_mov :: Reg -> InstrBlock + -- Ideal: Compute the result directly into dst + code_direct dst = x_code dst `snocOL` + fma213 z_op y_reg dst + -- Fallback: Compute the result into a tmp reg and then move it. + code_mov dst = x_code x_tmp `snocOL` + fma213 z_op y_reg x_tmp `snocOL` + MOV rep (OpReg x_tmp) (OpReg dst) + + code dst = + y_code `appOL` z_code `appOL` - x_code dst `snocOL` - fma213 (OpReg z_reg) y_reg dst + ( if arg_regs_conflict then code_mov dst else code_direct dst ) + + where + + arg_regs_conflict = + y_reg == dst || + case z_op of + OpReg z_reg -> z_reg == dst + OpAddr amode -> dst `elem` addrModeRegs amode + OpImm {} -> False + + -- NB: Computing the result into a desired register using Any can be tricky. + -- So for now, we keep it simple. (See #24496). return (Any rep code) ----------- diff --git a/testsuite/tests/primops/should_run/T24496.hs b/testsuite/tests/primops/should_run/T24496.hs new file mode 100644 index 0000000000000000000000000000000000000000..d7085397dbd694636430fb379f7d57440962e573 --- /dev/null +++ b/testsuite/tests/primops/should_run/T24496.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +import GHC.Exts + +twoProductFloat# :: Float# -> Float# -> (# Float#, Float# #) +twoProductFloat# x y = let !r = x `timesFloat#` y + in (# r, fmsubFloat# x y r #) +{-# NOINLINE twoProductFloat# #-} + +twoProductDouble# :: Double# -> Double# -> (# Double#, Double# #) +twoProductDouble# x y = let !r = x *## y + in (# r, fmsubDouble# x y r #) +{-# NOINLINE twoProductDouble# #-} + +main :: IO () +main = do + print $ case twoProductFloat# 2.0# 3.0# of (# r, s #) -> (F# r, F# s) + print $ case twoProductDouble# 2.0## 3.0## of (# r, s #) -> (D# r, D# s) diff --git a/testsuite/tests/primops/should_run/T24496.stdout b/testsuite/tests/primops/should_run/T24496.stdout new file mode 100644 index 0000000000000000000000000000000000000000..167c41d6cfbf5c53601d3125701028889351c9b9 --- /dev/null +++ b/testsuite/tests/primops/should_run/T24496.stdout @@ -0,0 +1,2 @@ +(6.0,0.0) +(6.0,0.0) diff --git a/testsuite/tests/primops/should_run/all.T b/testsuite/tests/primops/should_run/all.T index 5c088e6a18db6983d98c95702c3c2f8c52613df9..672946562af61dd181ce2400833f52a4d47b8952 100644 --- a/testsuite/tests/primops/should_run/all.T +++ b/testsuite/tests/primops/should_run/all.T @@ -77,3 +77,10 @@ test('FMA_ConstantFold' test('T21624', normal, compile_and_run, ['']) test('T23071', ignore_stdout, compile_and_run, ['']) test('T22710', normal, compile_and_run, ['']) +test('T24496' + , [ when(have_cpu_feature('fma'), extra_hc_opts('-mfma')) + , js_skip # JS backend doesn't have an FMA implementation + , when(arch('wasm32'), skip) + , when(have_llvm(), extra_ways(["optllvm"])) + ] + , compile_and_run, ['-O'])