How to find integer nth roots?

2019-01-09 06:23发布

问题:

I want to find the greatest integer less than or equal to the kth root of n. I tried

int(n**(1/k))

But for n=125, k=3 this gives the wrong answer! I happen to know that 5 cubed is 125.

>>> int(125**(1/3))
4

What's a better algorithm?


Background: In 2011, this slip-up cost me beating Google Code Jam. https://code.google.com/codejam/contest/dashboard?c=1150486#s=p2

回答1:

One solution first brackets the answer between lo and hi by repeatedly multiplying hi by 2 until n is between lo and hi, then uses binary search to compute the exact answer:

def iroot(k, n):
    hi = 1
    while pow(hi, k) < n:
        hi *= 2
    lo = hi / 2
    while hi - lo > 1:
        mid = (lo + hi) // 2
        midToK = pow(mid, k)
        if midToK < n:
            lo = mid
        elif n < midToK:
            hi = mid
        else:
            return mid
    if pow(hi, k) == n:
        return hi
    else:
        return lo

A different solution uses Newton's method, which works perfectly well on integers:

def iroot(k, n):
    u, s = n, n+1
    while u < s:
        s = u
        t = (k-1) * s + n // pow(s, k-1)
        u = t // k
    return s


回答2:

How about:

def nth_root(val, n):
    ret = int(val**(1./n))
    return ret + 1 if (ret + 1) ** n == val else ret

print nth_root(124, 3)
print nth_root(125, 3)
print nth_root(126, 3)
print nth_root(1, 100)

Here, both val and n are expected to be integer and positive. This makes the return expression rely exclusively on integer arithmetic, eliminating any possibility of rounding errors.

Note that accuracy is only guaranteed when val**(1./n) is fairly small. Once the result of that expression deviates from the true answer by more than 1, the method will no longer give the correct answer (it'll give the same approximate answer as your original version).

Still I am wondering why int(125**(1/3)) is 4

In [1]: '%.20f' % 125**(1./3)
Out[1]: '4.99999999999999911182'

int() truncates that to 4.



回答3:

My cautious solution after being so badly burned:

def nth_root(N,k):
    """Return greatest integer x such that x**k <= N"""
    x = int(N**(1/k))      
    while (x+1)**k <= N:
        x += 1
    while x**k > N:
        x -= 1
    return x


回答4:

Why not to try this :

125 ** (1 / float(3)) 

or

pow(125, 1 / float(3))

It returns 5.0, so you can use int(), to convert to int.



回答5:

Here it is in Lua using Newton-Raphson method

> function nthroot (x, n) local r = 1; for i = 1, 16 do r = (((n - 1) * r) + x / (r ^ (n -   1))) / n end return r end
> return nthroot(125,3)
5
> 

Python version

>>> def nthroot (x, n):
...     r = 1
...     for i in range(16):
...             r = (((n - 1) * r) + x / (r ** (n - 1))) / n
...     return r
... 
>>> nthroot(125,3)
5
>>> 


回答6:

I wonder if starting off with a method based on logarithms can help pin down the sources of rounding error. For example:

import math
def power_floor(n, k):
    return int(math.exp(1.0 / k * math.log(n)))

def nth_root(val, n):
    ret = int(val**(1./n))
    return ret + 1 if (ret + 1) ** n == val else ret

cases = [
    (124, 3),
    (125, 3),
    (126, 3),
    (1, 100),
    ]


for n, k in cases:
    print "{0:d} vs {1:d}".format(nth_root(n, k), power_floor(n, k))

prints out

4 vs 4
5 vs 5
5 vs 5
1 vs 1


回答7:

def nth_root(n, k):
    x = n**(1./k)
    y = int(x)
    return y + 1 if y != x else y


回答8:

You can round to nearest integer instead of rounding down / to zero (I don't know what Python specifies) :

def rtn (x):
    return int (x + 0.5)

>>> rtn (125 ** (1/3))
5


回答9:

int(125**(1/3)) should clearly be 5, i.e. the right answer, so this must be standard computer rounding error, i.e internally the result is 4.9999999999 which gets rounded down to 4. This problem will exist with whatever algorithm you use. One simple ad-hoc solution is to add a tiny number e.g. int((125**(1/3)) + 0.00000001)



回答10:

Do this before everything:

from __future__ import division

and then run any of the above specified techniques to have your results.



回答11:

def nthrootofm(a,n):
    a= pow(a,(1/n))
    return 'rounded:{},'.format(round(a))
a=125
n=3
q=nthrootofm(a,n)
print(q)

just used a format string , maybe this helps.