5

Все,Как проверить NaN в градиентах в Tensorflow при обновлении?

При подготовке большой модели с большим количеством образцов некоторые образцы могут быть причиной градиента NaN при обновлении параметра.

И я хочу найти эти образцы. Между тем я не хочу, чтобы этот градиент пакетных выборок обновлял параметр модели, потому что это может быть причиной того, что параметр модели NaN.

Итак, у любой есть хорошая идея справиться с этой проблемой?

Мой код, как показано ниже:

# Create an optimizer. 
    params = tf.trainable_variables() 
    opt = tf.train.AdamOptimizer(1e-3) 
    gradients = tf.gradients(self.loss, params) 

    max_gradient_norm = 10 
    clipped_gradients, self.gradient_norms = tf.clip_by_global_norm(gradients, 
                max_gradient_norm) 

    self.optimizer = opt.apply_gradients(zip(clipped_gradients, params)) 

ответ

9

Вы можете проверить ваши градиенты NaN по tf.check_numerics:

grad_check = tf.check_numerics(clipped_gradients) 
with tf.control_dependencies([grad_check]): 
    self.optimizer = opt.apply_gradients(zip(clipped_gradients, params)) 

grad_check бы бросить InvalidArgument если clipped_gradients является NaN или бесконечность.

tf.control_dependencies проверяет, чтобы оценка grad_check оценивалась перед применением градиентов.

Также см. tf.add_check_numerics_ops().

+0

Большое вам спасибо. Но у меня опять вопрос. Когда я добавляю tf.add_check_numerics_ops() в свой код, я получаю ошибку Out of Memory. И удалите эту строку, все в порядке. Моя модель действительно большая, поэтому tf.add_check_numerics_ops() будет выделять больше памяти GPU для операций проверки? – Issac

+0

В ядре 'tf.check_numerics' есть тензорная копия: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/check_numerics_op.cc#L60. Поэтому просто добавьте пару проверок с помощью 'tf.check_numerics'. Необходимо принять дополнительные меры предосторожности, если вы хотите использовать 'tf.add_check_numerics_ops()' который запускает 'tf.check_numerics' для всех тензоров с плавающей запятой. – yuefengz

0

Вы можете использовать tf.is_nan в сочетании с tf.cond, чтобы выполнить остальную часть кода, если потеря не NaN.

Смежные вопросы