Количество разных чисел по модулю (оптимизация)

столкнулся вроде бы как с простой задачкой, но не получается уложиться в ограничение по памяти, заданное на сайте.

Задан отсортированный массив целых чисел. Найдите количество различных по модулю чисел среди элементов массива.

Входные данные: Первая строка содержит количество чисел n (n ≤ 2 * 10^6). Вторая строка содержит n целых чисел, отсортированных по возрастанию. Массив может содержать одинаковые элементы.

Выходные данные: Выведите количество различных по модулю чисел.

Пример : Входные данные:

9
-1 -1 -1 -1 0 1 1 1 1

Выходные данные:

2

У меня два теста кушают 131 072 KiB, когда заданное ограничение - 128 MiB Есть подозрения, что нужно всё сделать через словари, но я не совсем понимаю, как это оформить. Ссылка на задание Вот мой код:

n = int(input())
lst_1 = [int(el) for el in input().split()]
counter = 0
lst = set(lst_1)
r = []
for i in lst:
    r.append(abs(i))
print(len(set(r)))

Ответы (3 шт):

Автор решения: Zhihar

вот решение задачи за линейное время без множеств и прочего, что жрет и память и дополнительное время:

основные моменты:

  1. мы идем слева направо по отрицательным числам (если и положительные - это значения не имеет)

  2. мы идем справа налево по положительным числам

  3. каждое уникальное левое число увеличивает счетчик уникальных чисел на 1

  4. каждое уникальное правое число такое, что оно больше левого числа по модулю увеличивает счетчик на 1

код:

n = int(input())
data = list(map(int, input().split()))

count = 0

pos = 0
pos_r = len(data) - 1

while pos <= pos_r:
    # получить отрицательное число
    num = data[pos]
    count += 1

    # пропустить одинаковые отрицательные числа
    while pos <= pos_r and data[pos] == num:
        pos += 1

    # выйти, если все отрицательные числа перебраны
    if pos > pos_r:
        break

    # перебрать все положительные числа до модуля искомого отрицательного
    while data[pos_r] >= abs(num) and pos < pos_r:
        num_r = data[pos_r]

        if num_r > abs(num):
            count += 1

        # пропустить одинаковые положительные числа
        while pos_r >= pos and data[pos_r] == num_r:
            pos_r -= 1

print(count)

код для тестирования (с оптимизациями, которые дали 30-50% прирост скорости):

import time

limit = 10000000
data = [i for i in range(-limit, limit)]

start_time = time.perf_counter()

count = 0

pos = 0
pos_r = len(data) - 1

while pos <= pos_r:
    # получить отрицательное число
    num = data[pos]
    count += 1

    # пропустить одинаковые отрицательные числа
    pos += 1    # ОПТИМИЗАЦИЯ позволяет избежать входа в цикл, если числа находятся рядом
    while pos <= pos_r and data[pos] == num:
        pos += 1

    # выйти, если все отрицательные числа перебраны
    if pos > pos_r:
        break

    # перебрать все положительные числа до модуля искомого отрицательного
    num_abs = abs(num)  # ОПТИМИЗАЦИЯ, однократное вычисление абсолютного значения вне цикла, где оно используется
    while data[pos_r] >= num_abs and pos < pos_r:
        num_r = data[pos_r]

        if num_r > num_abs:
            count += 1

        # пропустить одинаковые положительные числа
        pos_r -= 1    # ОПТИМИЗАЦИЯ позволяет избежать входа в цикл, если числа находятся рядом
        while pos_r >= pos and data[pos_r] == num_r:
            pos_r -= 1

end_time = time.perf_counter ()

print(count)
print(end_time - start_time, "seconds")
print((2 * limit / (end_time - start_time)) * 4 / (1024 * 1024), "MB/sec")
→ Ссылка
Автор решения: Stanislav Volodarskiy

Задача оказалась про скорость ввода данных. Вот решение которое проходит все тесты (результаты):

import sys


def read():
    READ_SIZE = 1024
    tail = ''
    while True:
        block = sys.stdin.read(READ_SIZE)
        if len(block) == 0:
            yield from tail.split()
            return
        text = tail + block

        last_ws = text.rfind(' ')
        if last_ws == -1:
            tail = text
            continue

        yield from text[:last_ws].split()
        tail = text[last_ws:]
    

input()
print(len(set(abs(int(w)) for w in read())))

Ограничения задачи - одна секунда, 128MB. Я сделал несколько тестов, чтобы понять как эти ограничения работают.

Прочитаем данные самым привычным для Питона способом:

input()
print(len(input().split()))

Результаты. То что это решение не решает задачу не важно. Меня интересует только память и время. Тесты 9 и 10 исчерпывают память полностью. На что расходуется память? На строку, которую вернул input(). На миллион (буквально) маленьких строчек, которые нарезал .split() и на список хранящий эти строки.

Главный вывод: каким бы ни был алгоритм обработки этого списка в дальнейшем, мы не помещаемся в память. Так читать данные нельзя.

А как можно? Единственное решение высокого уровня, которое читает слова из строки без создания длинного списка - re.finditer. Если вы знаете другие, подскажите.

import re

input()
print(sum(1 for _ in re.finditer('[0123456789]+', input())))

Результаты. Снова нас не интересует правильность, только память и время. На этот раз по памяти мы проходит свободно. Наибольшее потребление на девятом тесте - 24MB. Память расходуется на входную строку и миллион (буквально) объектов которые описывают найденные подстроки. Эти объекты не должны хранится, полагаю что сборщик мусора не успевает их убирать. Не суть.

Время работы этого варинта обескураживает: re.finditer тратит 412ms на восьмой тест, .split() - 146ms. re.finditer в 2.8 раза хуже.

Времени, которое остаётся после чтения данных re.finditer не хватает чтобы посчитать правильный ответ:

import re

input()
print(len(set(int(m.group(0)) for m in re.finditer('[0123456789]+', input()))))

Результаты. Этот вариант даёт верные ответы, но не проходит по времени в тестах 9 и 10. Сколько времени нам не хватает? Восьмой тест в этом варианте - 840ms, в предыдущем - 412ms. Девятый тест в предыдущем - 758ms. Предполагая что обе программы имеют линейную сложность, получаем пропорцию для девятого теста на последнем коде: 758 / 412 * 840 = 1547ms. На полсекунды больше лимита по времени.

Напишем ввод руками. Функция read читает входной поток небольшими кусками, не хранит его целиком, выдаёт наружу как можно быстрее. Проверим как быстро читаются данные таким образом:

import sys


def read():
    READ_SIZE = 1024
    tail = ''
    while True:
        block = sys.stdin.read(READ_SIZE)
        if len(block) == 0:
            yield from tail.split()
            return
        text = tail + block

        last_ws = text.rfind(' ')
        if last_ws == -1:
            tail = text
            continue

        yield from text[:last_ws].split()
        tail = text[last_ws:]
    

input()
print(sum(1 for _ in read()))

Результаты. Девятый тест занимает 374ms. Аналогичное время для re.finditer - 758ms. Рукописное чтение обогнало библиотечное в два раза - неожиданный результат.

Сравним времена input().split(), re.finditer(..., input()) и read() на тестах которые прошли по памяти для всех трёх вариантах:

вариант                      тест 7   тест 8 

input().split()               125       146
read()                        185       193
re.finditer(..., input())     395       412

Самописный read() в полтора раза медленнее библиотечного input().split(), что подтверждает правило "чем выше уровень, тем быстрее код на Питоне". re.finditer разочаровал.

Для решения задачи пришлось идиоматичное но расточительное по памяти решение (input().split()) заменить менее быстрым но экономным рукописным (read()).

→ Ссылка
Автор решения: SergFSM

вот, попробовал сделать со словарем, но не знаю пройдет ли тест, попробуйте:

from re import finditer

s = '-33 -11 -3 -2 -1 -1 0 1 1 1 1 4 5 11 33'
d = {}

for m in finditer(r'\d+', s):
    d[m[0]] = d.get(m[0], 0) + 1

print(sum(v==1 for v in d.values()))  # 5
→ Ссылка