Ниже приведен фрагмент кода, который, учитывая state
, генерирует action
из зависимого от состояния распределения (prob_policy
). Затем веса графика обновляются в соответствии с потерями, которые в -1 раз превышают вероятность выбора этого действия. В следующем примере как среднее (mu
), так и ковариация (sigma
) MultivariateNormal обучаемы / изучены.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
# make the graph
state = tf.placeholder(tf.float32, (1, 2), name="state")
mu = tf.contrib.layers.fully_connected(
inputs=state,
num_outputs=2,
biases_initializer=tf.ones_initializer)
sigma = tf.contrib.layers.fully_connected(
inputs=state,
num_outputs=2,
biases_initializer=tf.ones_initializer)
sigma = tf.squeeze(sigma)
mu = tf.squeeze(mu)
prob_policy = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
action = prob_policy.sample()
picked_action_prob = prob_policy.prob(action)
loss = -tf.log(picked_action_prob)
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
# run the optimizer
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
state_input = np.expand_dims([0.,0.],0)
_, action_loss = sess.run([train_op, loss], { state: state_input })
print(action_loss)
Тем не менее, когда я заменяю эту строку
prob_policy = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
Со следующей строкой (и закомментируйте строки, которые генерируют сигма-слой и сожмите его)
prob_policy = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=[1.,1.])
Я получаю следующую ошибку
ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables ["<tf.Variable 'fully_connected/weights:0' shape=(2, 2) dtype=float32_ref>", "<tf.Variable 'fully_connected/biases:0' shape=(2,) dtype=float32_ref>"] and loss Tensor("Neg:0", shape=(), dtype=float32).
Я не понимаю, почему это происходит. Разве он по-прежнему не может принимать градиент по отношению к весам в слое mu
? Почему создание ковариации константы распределения внезапно делает ее недифференцируемой?
Сведения о системе:
- Тензор потока 1.13.1
- Вероятность тензорного потока 0.6.0
- Python 3.6.8
- MacOS 10.13.6
2 ответа
Существует проблема, вызванная некоторым кэшированием, которое мы делаем внутри MVNDiag (и других подклассов TransformedDistribution) для обратимости.
Если вы сделаете + 0
(как обходной путь) после вашего .sample (), градиент сработает.
Также я бы предложил использовать dist.log_prob(..)
вместо tf.log(dist.prob(..))
. Лучше цифры.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
# make the graph
state = tf.placeholder(tf.float32, (1, 2), name="state")
mu = tf.contrib.layers.fully_connected(
inputs=state,
num_outputs=2,
biases_initializer=tf.ones_initializer)
sigma = tf.contrib.layers.fully_connected(
inputs=state,
num_outputs=2,
biases_initializer=tf.ones_initializer)
sigma = tf.squeeze(sigma)
mu = tf.squeeze(mu)
prob_policy = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=[1.,1.])
action = prob_policy.sample() + 0
loss = -prob_policy.log_prob(action)
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
# run the optimizer
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
state_input = np.expand_dims([0.,0.],0)
_, action_loss = sess.run([train_op, loss], { state: state_input })
print(action_loss)
Я должен был изменить эту строку
action = prob_policy.sample()
К этой строке
action = tf.stop_gradient(prob_policy.sample())
Если у кого-то есть объяснение того, почему изучение весов ковариации делает весовые коэффициенты локально дифференцируемыми по отношению к потере, а если сделать ковариацию постоянной, то нет, и как это изменение линии влечет за собой это, я бы с удовольствием объяснил! Благодарность!
Похожие вопросы
Новые вопросы
python
Python - это многопарадигмальный, динамически типизированный, многоцелевой язык программирования. Он разработан для быстрого изучения, понимания и использования, а также для обеспечения чистого и единообразного синтаксиса. Обратите внимание, что Python 2 официально не поддерживается с 01.01.2020. Тем не менее, для вопросов о Python, связанных с версией, добавьте тег [python-2.7] или [python-3.x]. При использовании варианта Python (например, Jython, PyPy) или библиотеки (например, Pandas и NumPy) включите его в теги.