Скажем, у меня есть массив расстояний x=[1,2,1,3,3,2,1,5,1,1].

Я хочу получить индексы от x, где cumsum достигает 10, в данном случае idx = [4,9].

Таким образом, cumsum перезапускается после выполнения условия.

Я могу сделать это с помощью цикла, но циклы медленны для больших массивов, и мне было интересно, смогу ли я сделать это vectorized способом.

8
user3194861 5 Июл 2019 в 16:51

3 ответа

Лучший ответ

Вот один с нумбой и инициализацией массива -

from numba import njit

@njit
def cumsum_breach_numba2(x, target, result):
    total = 0
    iterID = 0
    for i,x_i in enumerate(x):
        total += x_i
        if total >= target:
            result[iterID] = i
            iterID += 1
            total = 0
    return iterID

def cumsum_breach_array_init(x, target):
    x = np.asarray(x)
    result = np.empty(len(x),dtype=np.uint64)
    idx = cumsum_breach_numba2(x, target, result)
    return result[:idx]

< Сильный > Задержки

Включая @piRSquared's solutions и используя настройку сравнения из того же поста -

In [58]: np.random.seed([3, 1415])
    ...: x = np.random.randint(100, size=1000000).tolist()

# @piRSquared soln1
In [59]: %timeit list(cumsum_breach(x, 10))
10 loops, best of 3: 73.2 ms per loop

# @piRSquared soln2
In [60]: %timeit cumsum_breach_numba(np.asarray(x), 10)
10 loops, best of 3: 69.2 ms per loop

# From this post
In [61]: %timeit cumsum_breach_array_init(x, 10)
10 loops, best of 3: 39.1 ms per loop

Numba: добавление или инициализация массива

Для более детального изучения того, как помогает инициализация массива, что, по-видимому, является большой разницей между двумя реализациями numba, давайте разберем их на данных массива, поскольку создание данных массива само по себе было тяжелым во время выполнения, и они оба зависят от него -

In [62]: x = np.array(x)

In [63]: %timeit cumsum_breach_numba(x, 10)# with appending
10 loops, best of 3: 31.5 ms per loop

In [64]: %timeit cumsum_breach_array_init(x, 10)
1000 loops, best of 3: 1.8 ms per loop

Чтобы у вывода был собственный объем памяти, мы можем сделать копию. Хотя это не сильно изменит ситуацию -

In [65]: %timeit cumsum_breach_array_init(x, 10).copy()
100 loops, best of 3: 2.67 ms per loop
6
Divakar 5 Июл 2019 в 16:17

Забавный метод

sumlm = np.frompyfunc(lambda a,b:a+b if a < 10 else b,2,1)
newx=sumlm.accumulate(x, dtype=np.object)
newx
array([1, 3, 4, 7, 10, 2, 3, 8, 9, 10], dtype=object)
np.nonzero(newx==10)

(array([4, 9]),)
4
YO and BEN_W 5 Июл 2019 в 14:22

Циклы не всегда плохие (особенно когда они вам нужны). Кроме того, нет инструмента или алгоритма, который сделает это быстрее, чем O (n). Итак, давайте просто сделаем хорошую петлю.

Функция генератора

def cumsum_breach(x, target):
    total = 0
    for i, y in enumerate(x):
        total += y
        if total >= target:
            yield i
            total = 0

list(cumsum_breach(x, 10))

[4, 9]

Как раз во время компиляции с Numba

Numba - это сторонняя библиотека, которую необходимо установить.
Numba может быть привередливой о том, какие функции поддерживаются. Но это работает.
Кроме того, как отмечает Дивакар, Numba лучше работает с массивами

from numba import njit

@njit
def cumsum_breach_numba(x, target):
    total = 0
    result = []
    for i, y in enumerate(x):
        total += y
        if total >= target:
            result.append(i)
            total = 0

    return result

cumsum_breach_numba(x, 10)

Тестирование двух

Потому что мне так хотелось ¯\_(ツ)_/¯

< Сильный > Настройка

np.random.seed([3, 1415])
x0 = np.random.randint(100, size=1_000_000)
x1 = x0.tolist()

< Сильный > Точность

i0 = cumsum_breach_numba(x0, 200_000)
i1 = list(cumsum_breach(x1, 200_000))

assert i0 == i1

< Сильный > Время

%timeit cumsum_breach_numba(x0, 200_000)
%timeit list(cumsum_breach(x1, 200_000))

582 µs ± 40.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
64.3 ms ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Нумба была на порядок в 100 раз быстрее.

Для более верного теста от яблок к яблокам я конвертирую список в массив Numpy

%timeit cumsum_breach_numba(np.array(x1), 200_000)
%timeit list(cumsum_breach(x1, 200_000))

43.1 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.8 ms ± 327 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Что приводит их примерно к равному.

6
piRSquared 5 Июл 2019 в 15:03