Мне нужно создать трехмерный тензор, подобный этому (5,3,2), например array ([[[0, 0], [0, 1], [0, 0]], [[1, 0], [0 , 0], [0, 0]], [[0, 0], [1, 0], [...
2 ответа
Попробуйте сгенерировать случайный массив, затем найдите max
:
a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
Самый простой способ сделать это, вероятно, - создать массив нулей и установить случайный индекс равным 1. В NumPy это может выглядеть так:
import numpy as np
K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1
В JAX это может выглядеть примерно так:
import jax.numpy as jnp
from jax import random
K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)
Похожие вопросы
Связанные вопросы
Новые вопросы
python
Python - это многопарадигмальный, динамически типизированный, многоцелевой язык программирования. Он разработан для быстрого изучения, понимания и использования, а также для обеспечения чистого и единообразного синтаксиса. Обратите внимание, что Python 2 официально не поддерживается с 01.01.2020. Тем не менее, для вопросов о Python, связанных с версией, добавьте тег [python-2.7] или [python-3.x]. При использовании варианта Python (например, Jython, PyPy) или библиотеки (например, Pandas и NumPy) включите его в теги.