Моя реализация выглядит следующим образом, но я думаю, что должны быть более эффективные реализации.
логит-анализ: непересчитанных оценки, тензор, форма = (batch_size, n_classes)
этикетки: тензор, форма = (batch_size,)
batch_size, n_classes: Int
def multi_class_hinge_loss(logits, label, batch_size, n_classes):
# get the correct logit
flat_logits = tf.reshape(logits, (-1,))
correct_id = tf.range(0, batch_size) * n_classes + label
correct_logit = tf.gather(flat_logits, correct_id)
# get the wrong maximum logit
max_label = tf.argmax(logits, 1)
top2, _ = tf.nn.top_k(logits, k=2, sorted=True)
top2 = tf.split(1, 2, top2)
for i in xrange(2):
top2[i] = tf.reshape(top2[i], (batch_size,))
wrong_max_logit = tf.select(tf.equal(max_label, label), top2[1], top2[0])
# calculate multi-class hinge loss
return tf.reduce_mean(tf.maximum(0., 1. + wrong_max_logit - correct_logit))
какая версия TF у вас есть? В последней версии 'top_k' имеет градиенты –