Python Numpy, пересечение двумерных массивов
Имеется два массива np:
a = np.array([
[1, 2, 3, 5],
[5, 6, 3, 3],
[6, 8, 1, 7],
[9, 5, 1, 9],
[6, 4, 8, 12],
[5, 9, 2, 14],
])
b = np.array([
[1, 2, 3, 3],
[3, 3, 3, 6],
[6, 8, 1, 8],
[0, 0, 0, 1],
[5, 9, 2, 9],
])
Необходимо создать два массива a_ и b_. В эти массивы должны быть записаны строки, которые повторяются в обоих массивах по первым 3м элементам. То есть должны получиться такие массивы:
a_ = np.array([
[1, 2, 3, 5],
[6, 8, 1, 7],
[5, 9, 2, 14],
])
b_ = np.array([
[1, 2, 3, 3],
[6, 8, 1, 8],
[5, 9, 2, 9],
])
Ничего умнее цикла придумать не могу, может есть какая то функция о которой я не знаю и не могу нагуглить.
Можно в принципе поудалять неподходящие строки на месте прям в массивах a и b не создавая новых массивов a_ и b_ если есть возможность.
Ответы (3 шт):
Извиняюсь, сейчас не могу комментарий написать
a1 = a[:, :-1]
b1 = b[:, :-1]
m = (a1[:, None] == b1).all(-1).any(1)
print(a[m])
[[ 1 2 3 5]
[ 6 8 1 7]
[ 5 9 2 14]]
Можно воспользоваться библиотекой SciPy и функцией cdist для определения парных расстояний между векторами. Если расстояния равны нулю, то вектора равны. Для скорости, наверное, лучше использовать расстояние Хэмминга (matching).
import numpy as np
from scipy.spatial.distance import cdist
a = np.array([
[1, 2, 3, 5],
[5, 6, 3, 3],
[6, 8, 1, 7],
[9, 5, 1, 9],
[6, 4, 8, 12],
[5, 9, 2, 14],
])
b = np.array([
[1, 2, 3, 3],
[3, 3, 3, 6],
[6, 8, 1, 8],
[0, 0, 0, 1],
[5, 9, 2, 9],
])
ind1, ind2=np.nonzero(cdist(a[:, :-1],b[:, :-1], 'matching')==0)
print(a[ind1,:])
print(b[ind2,:])
Как правильно заметил Stanislav Volodarskiy, скорость и потребление памяти на больших массивах будут не очень хорошими, поэтому вот еще более быстрый вариант:
import numpy as np
h=1000
n=100000
a = np.random.randint(0, high=h, size=(n, 4))
b = np.random.randint(0, high=h, size=(n, 4))
def select(a, b):
a1=np.apply_along_axis(lambda x: hash(x.tobytes()), 1, a[:,:3])
b1=np.apply_along_axis(lambda x: hash(x.tobytes()), 1, b[:,:3])
_,idx1, idx2=np.intersect1d(a1, b1, return_indices=True)
print(a[idx1])
print(b[idx2])
select(a,b)
select фильтрует массив a относительно b. Сложность по памяти и времени линейная:
def select(a, b):
def key(v):
return tuple(v[:3])
b_set = set(map(key, b))
return a[np.apply_along_axis(lambda v: key(v) in b_set, 1, a)]
Например два массива по 100 тысяч строк укладываются в полторы секунды:
h = 1000
n = 100000
a = np.random.randint(0, high=h, size=(n, 4))
b = np.random.randint(0, high=h, size=(n, 4))
print(select(a, b))
print(select(b, a))
$ time python test.py [[290 811 482 291] [227 427 977 574] [904 141 433 193] [309 933 872 123] [341 585 128 447] [ 15 989 351 81] [303 64 711 972] [653 212 892 430]] [[303 64 711 838] [309 933 872 560] [653 212 892 127] [ 15 989 351 955] [904 141 433 975] [227 427 977 508] [341 585 128 518] [290 811 482 830]] real 0m1.134s user 0m1.324s sys 0m0.360s