Python 中的整数平方根

在 python 或标准库中是否存在一个整数平方根?我希望它是精确的(例如,返回一个整数) ,如果没有解决方案,就吠叫。

那一刻,我自己也有一个天真的想法:

def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')

但是它很难看,我不相信它能用于大整数。我可以迭代遍历这些正方形,如果我已经超过了这个值,我就会放弃,但是我认为做这样的事情会有点慢。而且我想我可能是重造轮子,这样的东西肯定已经存在于巨蟒中了... ..。

86633 次浏览

试试这个条件(没有额外的计算) :

def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i

如果需要返回一个 int(而不是后面带零的 float) ,那么要么指定第二个变量,要么计算两次 int(i)

看起来你可以这样检查:

if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'

更新:

正如您所指出的,对于 n的较大值,上述方法是失败的。对于那些看起来很有前途的人,下面是 Martin Guy@UKC 于1985年6月对示例 C 代码的改编,用于维基百科文章 平方根演算法中提到的相对简单的二进制数字逐位计算方法:

from math import ceil, log


def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0  # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res


if __name__ == '__main__':
from math import sqrt  # for comparison purposes


for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}:  math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')

产出:

                    0:  math.sqrt=           0, isqrt=         0 (perfect square)
1:  math.sqrt=           1, isqrt=         1 (perfect square)
2:  math.sqrt=    1.414214, isqrt=         1
3:  math.sqrt=    1.732051, isqrt=         1
4:  math.sqrt=           2, isqrt=         2 (perfect square)
5:  math.sqrt=    2.236068, isqrt=         2
6:  math.sqrt=     2.44949, isqrt=         2
7:  math.sqrt=    2.645751, isqrt=         2
8:  math.sqrt=    2.828427, isqrt=         2
9:  math.sqrt=           3, isqrt=         3 (perfect square)
10:  math.sqrt=    3.162278, isqrt=         3
11:  math.sqrt=    3.316625, isqrt=         3
12:  math.sqrt=    3.464102, isqrt=         3
13:  math.sqrt=    3.605551, isqrt=         3
14:  math.sqrt=    3.741657, isqrt=         3
15:  math.sqrt=    3.872983, isqrt=         3
16:  math.sqrt=           4, isqrt=         4 (perfect square)
9,007,199,254,740,992:  math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001:  math.sqrt=      1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)

您的函数在大输入时失败:

In [26]: isqrt((10**100+1)**2)


ValueError: input was not a perfect square

有一个 激活状态站点上的配方应该是更可靠的,因为它只使用整数数学。它基于早期的 StackOverflow 问题: 编写自己的平方根函数

一种选择是使用 decimal模块,并以足够精确的浮动方式进行:

import decimal


def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i

我认为应该可行:

>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square

在计算机上不能精确地表示浮动。您可以在 python 的 float 的精度范围内测试所需的接近设置 epsilon 的小值。

def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')

注意: 现在 stdlib 中有 math.isqrt,自 Python 3.8以来就可以使用。

牛顿的方法在整数上非常有效:

def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x

这将返回 X * X不超过 N的最大整数 X。如果您想检查结果是否正好是平方根,只需执行乘法来检查 N是否是完全平方。

我将在 我的博客上讨论这个算法以及计算平方根的其他三个算法。

很抱歉这么晚才回复,我只是无意中看到了这个页面。如果将来有人访问这个页面,python 模块 gmpy2被设计用于处理非常大的输入,其中包括一个整数平方根函数。

例如:

>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)

当然,所有东西都有“ mpz”标签,但是 mpz 的标签与 int 的标签兼容:

>>> gmpy2.mpz(3)*4
mpz(12)


>>> int(gmpy2.mpz(12))
12

有关此方法的性能相对于此问题的其他答案的讨论,请参见 我的另一个答案

下载: https://code.google.com/p/gmpy/

Long-hand 平方根算法

有一种计算平方根的算法可以手工计算,比如长除法。该算法的每次迭代只产生结果平方根的一个数字,同时消耗您所寻找的平方根数字的两个数字。虽然算法的“长手”版本是以十进制表示的,但它可以在任何基础上工作,二进制是最简单的实现,也许是执行速度最快的(取决于底层的 bignum 表示)。

由于该算法对数字逐位运算,对任意大小的完全正方形产生精确的结果,对非完全正方形,可以产生任意数字的精度(小数点后右侧)。

“数学博士”网站上有两篇不错的文章解释了这个算法:

下面是 Python 中的一个实现:

def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
 

The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder".  If
x is a perfect square, then r will be zero.
 

The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
 

Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
    

N = 0   # Problem so far
a = 0   # Solution so far
    

# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2)          # Round up to the next even number
    

for i in xrange(L, -1, -1):
        

# Get the next group of two bits
n = (x >> (2*i)) & 0b11
        

# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
        

a = (a << 1) | b   # Concatenate the next bit of the solution
N = (N << 2) | n   # Concatenate the next bit of the problem
    

return (a, N-a*a)

您可以很容易地修改这个函数,以进行额外的迭代来计算平方根的小数部分。我最感兴趣的是计算大型完全正方形的根。

我不知道这和“整数牛顿法”算法有什么不同。我怀疑牛顿的方法更快,因为它原则上可以在一次迭代中生成多位解,而“长手”算法每次迭代只生成一位解。

资料来源: https://gist.github.com/tobin/11233492

这里有一个非常简单的实现:

def i_sqrt(n):
i = n.bit_length() >> 1    # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i    # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m

这只是二进制搜索。将值 m初始化为不超过平方根的2的最大幂,然后检查是否可以设置每个较小的位,同时保持结果不大于平方根。(按降序一次检查一个位。)

对于相当大的 n值(比如,大约 10**6000或者大约 20000位) ,这似乎是:

所有这些方法都能在这种大小的输入上成功,但是在我的机器上,这个函数大约需要1.5秒,而@Nibot 大约需要0.9秒,@user448810大约需要19秒,而 gmpy2内置方法只需要不到一毫秒(!).例如:

>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True

这个函数可以很容易地推广,尽管它不是很好,因为我对 m的初始猜测没有这么精确:

def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m

但是,请注意,gmpy2也有一个 i_root方法。

实际上,这种方法可以适用于任何(非负的,递增的)函数 f,以确定“ f的整数逆”。但是,要选择一个有效的 m初始值,您仍然需要了解 f的一些信息。

编辑: 感谢@Greggo 指出可以重写 i_sqrt函数以避免使用任何乘法运算。这将产生令人印象深刻的性能提升!

def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1    # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i    # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m

注意,通过构造,m << 1k位没有设置,因此按位-或者可以用来实现 (m<<1) + (1<<k)的添加。最后,我把 (2*m*(2**k) + 2**(2*k))写成了 (((m<<1) | (1<<k)) << k),所以它是三个移位和一个位-或(后面跟着一个减法得到 new_diff)。也许还有更有效的方法得到它?无论如何,这比增加 m*m要好得多!与上述情况相比较:

>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True

我用一个循环比较了这里给出的不同方法:

for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1

发现这一个是最快的,不需要数学导入。很长时间可能会失败,但看看这个

15241576832799734552675677489**0.5 = 123456781234567.0

更新: Python 3.8有一个 math.isqrt函数在标准库中!

我在小(0... 222)和大(250001)输入上对这里的每个(正确的)函数进行了基准测试。在这两种情况下,明显的赢家是 由 mathmandan 建议的 gmpy2.isqrt位居第一,其次是 Python 3.8的 math.isqrt位居第二,其次是 NPE 链接的 ActiveState 配方位居第三。ActiveState 配方有一大堆可以用移位代替的划分,这使得它更快一些(但仍然落后于本机函数) :

def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")

基准结果:

(* 由于 gmpy2.isqrt返回一个 gmpy2.mpz对象,其行为大多与 int相似,但不完全相同,因此在某些情况下可能需要将其转换回 int。)

Python 默认的 math库有一个整数平方根函数:

math.isqrt(n)

返回非负整数 N的整数平方根。这是 N的精确平方根的底数,或者等效于 A2≤ n的最大整数。

下面的脚本提取整数平方根。它不使用除法,只使用位移,所以它相当快。它在反平方根上使用 牛顿的方法,正如 Wikipedia 文章 平方根倒数速算法中提到的,地震三竞技场使这种技术闻名于世。

算法计算 s = sqrt(Y)的策略如下。

  1. 在范围[1/4,1)内将参数 Y 减少到 y,即 y = Y/B,其中1/4 < = y < 1,其中 B 是2的偶数次幂,因此对于某个整数 k,B = 2**(2*k)。我们想找到 X,其中 x = X/B,和 x = 1/sqrt (y)。
  2. 使用二次 极小极大多项式极小极大多项式确定 X 的第一近似值。
  3. 用牛顿法精化 X。
  4. 计算 s = X*Y/(2**(3*k))

我们实际上不创建分数或执行任何除法。所有的算法都是用整数来完成的,我们使用位移来除以 B 的各种幂。

范围缩减使我们能够找到一个很好的初始近似值来反馈给牛顿方法。下面是二次极大极小多项式逼近区间[1/4,1]中的逆平方根的一个版本:

Minimax poly for 1/sqrt(x)

(对不起,为了遵守通常的惯例,我在这里颠倒了 x & y 的意思)。该近似的最大误差在0.0355 ~ = 1/28左右。下面的图表显示了这个错误:

Minimax poly error graph

使用这个多边形,我们的初始 x 至少以4或5位的精度开始。牛顿方法的每一个回合的精度翻倍,所以它不需要很多回合得到数千位,如果我们想要他们。


""" Integer square root


Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation


Written by PM 2Ring 2022.01.23
"""


def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y


# print("\n*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6


# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print("   ", x, h)


# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break


# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s


def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")


test(0, 100000, 1)

这个代码当然是 慢一点math.isqrtdecimal.Decimal.sqrt。其目的只是为了说明算法。如果它能在 C 语言中实现,那么它的速度会有多快呢。.


这是在 SageMathCell 服务器上运行的 现场版。设置 hi < = 0以计算并显示在 lo中设置的单个值的结果。你可以把表达式输入框,例如设置 hi到0和 lo2 * 10**100得到 sqrt(2) * 10**50

受到所有答案的启发,决定用纯 C + + 实现从这些答案中得到的几个最好的方法。众所周知,C + + 总是比 Python 快。

为了粘合 C + + 和 Python,我使用了 Cython。它允许使用 C + + 创建一个 Python 模块,然后直接从 Python 函数调用 C + + 函数。

作为补充,我不仅提供了采用 Python 的代码,还提供了纯 C + + 和测试。

下面是来自纯 C + + 测试的计时结果:

Test           'GMP', bits     64, time  0.000001 sec
Test 'AndersKaseorg', bits     64, time  0.000003 sec
Test    'Babylonian', bits     64, time  0.000006 sec
Test  'ChordTangent', bits     64, time  0.000018 sec


Test           'GMP', bits  50000, time  0.000118 sec
Test 'AndersKaseorg', bits  50000, time  0.002777 sec
Test    'Babylonian', bits  50000, time  0.003062 sec
Test  'ChordTangent', bits  50000, time  0.009120 sec

以及相同的 C + + 函数,但是采用了 Python 模块,它们具有计时功能:

Bits 50000
math.isqrt:   2.819 ms
gmpy2.isqrt:   0.166 ms
ISqrt_GMP:   0.252 ms
ISqrt_AndersKaseorg:   3.338 ms
ISqrt_Babylonian:   3.756 ms
ISqrt_ChordTangent:  10.564 ms

从某种意义上说,我的 Cython-C + + 是一个很好的框架,适合那些希望直接从 Python 编写和测试自己的 C + + 方法的人。

正如你在上面的例子中注意到的,我使用了以下方法:

  1. Isqrt ,从标准库实现。

  2. GMPY2.isqrt ,GMPY2库的实现。

  3. ISqrt _ GMP -与 GMPY2相同,但是在 Cython 模块中,我直接使用 C + + GMP 库(<gmpxx.h>)。

  4. ISqrt _ AndersKaseorg ,源自@AndersKaseorg 的 回答代码。

  5. ISqrt _ Babylonian ,取自 Wikipedia 文章的方法,即所谓的巴比伦方法。

  6. ISqrt _ ChordTangent ,这是我自己的方法,我称之为 Chord-Tangent,因为它使用 chord 和切线来迭代地缩短搜索间隔。这种方法在 我的另一篇文章中有详细的描述。这个方法很好,因为它不仅搜索平方根,而且还搜索任何 K 的 K 次根。我画了一个 小图片,显示了这个算法的细节。

关于编译 C + +/Cython 代码,我使用了 GMP库。您需要先安装它,在 Linux 下,通过 sudo apt install libgmp-dev很容易。

在 Windows 下最简单的就是安装非常棒的程序 VCPKG,这是软件包管理器,类似于 Linux 中的 APT。VCPKG 使用 视觉工作室从源代码编译所有包(不要忘记使用 VisualStudio 的 安装社区版本)。安装 VCPKG 后,可以按 vcpkg install gmp安装 GMP。也可以安装 MPIR,这是 GMP 的另一个分支,可以通过 vcpkg install mpir安装。

在 Windows 下安装 GMP 后,请编辑我的 Python 代码并替换路径以包含目录和库文件。VCPKG 在安装结束时应该显示你的 ZIP 文件路径与 GMP 库,有。Lib 和。H 档案。

您可能会注意到,在 Python 代码中,我还设计了特别方便的 cython_compile()函数,用于将任何 C + + 代码编译到 Python 模块中。这个函数非常好,因为它允许您轻松地将任何 C + + 代码插入 Python 中,可以多次重用。

如果你有任何问题或者建议,或者有什么东西在你的电脑上不能工作,请写评论。

下面首先展示 Python 代码,然后展示 C + + 代码。请参阅 C + + 代码上面的 Try it online!链接,以便在 GodBolt 服务器上在线运行代码。我完全可以从头开始运行这两个代码片段,不需要在其中进行任何编辑。

def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
    

if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)


os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
                

for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))


import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize


sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')


print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')


def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>


#include <gmpxx.h>


#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")


#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }


using u32 = uint32_t;
using u64 = uint64_t;


template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}


template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}


template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}


template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}


template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T   hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}


mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}


void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}


void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION


import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *


cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);


@cython.boundscheck(False)
@cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
        

words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
        

n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)


def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))


if __name__ == '__main__':
main()

C + + :

上网试试!

#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>


#include <gmpxx.h>


#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }


using u32 = uint32_t;
using u64 = uint64_t;


template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}


template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}


template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}


template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}


template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}


template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T   hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}


mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}


void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}


void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}


void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}


// Testing


#include <chrono>
#include <random>
#include <vector>
#include <iomanip>


inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}


template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}


void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP",           bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian",    bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent",  bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>\{\{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>\{\{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}


int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}