2017-01-26 8 views
1

Я хочу рассчитать функцию потерь в моем DNN по-другому в зависимости от значения метки.Тензорный поток: использование значения в тензоре как параметр

Концептуально это что-то вроде этого:

def loss(logits, labels): 

    if labels[0] == 0: 
     return loss_function_1(logits, labels) 
    else: 
     return loss_function_2(logits, labels) 

Очевидно, что это не будет работать, потому что я не могу сделать это сравнение на объект тензором. Я также не могу использовать eval(), потому что я получаю сообщение об ошибке, что сеть не определена. У меня есть другой вариант?

ответ

1

Вы можете использовать tf.cond конструкцию для этого:

tf.cond(labels[0] == 0, lambda: loss_function_1(logits, labels), 
         lambda: loss_function_2(logits, labels)) 
Смежные вопросы