Commit 04bc50b3 authored by Sylvain Henry's avatar Sylvain Henry Committed by Marge Bot
Browse files

Bignum: implement extended GCD (#18427)

parent 92daad24
......@@ -280,30 +280,32 @@ integer_gmp_mpn_gcd(mp_limb_t r[],
/* wraps mpz_gcdext()
*
* Set g={g0,gn} to the greatest common divisor of x={x0,xn} and
* y={y0,yn}, and in addition set s={s0,sn} to coefficient
* satisfying x*s + y*t = g.
*
* The g0 array is zero-padded (so that gn is fixed).
* y={y0,yn}, and in addition set s={s0,sn} and t={t0,tn} to
* coefficients satisfying x*s + y*t = g.
*
* g0 must have space for exactly gn=min(xn,yn) limbs.
* s0 must have space for at least yn limbs.
* t0 must have space for at least xn limbs.
*
* Actual sizes are returned by pointers.
*
* return value: signed 'sn' of s={s0,sn} where |sn| >= 1
*/
mp_size_t
integer_gmp_gcdext(mp_limb_t s0[], mp_limb_t g0[],
void
integer_gmp_gcdext(mp_limb_t s0[], int32_t * ssn,
mp_limb_t t0[], int32_t * stn,
mp_limb_t g0[], int32_t * gn,
const mp_limb_t x0[], const mp_size_t xn,
const mp_limb_t y0[], const mp_size_t yn)
{
const mp_size_t gn0 = mp_size_minabs(xn, yn);
const mpz_t x = CONST_MPZ_INIT(x0, mp_limb_zero_p(x0,xn) ? 0 : xn);
const mpz_t y = CONST_MPZ_INIT(y0, mp_limb_zero_p(y0,yn) ? 0 : yn);
mpz_t g, s;
mpz_t g, s, t;
mpz_init (g);
mpz_init (s);
mpz_init (t);
mpz_gcdext (g, s, NULL, x, y);
mpz_gcdext (g, s, t, x, y);
// g must be positive (0 <= gn).
// According to the docs for mpz_gcdext(), we have:
......@@ -311,28 +313,31 @@ integer_gmp_gcdext(mp_limb_t s0[], mp_limb_t g0[],
// --> g < min(|y|, |x|)
// --> gn <= min(yn, xn)
// <-> gn <= gn0
const mp_size_t gn = g[0]._mp_size;
assert(0 <= gn && gn <= gn0);
memset(g0, 0, gn0*sizeof(mp_limb_t));
memcpy(g0, g[0]._mp_d, gn*sizeof(mp_limb_t));
const mp_size_t gn0 = mp_size_minabs(xn, yn);
*gn = g[0]._mp_size;
assert(0 <= *gn && *gn <= gn0);
memcpy(g0, g[0]._mp_d, *gn * sizeof(mp_limb_t));
mpz_clear (g);
// According to the docs for mpz_gcdext(), we have:
// |s| < |y| / 2g
// --> |s| < |y| (note g > 0)
// --> sn <= yn
const mp_size_t ssn = s[0]._mp_size;
const mp_size_t sn = mp_size_abs(ssn);
*ssn = s[0]._mp_size;
const mp_size_t sn = mp_size_abs(*ssn);
assert(sn <= mp_size_abs(yn));
memcpy(s0, s[0]._mp_d, sn*sizeof(mp_limb_t));
mpz_clear (s);
if (!sn) {
s0[0] = 0;
return 1;
}
return ssn;
// According to the docs for mpz_gcdext(), we have:
// |t| < |x| / 2g
// --> |t| < |x| (note g > 0)
// --> st <= xn
*stn = t[0]._mp_size;
const mp_size_t tn = mp_size_abs(*stn);
assert(tn <= mp_size_abs(xn));
memcpy(t0, t[0]._mp_d, tn*sizeof(mp_limb_t));
mpz_clear (t);
}
/* Truncating (i.e. rounded towards zero) integer division-quotient of MPN */
......
......@@ -16,6 +16,7 @@ import GHC.Prim
import GHC.Types
import GHC.Num.WordArray
import GHC.Num.Primitives
import {-# SOURCE #-} GHC.Num.Integer
import qualified GHC.Num.Backend.Native as Native
import qualified GHC.Num.Backend.Selected as Other
......@@ -453,3 +454,18 @@ bignat_powmod_words b e m =
in case gr `eqWord#` nr of
1# -> gr
_ -> unexpectedValue_Word# (# #)
integer_gcde
:: Integer
-> Integer
-> (# Integer, Integer, Integer #)
integer_gcde a b =
let
!(# g0,x0,y0 #) = Other.integer_gcde a b
!(# g1,x1,y1 #) = Native.integer_gcde a b
in if isTrue# (integerEq# x0 x1
&&# integerEq# y0 y1
&&# integerEq# g0 g1)
then (# g0, x0, y0 #)
else case unexpectedValue of
!_ -> (# integerZero, integerZero, integerZero #)
......@@ -19,6 +19,7 @@ import GHC.Prim
import GHC.Types
import GHC.Num.WordArray
import GHC.Num.Primitives
import qualified GHC.Num.Backend.Native as Native
default ()
......@@ -579,3 +580,19 @@ bignat_powmod_words = ghc_bignat_powmod_words
foreign import ccall unsafe ghc_bignat_powmod_words
:: Word# -> Word# -> Word# -> Word#
-- | Return extended GCD of two non-zero integers.
--
-- I.e. integer_gcde a b returns (g,x,y) so that ax + by = g
--
-- Input: a and b are non zero.
-- Output: g must be > 0
--
integer_gcde
:: Integer
-> Integer
-> (# Integer, Integer, Integer #)
integer_gcde = Native.integer_gcde
-- for now we use Native's implementation. If some FFI backend user needs a
-- specific implementation, we'll need to determine a prototype to pass and
-- return BigNat signs and sizes via FFI.
......@@ -8,6 +8,9 @@
{-# LANGUAGE UnliftedFFITypes #-}
{-# LANGUAGE NegativeLiterals #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
-- | Backend based on the GNU GMP library.
--
......@@ -22,6 +25,9 @@ import GHC.Num.WordArray
import GHC.Num.Primitives
import GHC.Prim
import GHC.Types
import GHC.Magic (runRW#)
import {-# SOURCE #-} GHC.Num.Integer
import {-# SOURCE #-} GHC.Num.BigNat
default ()
......@@ -352,6 +358,70 @@ bignat_powmod r b e m s =
case ioInt# (integer_gmp_powm# r b (wordArraySize# b) e (wordArraySize# e) m (wordArraySize# m)) s of
(# s', n #) -> mwaSetSize# r (narrowGmpSize# n) s'
integer_gcde
:: Integer
-> Integer
-> (# Integer, Integer, Integer #)
integer_gcde a b = case runRW# io of (# _, a #) -> a
where
!(# sa, ba #) = integerToBigNatSign# a
!(# sb, bb #) = integerToBigNatSign# b
!sza = bigNatSize# ba
!szb = bigNatSize# bb
-- signed sizes of a and b
!ssa = case sa of
0# -> sza
_ -> negateInt# sza
!ssb = case sb of
0# -> szb
_ -> negateInt# szb
-- gcd(a,b) < min(a,b)
!g_init_sz = minI# sza szb
-- According to https://gmplib.org/manual/Number-Theoretic-Functions.html#index-mpz_005fgcdext
-- a*x + b*y = g
-- abs(x) < abs(b) / (2 g) < abs(b)
-- abs(y) < abs(a) / (2 g) < abs(a)
!x_init_sz = szb
!y_init_sz = sza
io s =
-- allocate output arrays
case newWordArray# g_init_sz s of { (# s, mbg #) ->
case newWordArray# x_init_sz s of { (# s, mbx #) ->
case newWordArray# y_init_sz s of { (# s, mby #) ->
-- allocate space to return sizes (3x4 = 12)
case newPinnedByteArray# 12# s of { (# s, mszs #) ->
case unsafeFreezeByteArray# mszs s of { (# s, szs #) ->
let !ssx_ptr = byteArrayContents# szs in
let !ssy_ptr = ssx_ptr `plusAddr#` 4# in
let !sg_ptr = ssy_ptr `plusAddr#` 4# in
-- call GMP
case ioVoid (integer_gmp_gcdext# mbx ssx_ptr mby ssy_ptr mbg sg_ptr ba ssa bb ssb) s of { s ->
-- read sizes
case readInt32OffAddr# ssx_ptr 0# s of { (# s, ssx #) ->
case readInt32OffAddr# ssy_ptr 0# s of { (# s, ssy #) ->
case readInt32OffAddr# sg_ptr 0# s of { (# s, sg #) ->
case touch# szs s of { s ->
-- shrink x, y and g to their actual sizes and freeze them
let !sx = absI# ssx in
let !sy = absI# ssy in
case mwaSetSize# mbx sx s of { s ->
case mwaSetSize# mby sy s of { s ->
case mwaSetSize# mbg sg s of { s ->
-- return x, y and g as Integer
case unsafeFreezeByteArray# mbx s of { (# s, bx #) ->
case unsafeFreezeByteArray# mby s of { (# s, by #) ->
case unsafeFreezeByteArray# mbg s of { (# s, bg #) ->
(# s, (# integerFromBigNat# bg
, integerFromBigNatSign# (ssx <# 0#) bx
, integerFromBigNatSign# (ssy <# 0#) by #) #)
}}}}}}}}}}}}}}}}
----------------------------------------------------------------------
-- FFI ccall imports
......@@ -366,10 +436,13 @@ foreign import ccall unsafe "integer_gmp_mpn_gcd"
c_mpn_gcd# :: MutableByteArray# s -> ByteArray# -> GmpSize#
-> ByteArray# -> GmpSize# -> IO GmpSize
foreign import ccall unsafe "integer_gmp_gcdext"
integer_gmp_gcdext# :: MutableByteArray# s -> MutableByteArray# s
-> ByteArray# -> GmpSize#
-> ByteArray# -> GmpSize# -> IO GmpSize
foreign import ccall unsafe "integer_gmp_gcdext" integer_gmp_gcdext#
:: MutableByteArray# s -> Addr#
-> MutableByteArray# s -> Addr#
-> MutableByteArray# s -> Addr#
-> ByteArray# -> GmpSize#
-> ByteArray# -> GmpSize#
-> IO ()
-- mp_limb_t mpn_add_1 (mp_limb_t *rp, const mp_limb_t *s1p, mp_size_t n,
-- mp_limb_t s2limb)
......
......@@ -17,9 +17,11 @@ module GHC.Num.Backend.Native where
#if defined(BIGNUM_NATIVE) || defined(BIGNUM_CHECK)
import {-# SOURCE #-} GHC.Num.BigNat
import {-# SOURCE #-} GHC.Num.Natural
import {-# SOURCE #-} GHC.Num.Integer
#else
import GHC.Num.BigNat
import GHC.Num.Natural
import GHC.Num.Integer
#endif
import GHC.Num.WordArray
import GHC.Num.Primitives
......@@ -717,3 +719,21 @@ bignat_powmod_words b e m =
bignat_powmod_word (wordArrayFromWord# b)
(wordArrayFromWord# e)
m
integer_gcde
:: Integer
-> Integer
-> (# Integer, Integer, Integer #)
integer_gcde a b = f (# a,integerOne,integerZero #) (# b,integerZero,integerOne #)
where
-- returned "g" must be positive
fix (# g, x, y #)
| integerIsNegative g = (# integerNegate g, integerNegate x, integerNegate y #)
| True = (# g,x,y #)
f old@(# old_g, old_s, old_t #) new@(# g, s, t #)
| integerIsZero g = fix old
| True = case integerQuotRem# old_g g of
!(# q, r #) -> f new (# r , old_s `integerSub` (q `integerMul` s)
, old_t `integerSub` (q `integerMul` t) #)
......@@ -10,6 +10,7 @@ import GHC.Prim
type BigNat# = WordArray#
data BigNat = BN# { unBigNat :: BigNat# }
bigNatSize# :: BigNat# -> Int#
bigNatSubUnsafe :: BigNat# -> BigNat# -> BigNat#
bigNatMulWord# :: BigNat# -> Word# -> BigNat#
bigNatRem :: BigNat# -> BigNat# -> BigNat#
......
......@@ -6,6 +6,7 @@
{-# LANGUAGE NegativeLiterals #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE LambdaCase #-}
-- |
-- Module : GHC.Num.Integer
......@@ -31,6 +32,7 @@ import GHC.Magic
import GHC.Num.Primitives
import GHC.Num.BigNat
import GHC.Num.Natural
import qualified GHC.Num.Backend as Backend
#if WORD_SIZE_IN_BITS < 64
import GHC.IntWord64
......@@ -113,6 +115,17 @@ integerFromBigNatSign# !sign !bn
| True
= integerFromBigNatNeg# bn
-- | Convert an Integer into a sign-bit and a BigNat
integerToBigNatSign# :: Integer -> (# Int#, BigNat# #)
integerToBigNatSign# = \case
IS x
| isTrue# (x >=# 0#)
-> (# 0#, bigNatFromWord# (int2Word# x) #)
| True
-> (# 1#, bigNatFromWord# (int2Word# (negateInt# x)) #)
IP x -> (# 0#, x #)
IN x -> (# 1#, x #)
-- | Convert an Integer into a BigNat.
--
-- Return 0 for negative Integers.
......@@ -853,7 +866,7 @@ integerDivMod# :: Integer -> Integer -> (# Integer, Integer #)
{-# NOINLINE integerDivMod# #-}
integerDivMod# !n !d
| isTrue# (integerSignum# r ==# negateInt# (integerSignum# d))
= let !q' = integerAdd q (IS -1#) -- TODO: optimize
= let !q' = integerSub q (IS 1#)
!r' = integerAdd r d
in (# q', r' #)
| True = qr
......@@ -1169,3 +1182,35 @@ integerFromByteArray# sz ba off e s = case bigNatFromByteArray# sz ba off e s of
integerFromByteArray :: Word# -> ByteArray# -> Word# -> Bool# -> Integer
integerFromByteArray sz ba off e = case runRW# (integerFromByteArray# sz ba off e) of
(# _, i #) -> i
-- | Get the extended GCD of two integers.
--
-- `integerGcde# a b` returns (# g,x,y #) where
-- * ax + by = g = |gcd a b|
integerGcde#
:: Integer
-> Integer
-> (# Integer, Integer, Integer #)
integerGcde# a b
| integerIsZero a && integerIsZero b = (# integerZero, integerZero, integerZero #)
| integerIsZero a = fix (# b , integerZero, integerOne #)
| integerIsZero b = fix (# a , integerOne, integerZero #)
| integerAbs a `integerEq` integerAbs b = fix (# b , integerZero, integerOne #)
| True = Backend.integer_gcde a b
where
-- returned "g" must be positive
fix (# g, x, y #)
| integerIsNegative g = (# integerNegate g, integerNegate x, integerNegate y #)
| True = (# g,x,y #)
-- | Get the extended GCD of two integers.
--
-- `integerGcde a b` returns (g,x,y) where
-- * ax + by = g = |gcd a b|
integerGcde
:: Integer
-> Integer
-> ( Integer, Integer, Integer)
integerGcde a b = case integerGcde# a b of
(# g,x,y #) -> (g,x,y)
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE MagicHash #-}
module GHC.Num.Integer where
import GHC.Types
import GHC.Prim
import {-# SOURCE #-} GHC.Num.BigNat
data Integer
integerZero :: Integer
integerOne :: Integer
integerEq# :: Integer -> Integer -> Int#
integerEq :: Integer -> Integer -> Bool
integerGt :: Integer -> Integer -> Bool
integerIsZero :: Integer -> Bool
integerIsNegative :: Integer -> Bool
integerSub :: Integer -> Integer -> Integer
integerMul :: Integer -> Integer -> Integer
integerNegate :: Integer -> Integer
integerDivMod# :: Integer -> Integer -> (# Integer, Integer #)
integerQuotRem# :: Integer -> Integer -> (# Integer, Integer #)
integerToBigNatSign# :: Integer -> (# Int#, BigNat# #)
integerFromBigNatSign# :: Int# -> BigNat# -> Integer
integerFromBigNat# :: BigNat# -> Integer
......@@ -29,6 +29,7 @@ module GHC.Integer.GMP.Internals
-- ** Additional 'Integer' operations
, gcdInteger
, gcdExtInteger
, lcmInteger
, sqrInteger
......@@ -170,6 +171,12 @@ isValidInteger# = I.integerCheck#
gcdInteger :: Integer -> Integer -> Integer
gcdInteger = I.integerGcd
{-# DEPRECATED gcdExtInteger "Use integerGcde instead" #-}
gcdExtInteger :: Integer -> Integer -> (# Integer, Integer #)
gcdExtInteger a b = case I.integerGcde# a b of
(# g, s, _t #) -> (# g, s #)
{-# DEPRECATED lcmInteger "Use integerLcm instead" #-}
lcmInteger :: Integer -> Integer -> Integer
lcmInteger = I.integerLcm
......
......@@ -5,11 +5,12 @@ test('integerConstantFolding', normal, makefile_test, ['integerConstantFolding']
test('fromToInteger', [], makefile_test, ['fromToInteger'])
test('IntegerConversionRules', [], makefile_test, ['IntegerConversionRules'])
test('gcdInteger', normal, compile_and_run, [''])
test('gcdeInteger', normal, compile_and_run, [''])
test('integerPowMod', [], compile_and_run, [''])
test('integerGcdExt', [omit_ways(['ghci'])], compile_and_run, [''])
# skip ghci as it doesn't support unboxed tuples
test('integerImportExport', [omit_ways(['ghci'])], compile_and_run, [''])
# Disable GMP only tests
#test('integerGcdExt', [omit_ways(['ghci'])], compile_and_run, [''])
#test('integerGmpInternals', [], compile_and_run, [''])
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
module Main (main) where
import GHC.Base
import GHC.Num.Integer
import Control.Monad
import System.Exit
main :: IO ()
main = do
let test a b = do
putStrLn $ "GCDE " ++ show a ++ " " ++ show b
let r@(g,x,y) = integerGcde a b
putStrLn $ " -> g = " ++ show g
putStrLn $ " -> x = " ++ show x
putStrLn $ " -> y = " ++ show y
let sign a | a >= 0 = 1
| otherwise = -1
let assert text cond term
| not cond = return ()
| term = return ()
| otherwise = do
putStrLn $ "FAILED: " ++ text
putStrLn $ "a*x + b*y = g"
putStrLn $ "a = " ++ show a
putStrLn $ "b = " ++ show b
putStrLn $ "x = " ++ show x
putStrLn $ "y = " ++ show y
putStrLn $ "g = " ++ show g
putStrLn $ "expected g = " ++ show (abs (integerGcd a b))
exitFailure
-- check properties
assert "g >= 0" True (g >= 0)
assert "a*x + b*y = g" True (a*x + b*y == g)
assert "g = abs (gcd a b)" True (g == abs (integerGcd a b))
if -- special cases
| a == 0 && b == 0 -> do
assert "a == 0 && b ==0 ==> g == 0" (a == 0 && b == 0) (g == 0)
| abs a == abs b -> do
assert "abs a == abs b ==> x == 0 && y == sign b && g == abs a"
(abs a == abs b) (x == 0 && y == sign b && g == abs a)
-- non special cases
| otherwise -> do
assert "b == 0 ==> x=sign a"
(b == 0)
(x == sign a)
assert "abs b == 2g ==> x=sign a"
(abs b == 2*g)
(x == sign a)
assert "b /= 0 ==> abs x <= abs b / 2*g"
(b /= 0)
(abs x <= abs b `div` 2 * g)
assert "a /= 0 ==> abs y <= abs a / 2*g"
(a /= 0)
(abs y <= abs a `div` 2 * g)
assert "a == 0 ==> y=sign b"
(a == 0)
(y == sign b)
assert "abs a == 2g ==> y==sign b"
(abs a == 2*g)
(y == sign b)
assert "x == 0 ==> g == abs b"
(x == 0)
(g == abs b)
nums =
[ 0
, 1
, 7
, 14
, 123
, 1230
, 123456789456789456789456789456789456789465789465456789465454645789
, 4 * 123456789456789456789456789456789456789465789465456789465454645789
, -1
, -123
, -123456789456789456789456789456789456789465789465456789465454645789
, 4567897897897897899789897897978978979789
, 2988348162058574136915891421498819466320163312926952423791023078876139
, 2351399303373464486466122544523690094744975233415544072992656881240319
, 5328841272400314897981163497728751426
, 32052182750761975518649228050096851724
]
forM_ nums $ \a ->
forM_ nums $ \b ->
test a b
-- see #15350
do
let a = 2
b = 2^65 + 1
test a b
test a (-b)
test (-a) b
test (-a) (-b)
test b a
test b (-a)
test (-b) a
test (-b) (-a)
This diff is collapsed.
......@@ -9,10 +9,10 @@ import Control.Monad
import GHC.Word
import GHC.Base
import qualified GHC.Integer.GMP.Internals as I
import qualified GHC.Num.Integer as I
gcdExtInteger :: Integer -> Integer -> (Integer, Integer)
gcdExtInteger a b = case I.gcdExtInteger a b of (# g, s #) -> (g, s)
gcdExtInteger a b = case I.integerGcde a b of ( g, s, _t ) -> (g, s)
main :: IO ()
main = do
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment