Оптимизация алгоритма поиска троек
Не проходит по времени. Не могу оптимизировать.
Условие:
Вам дан массив a из n чисел. Надо посчитать количество таких троек i, j, k
, что i < j < k
и a_i > a_j > a_k
.
В первой строке записано одно целое число n (3 ≤ n ≤ 10^6)
— количество чисел. Следующая строка содержит n целых чисел a_i (1 ≤ i ≤ n, 1 ≤ a_i ≤ 10^9)
— сам массив.
На выходе одно число - ответ на задачу.
Примеры:
Вход: 3
3 2 1
Выход: 1
Вход: 4
8 6 3 1
Выход: 4
Мой код:
def count_inversions(n, arr):
count = 0
for i in range(n):
for j in range(i + 1, n):
for k in range(j + 1, n):
if arr[i] > arr[j] > arr[k]:
count += 1
return count
n = int(input())
arr = list(map(int, input().split()))
result = count_inversions(n, arr)
print(result)
С правками из ответа выше. Не проходит по времени
def merge_count_split_inv(arr, temp_arr, left, mid, right):
i = left
j = mid + 1
k = left
inv_count = 0
while i <= mid and j <= right:
if arr[i] <= arr[j]:
temp_arr[k] = arr[i]
i += 1
else:
temp_arr[k] = arr[j]
inv_count += (mid - i + 1)
j += 1
k += 1
while i <= mid:
temp_arr[k] = arr[i]
i += 1
k += 1
while j <= right:
temp_arr[k] = arr[j]
j += 1
k += 1
for i in range(left, right + 1):
arr[i] = temp_arr[i]
return inv_count
def merge_sort_and_count(arr, temp_arr, left, right):
inv_count = 0
if left < right:
mid = (left + right) // 2
inv_count += merge_sort_and_count(arr, temp_arr, left, mid)
inv_count += merge_sort_and_count(arr, temp_arr, mid + 1, right)
inv_count += merge_count_split_inv(arr, temp_arr, left, mid, right)
return inv_count
def count_inversions(arr):
n = len(arr)
temp_arr = [0] * n
return merge_sort_and_count(arr, temp_arr, 0, n - 1)
def count_triplets(n, arr):
total_count = 0
for j in range(1, n - 1):
left_count = 0
right_count = 0
for i in range(j):
if arr[i] > arr[j]:
left_count += 1
for k in range(j + 1, n):
if arr[j] > arr[k]:
right_count += 1
total_count += left_count * right_count
return total_count
n = int(input())
arr = list(map(int, input().split()))
result = count_triplets(n, arr)
print(result)
Ответы (1 шт):
Теория
Фиксируем j. Отыскиваем количество i: i < j, ai > aj. Отыскиваем количество k: j < k, aj > ak. Произведение этих количеств – вклад в общую сумму всех троек со средним индексом j.
Суммируем для всех j, что даёт квадратичный алгоритм по сравнению с оригинальным кубическим.
Отыскиваем модификацию сортировки слиянием, которая подсчитывает количество инверсий для всех элементов массива за NlogN. Используем её два раза: для элементов перед j и для элементов после j. Далее как в квадратичном алгоритме.
Что даёт алгоритм со сложностью NlogN и памятью N.
Практика
Массив a преобразуем в массив троек: [aj, lj, rj]. Первоначально все lj и rj равны нулю. В конце работы алгоритма
lj станет равной числу тех элементов a, которые левее позиции j и больше aj;
rj станет равной числу тех элементов a, которые правее позиции j и меньше aj.
Заодно массив будет упорядочен по возрастанию *aj, но для нас это не важно. Нас интересуют только lj и rj.
Как устроена сортировка слиянием, объяснять не буду. При слиянии увеличиваются lj и rj у элементов, которые копируются в объединённый список.
Когда копируется элемент b[j]
, он "обгоняет" все элементы из хвоста a
. Все они в исходном массиве были левее и больше b[j]
. Таких элементов len(a) - i
штук.
Когда копируется элемент a[i]
, он никого не обгоняет, но к этому моменту его ранее "обогнали" j
элементов из b
. Все они в исходном массиве были правее и меньше a[i]
.
def merge(a, b):
c = []
i = 0
j = 0
while i < len(a) and j < len(b):
if a[i][0] <= b[j][0]:
a[i][2] += j
c.append(a[i])
i += 1
else:
b[j][1] += len(a) - i
c.append(b[j])
j += 1
while i < len(a):
a[i][2] += len(b)
c.append(a[i])
i += 1
while j < len(b):
c.append(b[j])
j += 1
return c
def merge_sort(a):
if len(a) <= 1:
return a
h = len(a) // 2
a1 = merge_sort(a[:h])
a2 = merge_sort(a[h:])
return merge(a1, a2)
def main():
input() # skip n
a = [[int(w), 0, 0] for w in input().split()]
b = merge_sort(a)
print(sum(l * r for _, l, r in b))
main()