Оптимизация алгоритма поиска троек

Не проходит по времени. Не могу оптимизировать.

Условие:

Вам дан массив a из n чисел. Надо посчитать количество таких троек i, j, k, что i < j < k и a_i > a_j > a_k.

В первой строке записано одно целое число n (3 ≤ n ≤ 10^6) — количество чисел. Следующая строка содержит n целых чисел a_i (1 ≤ i ≤ n, 1 ≤ a_i ≤ 10^9) — сам массив.

На выходе одно число - ответ на задачу.

Примеры:

Вход: 3 3 2 1 Выход: 1

Вход: 4 8 6 3 1 Выход: 4

Мой код:

def count_inversions(n, arr):
    count = 0
    for i in range(n):
        for j in range(i + 1, n):
            for k in range(j + 1, n):
                if arr[i] > arr[j] > arr[k]:
                    count += 1
    return count

n = int(input())
arr = list(map(int, input().split()))

result = count_inversions(n, arr)
print(result)

С правками из ответа выше. Не проходит по времени

def merge_count_split_inv(arr, temp_arr, left, mid, right):
i = left
j = mid + 1
k = left
inv_count = 0

while i <= mid and j <= right:
    if arr[i] <= arr[j]:
        temp_arr[k] = arr[i]
        i += 1
    else:
        temp_arr[k] = arr[j]
        inv_count += (mid - i + 1)
        j += 1
    k += 1

while i <= mid:
    temp_arr[k] = arr[i]
    i += 1
    k += 1

while j <= right:
    temp_arr[k] = arr[j]
    j += 1
    k += 1

for i in range(left, right + 1):
    arr[i] = temp_arr[i]
    
return inv_count

def merge_sort_and_count(arr, temp_arr, left, right):
    inv_count = 0
    if left < right:
        mid = (left + right) // 2
        inv_count += merge_sort_and_count(arr, temp_arr, left, mid)
        inv_count += merge_sort_and_count(arr, temp_arr, mid + 1, right)
        inv_count += merge_count_split_inv(arr, temp_arr, left, mid, right)
    return inv_count

def count_inversions(arr):
    n = len(arr)
    temp_arr = [0] * n
    return merge_sort_and_count(arr, temp_arr, 0, n - 1)

def count_triplets(n, arr):
    total_count = 0
    for j in range(1, n - 1):
        left_count = 0
        right_count = 0
        for i in range(j):
            if arr[i] > arr[j]:
                left_count += 1
        for k in range(j + 1, n):
            if arr[j] > arr[k]:
                right_count += 1
        total_count += left_count * right_count
    return total_count
n = int(input())
arr = list(map(int, input().split()))
result = count_triplets(n, arr)
print(result)

Ответы (1 шт):

Автор решения: Stanislav Volodarskiy

Теория

Фиксируем j. Отыскиваем количество i: i < j, ai > aj. Отыскиваем количество k: j < k, aj > ak. Произведение этих количеств – вклад в общую сумму всех троек со средним индексом j.

Суммируем для всех j, что даёт квадратичный алгоритм по сравнению с оригинальным кубическим.

Отыскиваем модификацию сортировки слиянием, которая подсчитывает количество инверсий для всех элементов массива за NlogN. Используем её два раза: для элементов перед j и для элементов после j. Далее как в квадратичном алгоритме.

Что даёт алгоритм со сложностью NlogN и памятью N.

Практика

Массив a преобразуем в массив троек: [aj, lj, rj]. Первоначально все lj и rj равны нулю. В конце работы алгоритма
lj станет равной числу тех элементов a, которые левее позиции j и больше aj;
rj станет равной числу тех элементов a, которые правее позиции j и меньше aj.

Заодно массив будет упорядочен по возрастанию *aj, но для нас это не важно. Нас интересуют только lj и rj.

Как устроена сортировка слиянием, объяснять не буду. При слиянии увеличиваются lj и rj у элементов, которые копируются в объединённый список.

Когда копируется элемент b[j], он "обгоняет" все элементы из хвоста a. Все они в исходном массиве были левее и больше b[j]. Таких элементов len(a) - i штук.

Когда копируется элемент a[i], он никого не обгоняет, но к этому моменту его ранее "обогнали" j элементов из b. Все они в исходном массиве были правее и меньше a[i].

def merge(a, b):
    c = []
    i = 0
    j = 0
    while i < len(a) and j < len(b):
        if a[i][0] <= b[j][0]:
            a[i][2] += j
            c.append(a[i])
            i += 1
        else:
            b[j][1] += len(a) - i
            c.append(b[j])
            j += 1
    while i < len(a):
        a[i][2] += len(b)
        c.append(a[i])
        i += 1
    while j < len(b):
        c.append(b[j])
        j += 1
    return c


def merge_sort(a):
    if len(a) <= 1:
        return a
    h = len(a) // 2
    a1 = merge_sort(a[:h])
    a2 = merge_sort(a[h:])
    return merge(a1, a2)


def main():
    input() # skip n
    a = [[int(w), 0, 0] for w in input().split()]
    b = merge_sort(a)
    print(sum(l * r for _, l, r in b))


main()
→ Ссылка