Как ускорить вычисления вероятностей на python
Есть код, считает правильно, но долго - около 10 сек. Можно ли как-то ускорить вычисления?
import scipy.stats as sps
mo1, mo2, sig_h, sig_a, summ = 50, 50, 10, 10, 150
def find_prob(mo1, mo2, sig_h, sig_a, summ):
prob = 0
start = time.time()
for i in range(0, 301):
for j in range(0, 301):
if i + j < summ:
prob = prob + (sps.norm(loc=mo1, scale=sig_h).cdf(i + 0.5) - sps.norm(loc=mo1, scale=sig_h).cdf(i - 0.5))*(sps.norm(loc=mo2, scale=sig_a).cdf(j + 0.5) - sps.norm(loc=mo2, scale=sig_a).cdf(j - 0.5))
end = time.time()
print('время: ' + str(end - start))
return prob
Ответы (2 шт):
Автор решения: Stanislav Volodarskiy
→ Ссылка
Всё вынести из циклов:
import scipy.stats as sps
import time
def find_prob(mo1, mo2, sig_h, sig_a, summ):
n1 = sps.norm(loc=mo1, scale=sig_h)
n2 = sps.norm(loc=mo2, scale=sig_a)
prob = 0
start = time.time()
a = [n1.cdf(i + 0.5) - n1.cdf(i - 0.5) for i in range(summ)]
b = [n2.cdf(j + 0.5) - n2.cdf(j - 0.5) for j in range(summ)]
for i, ai in enumerate(a):
for j, bj in enumerate(b):
if i + j < summ:
prob = prob + ai * bj
end = time.time()
print('время: ' + str(end - start))
return prob
# mo1, mo2, sig_h, sig_a, summ = 50, 50, 10, 10, 150
print(find_prob(50, 50, 10, 10, 150))
$ python temp.py время: 0.07240009307861328 0.9997664485607246
Сократить О-большое с квадрата до линии и перевести все вычисления на NumPy:
import numpy as np
import scipy.stats as sps
import time
def find_prob(mo1, mo2, sig_h, sig_a, summ):
prob = 0
start = time.time()
x = np.arange(summ + 1)
a = np.diff(sps.norm(loc=mo1, scale=sig_h).cdf(x - 0.5))
b = sps.norm(loc=mo2, scale=sig_a).cdf(x - 0.5)
b -= b[0]
b = b[1:]
prob = np.dot(a, b[::-1])
end = time.time()
print('время: ' + str(end - start))
return prob
# mo1, mo2, sig_h, sig_a, summ = 50, 50, 10, 10, 150
print(find_prob(50, 50, 10, 10, 150))
$ python temp.py время: 0.0018396377563476562 0.9997664485607203
Код ускорен примерно в восемнадцать тысяч раз. Тут и О-большое и константа.
Автор решения: CrazyElf
→ Ссылка
Ещё вариант с lru_cache, плюс break для досрочного выхода из цикла. Для чистоты эксперимента кэш чистится.
import scipy.stats as sps
import time
from functools import lru_cache
mo1, mo2, sig_h, sig_a, summ = 50, 50, 10, 10, 150
s1 = sps.norm(loc=mo1, scale=sig_h)
s2 = sps.norm(loc=mo2, scale=sig_a)
@lru_cache(None)
def cdf1(i):
return s1.cdf(i)
@lru_cache(None)
def cdf2(i):
return s2.cdf(i)
def find_prob(mo1, mo2, sig_h, sig_a, summ):
prob = 0
start = time.time()
for i in range(summ):
for j in range(summ):
if i + j < summ:
prob = prob + (cdf1(i + 0.5) - cdf1(i - 0.5)) * (cdf2(j + 0.5) - cdf2(j - 0.5))
else:
break
end = time.time()
print('время: ' + str(end - start))
return prob
cdf1.cache_clear()
cdf2.cache_clear()
print(find_prob(mo1, mo2, sig_h, sig_a, summ))
# время: 0.048853158950805664
# 0.9997664485607246