Выбрать лучшие перекрытия интервалов

У меня есть список из нескольких взвешенных интервалов - это кортежи (value, start, stop):

#   0   1   2   3   4   5   6   7   8   9   10
# 1         [___]                           
# 2 [___________________________________]   
# 3 [_______]                               
# 8     [_______________________]           
# 7                     [___________________]
# 6                                 [___]   
# 4     [_______________________]           

# input
# (value: float, start: float, stop: float)
list_in = [(2, 0, 9), (3, 0, 2), (8, 1, 7), (4, 1, 7), (1, 3, 4), (5, 5, 10), (6, 8, 9)]

Я хочу найти все частичные перекрытия и выбрать из них "лучший", то есть, с наибольшим value.

Для моего примера правильный ответ:

#   0   1   2   3   4   5   6   7   8   9   10
# 1                                         
# 2                                         
# 3 [___]                                   
# 8     [_______________________]           
# 7                             [___________]
# 6                                         
# 4                                         

# output
list_out = [(3, 0, 1), (8, 1, 7), (7, 7, 10)]

Я постараюсь не усложнять всеми условиями, которые у меня есть, но вот минимум из того что вы должны знать:

  1. Входной массив кортежей никак не отсортирован
  2. У меня есть реализация дерева интервалов, если алгоритм производительнее или проще с ним - используйте его

3) Нужно учитывать -float('inf') для начала интервала и float('inf') для конца интервала.

  1. Если значение value у обоих интервалов совпадает, то лучше тот - который короче. Если при этом они одинаковы по длинне, то это точная копия интервала - игнорируйте точные копии.
  2. Игнорируйте интервалы, у которых start == stop

Если это упростит вам задачу, то вот все возможные кейсы взаимного расположения интервалов:

    # All cases for two intervals:
    #    0    1    2    3    4    5    6    7    8    9
    #                   [______________]               
    #                                                  
    #  1 [_________]    |              |                 altc_altd_bltc_bltd
    #  2      [_________]              |                 altc_altd_beqc_bltd
    #  3           [_________]         |                 altc_altd_bgtc_bltd
    #  4                [_________]    |                 aeqc_altd_bgtc_bltd
    #  5                |     [___]    |                 agtc_altd_bgtc_bltd
    #  6                [______________]                 aeqc_altd_bgtc_beqd
    #  7           [___________________]                 altc_altd_bgtc_beqd
    #  8           [________________________]            altc_altd_bgtc_bgtd
    #  9                [___________________]            aeqc_altd_bgtc_bgtd
    # 10                |    [_________]                 agtc_altd_bgtc_beqd
    # 11                |         [_________]            agtc_altd_bgtc_bgtd
    # 12                |              [_________]       agtc_aeqb_bgtc_bgtd
    # 13                |              |    [_________]  agtc_agtd_bgtc_bgtd    

Все кейсы для интервала и точки:

value, start, stop = interval
point = x

point < start
point == start
start < point and point < stop
point == stop
stop < point

Любые мои попытки решить эту задачу "в лоб" вызывают комбинаторный взрыв и превращают код в callback hell

Пожалуйста, помогите. ЗЫ: я могу вставить свою попытку, но она занимает 200 строк кода.


Благодаря ответу @StanislavVolodarskiy, я смог упростить код до следующего:

import heapq
from typing import Callable, Optional


__all__ = ['DataInterval', 'inlay_right_to_left', 'inlay_data_intervals', 'TestInlayTwoWeigtedIntervals']

class DataInterval(object):
    _identifier_counter = 0
    
    def __init__(self, weight, start, stop) -> None:
        self._identifier = self._identifier_counter
        self._identifier_counter += 1
        
        self.data: float = weight
        self.start: float = start
        self.stop: float = stop
    
    @property
    def identifier(self):
        return self._identifier
    
    def __eq__(self, other: object) -> bool:
        if isinstance(other, self.__class__):
            return self._identifier == other._identifier
        
        return False
    
    def __hash__(self) -> int:
        return hash(self._identifier)

def inlay_data_intervals(intervals: list[DataInterval], key: Optional[Callable[[DataInterval], float]]=None):
    if key is None:
        key: Callable[[DataInterval], float] = lambda interval: (interval.data, interval.start, interval.stop)
    
    precomputed_values = []
    
    # event queue
    # list[(x, is_end, index)]
    events: list[tuple] = []
    for index, interval in enumerate(intervals):
        precomputed_values.append(key(interval))
        events.append((interval.start, False, index))
        events.append((interval.stop, True, index))
    
    events.sort()
    
    # status
    deleted = [False] * (len(events) // 2)
    heap = []

    # print
    last_start = None
    last_index = None

    for x, is_end, index in events:
        # update status
        if is_end:
            deleted[index] = True
        else:
            heapq.heappush(heap, (precomputed_values[index], index))

        # remove deleted items from heap top
        while len(heap) > 0 and deleted[heap[0][1]]:
            heapq.heappop(heap)

        current_value, current_index = heap[0] if len(heap) > 0 else (None, None)
        last_value = precomputed_values[last_index] if last_index is not None else None
        if current_value != last_value:
            if last_index is not None and last_start < x:
                yield DataInterval(intervals[last_index].data, last_start, x)
            
            last_start = x
            last_index = current_index

Также есть Тесты

Громоздко, но я не разобрался как заставить VS Code видеть параметризованные тесты.


Ответы (2 шт):

Автор решения: Akina

Показываю на исходных данных.

Шаг 1-1. Сбор точек. После него получаем массив

0,1,2,3,5,7,8,9,10

Шаг 1-2. Конвертация в минимальные интервалы. Получаем набор

0-1
1-2
2-3
3-5
5-7
7-8
8-9
9-10

Шаг 2. Для каждого интервала подбираем перекрывающие их исходные интервалы, собираем веса, выбираем максимальный

0-1 веса 2,3
1-2 веса 2,3,8,4
2-3 веса 1,2,8,4
3-5 веса 2,8,4
5-7 веса 2,8,7,4
7-8 веса 2,7
8-9 веса 2,7,6
9-10 веса 7

Шаг 3. Объединяем соседние интервалы, если их веса равны. Получаем

0-1 веса 2,3 ИТОГО 0-1 вес 3
1-2 веса 2,3,8,4
2-3 веса 1,2,8,4
3-5 веса 2,8,4
5-7 веса 2,8,7,4 ИТОГО 1-7 вес 8
7-8 веса 2,7
8-9 веса 2,7,6
9-10 веса 7 ИТОГО 7-10 вес 7
→ Ссылка
Автор решения: Stanislav Volodarskiy

Задача решается заметанием. События - концы интервалов. Очередь событий - сортированный массив, при равных временах начала отрезков сортируются до концов. Статус: множество отрезков, покрывающих данную точку. Структура для множества - максимальная куча значений, поддерживающая удаление элементов.

Обработка события:

  • если это начало отрезка, добавить его значение в множество
  • если это конец отрезка, пометить его удалённым в множестве

Для печати результата хранится начало текущего максимума (если оно есть) и сам текущий максимум. Если максимум изменился в этом событии,

  • печатается тройка (максимум, начало, текущее событие);
  • обновляется начало и текущий максимум.
import heapq


def min_values(intervals):
    # event queue
    events = [] # (x, is_end, index, value, label)
    for index, (start, stop, value, label) in enumerate(intervals):
        events.append((start, False, index, value, label))
        events.append((stop , True , index, value, label))
    events.sort()

    # status
    deleted = [False] * (len(events) // 2)
    heap = []

    # print
    last_start = None
    last_value = None
    last_label = None

    for x, is_end, index, value, label in events:
        # update status
        if is_end:
            deleted[index] = True
        else:
            heapq.heappush(heap, (value, index, label))

        # remove deleted items from heap top
        while heap and deleted[heap[0][1]]:
            heapq.heappop(heap)
        
        # print
        value = heap[0][0] if heap else None
        label = heap[0][2] if heap else None
        if value != last_value:
            if last_value is not None and last_start < x:
                yield last_label, last_start, x 
            last_start = x
            last_value = value
            last_label = label


def max_values(intervals):
    return min_values(
        (start, stop, (-value, stop - start), value)
        for value, start, stop in intervals
    )


def test():
    for intervals, expected in (
        # value left > value right
        ([(2, 3, 6), (1, 3, 6)], [(2, 3, 6)]),
        ([(2, 3, 6), (1, 0, 2)], [(1, 0, 2), (2, 3, 6)]),
        ([(2, 3, 6), (1, 1, 3)], [(1, 1, 3), (2, 3, 6)]),
        ([(2, 3, 6), (1, 2, 4)], [(1, 2, 3), (2, 3, 6)]),
        ([(2, 3, 6), (1, 3, 5)], [(2, 3, 6)]),
        ([(2, 3, 6), (1, 4, 5)], [(2, 3, 6)]),
        ([(2, 3, 6), (1, 5, 7)], [(2, 3, 6), (1, 6, 7)]),
        ([(2, 3, 6), (1, 6, 8)], [(2, 3, 6), (1, 6, 8)]),
        ([(2, 3, 6), (1, 7, 9)], [(2, 3, 6), (1, 7, 9)]),
        ([(2, 3, 6), (1, 2, 7)], [(1, 2, 3), (2, 3, 6), (1, 6, 7)]),
        
        # value left == value right
        ([(2, 3, 6), (2, 3, 6)], [(2, 3, 6)]), # is copy
        ([(2, 3, 6), (2, 0, 2)], [(2, 0, 2), (2, 3, 6)]), # 2. / (6. - 3.) < 2. / (2. - 0.)
        ([(2, 3, 6), (2, 1, 3)], [(2, 1, 3), (2, 3, 6)]), # 2. / (6. - 3.) < 2. / (3. - 1.)
        ([(2, 3, 6), (2, 2, 4)], [(2, 2, 4), (2, 4, 6)]), # 2. / (6. - 3.) < 2. / (4. - 2.)
        ([(2, 3, 6), (2, 3, 5)], [(2, 3, 5), (2, 5, 6)]), # 2. / (6. - 3.) < 2. / (5. - 3.)
        ([(2, 3, 6), (2, 4, 5)], [(2, 3, 4), (2, 4, 5), (2, 5, 6)]), # 2. / (6. - 3.) < 2. / (5. - 4.)
        ([(2, 3, 6), (2, 5, 7)], [(2, 3, 5), (2, 5, 7)]), # 2. / (6. - 3.) < 2. / (7. - 5.)
        ([(2, 3, 6), (2, 6, 8)], [(2, 3, 6), (2, 6, 8)]), # 2. / (6. - 3.) < 2. / (8. - 6.)
        ([(2, 3, 6), (2, 7, 9)], [(2, 3, 6), (2, 7, 9)]), # 2. / (6. - 3.) < 2. / (9. - 7.)
        ([(2, 3, 6), (2, 2, 7)], [(2, 2, 3), (2, 3, 6), (2, 6, 7)]), # 2. / (6. - 3.) > 2. / (7. - 2.)
        
        # value left < value right
        ([(2, 3, 6), (3, 3, 6)], [(3, 3, 6)]),
        ([(2, 3, 6), (3, 0, 2)], [(3, 0, 2), (2, 3, 6)]),
        ([(2, 3, 6), (3, 1, 3)], [(3, 1, 3), (2, 3, 6)]),
        ([(2, 3, 6), (3, 2, 4)], [(3, 2, 4), (2, 4, 6)]),
        ([(2, 3, 6), (3, 3, 5)], [(3, 3, 5), (2, 5, 6)]),
        ([(2, 3, 6), (3, 4, 5)], [(2, 3, 4), (3, 4, 5), (2, 5, 6)]),
        ([(2, 3, 6), (3, 5, 7)], [(2, 3, 5), (3, 5, 7)]),
        ([(2, 3, 6), (3, 6, 8)], [(2, 3, 6), (3, 6, 8)]),
        ([(2, 3, 6), (3, 7, 9)], [(2, 3, 6), (3, 7, 9)]),
        ([(2, 3, 6), (3, 2, 7)], [(3, 2, 7)]),
    ):
        actual = list(max_values(intervals))
        if actual != expected:
            print(intervals, expected, actual)


test()
→ Ссылка