Почему результат функции на python отличен от результата на C++ и как исправить?
Функция на Python:
def inv(a, n):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % n, n
while low > 1:
r = high//low
nm, new = hm-lm*r, high-low*r
lm, low, hm, high = nm, new, lm, low
return lm % n
Функция на C++:
int inv(int a, int n) {
if (a == 0) {
return 0;
}
int lm = 1;
int hm = 0;
int low = a % n;
int high = n;
while (low > 1) {
int r = high / low;
int nm = hm - lm * r;
int nw = high - low * r;
lm = nm;
low = nw;
hm = lm;
high = low;
}
return lm % n;
}
Python: inv(198411, 524287) -> 448444, C++: inv(198411, 524287) -> 0
Как исправить функцию на C++ что бы она работала как на Python?
Ответы (2 шт):
int inv(int a, int n)
{
if (a < 1 or n < 2)
return -1;
int high = n, hm = 0, low = a, lm = 1;
while (low)
{
int q = high / low;
int u = high - q*low;
int v = hm - q*lm;
high = low;
hm = lm;
low = u;
lm = v;
}
return high == 1 ? (hm + n) % n : -1;
}
Запись lm, low, hm, high = nm, new, lm, low аналогична
(lm, low, hm, high) = (nm, new, lm, low)
и она кортежу присваивает кортеж. Сначала создаётся один, а потом второй. Так как вы присваиваете последовательно в C++, не используя никаких дополнительных структур, то происходит ошибка логическая. Сначала вы присваиваете
lm = nm;
// потом уже другое значение
hm = lm;
и происходит эта ошибка. Чтобы её исправить, можно добавить запасные переменные с копиями этих значений.
int oldlm = lm ;
lm = nm;
int oldlow = low ;
low = nw;
hm = oldlm;
high = oldlow;
Далее функция взятия остатка от деления немного отличается от C++ варианта при отрицательных переменных. В C++ она возвращает отрицательное число, а в Python - положительное. Это можно исправить, добавив делитель.
return ( lm + n ) % n;
В итоге получилось такое :
int inv(int a, int n) {
if (a == 0) {
return 0;
}
int lm = 1;
int hm = 0;
int low = (a + n) % n ;
int high = n;
while (low > 1) {
int r = high / low;
int nm = hm - lm * r;
int nw = high - low * r;
int oldlm = lm ;
lm = nm;
int oldlow = low ;
low = nw;
hm = oldlm;
high = oldlow;
}
return ( lm + n ) % n ;
}