Оптимизация скорости кода
Код выполняется без ошибок, но вот при увеличении количества данных скорость снижается и решение не проходит по условию в какой то момент. Помогите оптимизировать код. Возможно есть другой вариант для расчета формулы.
Спасибо.
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar.astype(dtype=np.float16)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
for _ in un:
for row in ar[ar[:, 0] == _][:, 1]:
s += (
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] > row)))].shape[0] +
ar[np.where((ar[:, 0] > _)
& ((ar[:, 1] == row)))].shape[0] / 2
)
p += (
ar[np.where(ar[:, 0] > _)].shape[0]
)
print(round(s / p, 7))
Оптимизировал не много код но дальше уменьшить время не удается.
import numpy as np
ar = np.loadtxt('input.txt', skiprows=1)
ar = ar[ar[:, 0].argsort()]
s = 0
p = 0
un = np.unique(ar[:, 0])
un_1, duplicate_count = np.unique(ar, axis=0, return_counts=True)
pivot_table = np.concatenate((un_1, (np.array([duplicate_count])).T), axis=1)
for _ in un:
for row in pivot_table[pivot_table[:, 0] == _][:, 1]:
s += (
(
np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] > row))][:, 2])
+ np.sum(pivot_table[np.where((pivot_table[:, 0] > _)
& (pivot_table[:, 1] == row))][:, 2])
/ 2
)
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
p += (
np.sum(pivot_table[np.where(pivot_table[:, 0] > _)][:, 2])
* pivot_table[np.where((pivot_table[:, 0] == _) & (pivot_table[:, 1] == row))][:, 2]
)
print(round(float(s / p), 7))
Ответы (1 шт):
Ваш алгоритм квадратичный. И любой алгоритм который сравнивает между собой все пары записей будет таким, как бы вы его не оптимизировали.
Быстрое решение основано на алгоритме подсчёта количества инверсий в перестановке/массиве - модификация сортировки слиянием считает число инверсий за NlogN. Суммарное число инверсий нам не интересно, но для каждого элемента массива можно сосчитать сколько элементов меньших его стоит перед ним.
Входные пары (t_i, y_i) сортируются по t_i (если несколько t_i равны, то y_i упорядочиваются по убыванию) и дополняются двумя счётчиками:
t_i - точное значение y_i - предсказанное значение lt_i - число элементов, стоящих перед текущим, таких что `y` меньше текущего eq_i - число элементов, стоящих перед текущим, таких что `y` равен текущему
Сортировка слиянием по списку записей упорядоченных по t_i строит список записей упорядоченных по y_i.
lt_i считается в два прохода: сперва в сортировке слиянием подсчитываются меньшие либо равные элементы, затем из них вычитается количество равных.
eq_i считает в один проход по уже упорядоченному списку.
Обладая счётчиками lt_i и eq_i можно вычислить ROC-AUC за линейное время.
Суммарная сложность алгоритма NlogN. У меня нет доступа проверяющей системе, я сравнивал работу вашего решения со своим на случайных наборах данных. Проверялись как наборы уникальных значений, так и повторяющиеся. Во всех случаях ROC-AUC совпадал в первых шести знаках.
Наибольшие трудности доставила обработка повторяющихся t_i и y_i в различных сочетаниях. Я опущу подробности, их слишком много.
def merge(a, b):
c = []
i = 0
j = 0
while i < len(a) and j < len(b):
if a[i][1] <= b[j][1]:
c.append(a[i])
i += 1
else:
b[j][2] += i
c.append(b[j])
j += 1
while i < len(a):
c.append(a[i])
i += 1
while j < len(b):
b[j][2] += i
c.append(b[j])
j += 1
return c
def merge_sort(a):
if len(a) <= 1:
return a
h = len(a) // 2
a1 = merge_sort(a[:h])
a2 = merge_sort(a[h:])
return merge(a1, a2)
def fill_counters(a):
a = merge_sort(a)
prev_y = None
c_y = 0
for e in a:
if e[1] == prev_y:
c_y += 1
else:
c_y = 0
e[2] -= c_y
prev_y = e[1]
prev_y = None
c_y = 0
c_t = 0
prev_t = None
for e in a:
if e[1] == prev_y:
c_y += 1
if e[0] != prev_t:
c_t = c_y
else:
c_y = 0
c_t = 0
e[3] = c_t
prev_y = e[1]
prev_t = e[0]
def main():
array = [
[float(t) for t in input().split()] + [0, 0]
for _ in range(int(input()))
]
array.sort(key=lambda e: (e[0], -e[1]))
fill_counters(array)
n = 0
d = 0
k = 0
prev_t = None
for i, e in enumerate(array):
if e[0] != prev_t:
k = i
n += e[2] + e[3] / 2
d += k
prev_t = e[0]
print(n / d)
main()
Времена выполнения:
размер разница время время списка ROC-AUC работы работы (N) NumPy (c) merge_sort (с) 10000 0.000000 3.794 0.119 20000 0.000000 13.699 0.225 30000 0.000000 27.586 0.318 40000 0.000000 48.019 0.460 50000 0.000000 75.546 0.643 60000 0.000000 106.449 0.717 70000 0.000000 149.666 0.892 80000 0.000000 183.675 0.922 90000 0.000000 230.494 1.073 100000 0.000000 303.946 1.214
P.S. print(round(s / p, 7)) - печатать округлённые значения лучше форматированием. Вещественные десятичные дроби не представляются точно в компьютере. Вы рискуете напечатать неокруглённый результат. Так лучше: print(f'{s / p:.7f}').
