2016-10-30 5 views
1

В настоящее время я реализую http://www.aclweb.org/anthology/P15-1061 в тензорном потоке.Эффективно вычислять функцию потерь по положению в Tensorflow

Я выполнил парного ранжирования функции потерь (раздел 2.5 бумаги) следующим образом:

s_theta_y = tf.gather(tf.reshape(s_theta, [-1]), y_true_index) 
s_theta_c_temp = tf.reshape(tf.gather(tf.reshape(s_theta, [-1]), y_neg_index), [-1, classes_size]) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

мне пришлось использовать tf.gather, а не tf.gather_nd, поскольку последний еще не реализован с градиентом спуск. Мне также пришлось преобразовать все индексы, чтобы они были правильными с матрицей сглаживания.

Если tf.gather_nd был реализован с градиентным спуском, мой код был бы следующим образом:

s_theta_y = tf.gather_nd(s_theta, y_t_index) 
s_theta_c_temp = tf.gather_nd(s_theta, y_neg_index) 
s_theta_c = tf.reduce_max(s_theta_c_temp, reduction_indices=[1]) 

s_theta является вычисленной оценкой для каждой метки класса, как и в работе. y_true_index содержит индекс истинного класса, чтобы вычислить s_theta_y. y_neg_index - это индекс всех отрицательных классов, его размеры либо # class-1, либо #class - это отношение классифицируется как другое.

Однако несколько предложений относятся к категории «Другие», поэтому s_theta_y не существует, и мы не должны принимать его во внимание для расчета. Чтобы справиться с таким случаем, у меня есть постоянный коэффициент 0, который отменяет этот термин и имеет один и тот же размерный вектор для отрицательного класса, я просто копирую случайное значение индекса, потому что в конце нас интересует только максимальное значение среди всех отрицательных классов (а не индекса).

Есть ли более эффективный способ вычисления этих терминов в функции потерь? У меня сложилось впечатление, что использование tf.gather с такой большой изменчивостью очень медленное.

ответ

1

Конечно, это похоже, что gather_nd - это то, что вы хотите, но пока не будут внедрены градиенты, я бы без колебаний использовал ваше решение reshape() так как reshape() практически свободен.

C++ implementation of the reshape() op похоже, что он много работает, но это всего лишь быстрая проверка ошибок информации о форме. «Работа» происходит в CopyFrom на линии 90, которая звучит, как будто она может быть дорогой, но на самом деле является просто копией указателя (CopyFrom вызывает CopyFromInternal, который копирует указатель).

Это имеет смысл: базовый буфер - это всего лишь плоский массив чисел в row-major order, и это упорядочение не зависит от информации о форме. По той же причине что-то вроде tf.transpose() будет требует копирования вообще.

Смежные вопросы