Да, как упоминалось в ответ @ Ярослава, можно и ключ ссылки он References: here и here. Я хочу подробно остановиться на этом ответе, представив конкретный пример.
Modulo opperation: Давай осуществлять поэлементную операцию по модулю в tensorflow (она уже существует, но его градиент не определен, но для примера, мы реализуем его с нуля).
Функция Numpy: Первым шагом является определение операции, которую мы хотим использовать для массивов numpy. Поэлементно по модулю opperation уже реализована в NumPy так легко:
import numpy as np
def np_mod(x,y):
return (x % y).astype(np.float32)
Причина .astype(np.float32)
потому, что по умолчанию tensorflow принимает float32 типы и если вы даете ему float64 (NumPy по умолчанию) она будет жаловаться ,
Градиент Функция: Далее нам нужно определить функцию градиента для нашей операции для каждого входа операции в качестве функции тензорного потока. Функция должна принимать очень специфическую форму. Необходимо принять представление тензорного потока операции op
и градиент выхода grad
и сказать, как распространять градиенты. В нашем случае градиенты операции mod
просты, производная равна 1 по первому аргументу и (почти всюду и бесконечна при конечном числе пятен, но давайте проигнорируем это, см. Подробности https://math.stackexchange.com/questions/1849280/derivative-of-remainder-function-wrt-denominator) с учетом ко второму аргументу. Итак, у нас есть
def modgrad(op, grad):
x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x.)
y = op.inputs[1] # the second argument
return grad * 1, grad * tf.neg(tf.floordiv(x, y)) #the propagated gradient with respect to the first and second argument respectively
Функция grad должна возвращать n-кортеж, где n - количество аргументов операции. Обратите внимание, что нам нужно вернуть функции тензорного потока ввода.
Выполнение функции TF с градиентом: Как объяснено в источниках, указанных выше, есть хак для определения градиентов функции с использованием tf.RegisterGradient
[doc] и tf.Graph.gradient_override_map
[doc].
Копирование кода из harpone мы можем изменить функцию tf.py_func
, чтобы сделать его определить градиент в то же время: импорт tensorflow в ТФ
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
# Need to generate a unique name to avoid duplicates:
rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": rnd_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
Опция stateful
это сказать tensorflow ли всегда дает функция тот же вывод для одного и того же входа (stateful = False), и в этом случае тензорный поток может просто графом тензорного потока, это наш случай и, вероятно, будет иметь место в большинстве ситуаций.
Объединяя все вместе: Теперь, когда у нас есть все части, мы можем объединить их все вместе:
from tensorflow.python.framework import ops
def tf_mod(x,y, name=None):
with ops.op_scope([x,y], name, "mod") as name:
z = py_func(np_mod,
[x,y],
[tf.float32],
name=name,
grad=modgrad) # <-- here's the call to the gradient
return z[0]
tf.py_func
действует на списках тензоров (и возвращает список тензоров), то есть почему у нас есть [x,y]
(и возвращается z[0]
). И теперь мы закончили. И мы можем проверить это.
Тест:
with tf.Session() as sess:
x = tf.constant([0.3,0.7,1.2,1.7])
y = tf.constant([0.2,0.5,1.0,2.9])
z = tf_mod(x,y)
gr = tf.gradients(z, [x,y])
tf.initialize_all_variables().run()
print(x.eval(), y.eval(),z.eval(), gr[0].eval(), gr[1].eval())
[0,30000001 0,69999999 1,20000005 1,70000005] [0,2 0,5 1. 2,9000001] [0,10000001 0,19999999 0,20000005 1,70000005] [1. 1. 1. 1.] [-1. -1. -1. 0.]
Успех!
Большое спасибо за этот пост! У вас есть, как определить Tensorflow форму 'z'? 'x' имеет форму (4),' y' имеет форму (4), но Tensorflow не знает, что 'z' имеет форму (4). Только во время выполнения он решит, что форма равна 4. –
'z [0] = tf.reshape (z [0], [int (x.get_shape() [0])])' Я использовал это, чтобы обеспечить его выполнение , –
Вы уверены, что ваши градиенты верны? Спасибо! Http: //math.stackexchange.com/questions/1849280/derivative-of-remainder-function-wrt-denominator –