2016-04-04 2 views
1

Я строю классификатор изображений в TensorFlow, и в моих данных обучения есть дисбаланс классов. Поэтому при вычислении потерь мне нужно взвешивать потери для каждого класса за счет обратной частоты этого класса в данных обучения.TensorFlow: определение переменных в глобальном масштабе

Итак, вот мой код:

# Get the softmax from the final layer of the network 
softmax = tf.nn.softmax(final_layer) 
# Weight the softmax by the inverse frequency of the weights 
weighted_softmax = tf.mul(softmax, class_weights) 
# Compute the cross entropy 
cross_entropy = -tf.reduce_sum(y_ * tf.log(softmax)) 
# Define the optimisation 
train_step = tf.train.AdamOptimizer(1e-5).minimize(cross_entropy) 

# Run the training 
session.run(tf.initialize_all_variables()) 
for i in range(10000): 
    # Get the next batch 
    batch = datasets.train.next_batch(64) 
    # Run a training step 
    train_step.run(feed_dict = {x: batch[0], y_: batch[1]}) 

Мой вопрос: Могу ли я хранить class_weights, как только в tf.constant(...) в глобальном масштабе? Или мне нужно передать его в качестве параметра при вычислении cross_entropy?

Причина, по которой мне интересно, заключается в том, что class_weights отличается для каждой партии. Поэтому я обеспокоен тем, что, если он определен только в глобальной области, тогда, когда граф Tensor Flow сконструирован, он просто принимает начальные значения в class_weights, а затем никогда не обновляет их. Если бы я должен был пройти class_weights, используя feed_dict при вычислении weighted_softmax, то я прямо говорю Tensor Flow использовать последние обновленные значения в class_weights.

Любая помощь будет оценена по достоинству. Благодаря!

ответ

1

Я думаю, что class_weights a tf.constant в порядке. Классовое взвешивание должно выполняться для всего набора данных, а не для мини-партии.

Другой подход к этому, который вы, возможно, захотите рассмотреть, - это выборка, чтобы каждая партия имела равные числа каждого класса?

+0

Если я буду продолжать иметь минибазы с неуравновешенными классами, должен ли я по-прежнему взвешивать статистику глобального класса? Я бы подумал, что имеет смысл весить статистикой класса в этой мини-баре ...? Вычисляемые градиенты предназначены только для этой мини-бара, и поэтому они не должны заботиться о статистике глобального класса ... – Karnivaurus

+1

Вы делаете хороший момент. Одно дело в том, что, когда мини-партии отбираются из более широкого распределения классов. Если вы должны были весить мини-партию, вес будет сходиться к весу более широкого распределения по полной итерации любым способом. Поэтому я думаю, что с учетом того, что эффективнее всего вычислять class_weights один раз. Хотел бы я найти какую-то цитату для этого, но пока не могу ... –

Смежные вопросы