Предположим, нам даны два двумерных массива a и b с одинаковым количеством строк. Предположим также, что мы знаем, что каждая строка i из a и b имеет не более одного общего элемента, хотя этот элемент может встречаться несколько раз. Как мы можем найти этот элемент максимально эффективно?

Пример:

import numpy as np

a = np.array([[1, 2, 3],
              [2, 5, 2],
              [5, 4, 4],
              [2, 1, 3]])

b = np.array([[4, 5],
              [3, 2],
              [1, 5],
              [0, 5]])

desiredResult = np.array([[np.nan],
                          [2],
                          [5],
                          [np.nan]])

Легко придумать прямолинейную реализацию, применив intersect1d вдоль первой оси:

from intertools import starmap

desiredResult = np.array(list(starmap(np.intersect1d, zip(a, b))))

По-видимому, использование встроенных в Python операций над наборами еще быстрее. Преобразовать результат в желаемую форму легко.

Однако мне нужна максимально эффективная реализация. Следовательно, мне не нравится starmap, поскольку я предполагаю, что он требует вызова python для каждой строки. Я хотел бы использовать чисто векторизованный вариант, и был бы рад, если бы он даже использовал наши дополнительные знания о том, что в строке есть не более одного общего значения.

У кого-нибудь есть идеи, как я мог бы ускорить задачу и более элегантно реализовать решение? Я был бы согласен с использованием кода C или Cython, но усилия по написанию кода должны быть не слишком большими.

6
Samufi 5 Июл 2019 в 03:41

3 ответа

Лучший ответ

Подход № 1

Вот векторизованный на основе searchsorted2d -

# Sort each row of a and b in-place
a.sort(1)
b.sort(1)

# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)

# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0

# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b

# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)

# Finally use np.where to choose between valid match 
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)

Подход № 2

Основанный на Numba для эффективности памяти -

from numba import njit

@njit(parallel=True)
def numba_f1(a,b,out):
    n,a_ncols = a.shape
    b_ncols = b.shape[1]
    for i in range(n):
        for j in range(a_ncols):
            for k in range(b_ncols):
                m = a[i,j]==b[i,k]
                if m:
                    break
            if m:
                out[i] = a[i,j]
                break
    return out

def find_first_common_elem_per_row(a,b):
    out = np.full(len(a),np.nan)
    numba_f1(a,b,out)
    return out

Подход № 3

Вот еще один векторизованный на основе укладки и сортировки -

r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)

Подход № 4

Для элегантного мы можем использовать broadcasting для ресурсоемкого метода -

m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)
4
Divakar 5 Июл 2019 в 06:11

Не уверен, что это быстрее, но мы можем попробовать пару вещей здесь:

Метод 1 np.intersect1d со списком

[np.intersect1d(arr[0], arr[1]) for arr in list(zip(a,b))]

# Out
[array([], dtype=int32), array([2]), array([5]), array([], dtype=int32)]

Или перечислить:

[np.intersect1d(arr[0], arr[1]).tolist() for arr in list(zip(a,b))]

# Out
[[], [2], [5], []]

Метод 2 set со списком:

[list(set(arr[0]) & set(arr[1])) for arr in list(zip(a,b))]

# Out
[[], [2], [5], []]
0
Erfan 5 Июл 2019 в 01:37

Проведя некоторое исследование, я обнаружил, что проверка того, являются ли два списка непересекающимися, выполняется в O (n + m) , где n и m - это длины списков (см. здесь ) . Идея состоит в том, что вставка и поиск элементов выполняются в постоянное время для хэш-карт. Поэтому для вставки всех элементов из первого списка в хэш-карту требуются операции O (n) , а для проверки для каждого элемента во втором списке, находится ли он уже в хэш-карте, требуется O (m). ) операции. Поэтому решения, основанные на сортировке, которые выполняются в O (n log (n) + m log (m)) , не являются оптимальными асимптотически.

Хотя решения @Divakar очень эффективны во многих случаях, они менее эффективны, если второе измерение велико. Тогда решение, основанное на хэш-картах, лучше подходит. Я реализовал это следующим образом в cython:

import numpy as np
cimport numpy as np
import cython
from libc.math cimport NAN
from libcpp.unordered_map cimport unordered_map
np.import_array()

@cython.boundscheck(False)
@cython.wraparound(False)
def get_common_element2d(np.ndarray[double, ndim=2] arr1, 
                         np.ndarray[double, ndim=2] arr2):

    cdef np.ndarray[double, ndim=1] result = np.empty(arr1.shape[0])
    cdef int dim1 = arr1.shape[1]
    cdef int dim2 = arr2.shape[1]
    cdef int i, j
    cdef unordered_map[double, int] tmpset = unordered_map[double, int]()

    for i in range(arr1.shape[0]):
        for j in range(dim1):
            # insert arr1[i, j] as key without assigned value
            tmpset[arr1[i, j]]
        for j in range(dim2):
            # check whether arr2[i, j] is in tmpset
            if tmpset.count(arr2[i,j]):
                result[i] = arr2[i,j]
                break
        else:
            result[i] = NAN
        tmpset.clear()

    return result

Я создал контрольные примеры следующим образом:

import numpy as np
import timeit
from itertools import starmap
from mycythonmodule import get_common_element2d

m, n = 3000, 3000
a = np.random.rand(m, n)
b = np.random.rand(m, n)

for i, row in enumerate(a):
    if np.random.randint(2):
        common = np.random.choice(row, 1)
        b[i][np.random.choice(np.arange(n), np.random.randint(min(n,20)), False)] = common

# we need to copy the arrays on each test run, otherwise they 
# will remain sorted, which would bias the results

%timeit [set(aa).intersection(bb) for aa, bb in zip(a.copy(), b.copy())]
# returns 3.11 s ± 56.8 ms

%timeit list(starmap(np.intersect1d, zip(a.copy(), b.copy)))
# returns 1.83 s ± 55.4

# test sorting method
# divakarsMethod1 is the appraoch #1 in @Divakar's answer
%timeit divakarsMethod1(a.copy(), b.copy())
# returns 1.88 s ± 18 ms

# test hash map method
%timeit get_common_element2d(a.copy(), b.copy())
# returns 1.46 s ± 22.6 ms

Эти результаты показывают, что наивный подход на самом деле лучше, чем некоторые векторизованные версии. Тем не менее, векторизованные алгоритмы разыгрывают свои сильные стороны, если рассматривать много строк с меньшим количеством столбцов (другой вариант использования). В этих случаях векторизованные подходы более чем в 5 раз быстрее, чем наивный метод, и метод сортировки оказывается наилучшим.

Заключение . Я пойду с версией для Cython, основанной на HashMap, поскольку она является одним из наиболее эффективных вариантов в обоих случаях использования. Если бы мне сначала пришлось настроить Cython, я бы использовал метод сортировки.

2
Samufi 5 Июл 2019 в 08:48