Я строю классификатор изображений в TensorFlow, и в моих данных обучения есть дисбаланс классов. Поэтому при вычислении потерь мне нужно взвешивать потери для каждого класса за счет обратной частоты этого класса в данных обучения.TensorFlow: определение переменных в глобальном масштабе
Итак, вот мой код:
# Get the softmax from the final layer of the network
softmax = tf.nn.softmax(final_layer)
# Weight the softmax by the inverse frequency of the weights
weighted_softmax = tf.mul(softmax, class_weights)
# Compute the cross entropy
cross_entropy = -tf.reduce_sum(y_ * tf.log(softmax))
# Define the optimisation
train_step = tf.train.AdamOptimizer(1e-5).minimize(cross_entropy)
# Run the training
session.run(tf.initialize_all_variables())
for i in range(10000):
# Get the next batch
batch = datasets.train.next_batch(64)
# Run a training step
train_step.run(feed_dict = {x: batch[0], y_: batch[1]})
Мой вопрос: Могу ли я хранить class_weights
, как только в tf.constant(...)
в глобальном масштабе? Или мне нужно передать его в качестве параметра при вычислении cross_entropy
?
Причина, по которой мне интересно, заключается в том, что class_weights
отличается для каждой партии. Поэтому я обеспокоен тем, что, если он определен только в глобальной области, тогда, когда граф Tensor Flow сконструирован, он просто принимает начальные значения в class_weights
, а затем никогда не обновляет их. Если бы я должен был пройти class_weights
, используя feed_dict
при вычислении weighted_softmax
, то я прямо говорю Tensor Flow использовать последние обновленные значения в class_weights
.
Любая помощь будет оценена по достоинству. Благодаря!
Если я буду продолжать иметь минибазы с неуравновешенными классами, должен ли я по-прежнему взвешивать статистику глобального класса? Я бы подумал, что имеет смысл весить статистикой класса в этой мини-баре ...? Вычисляемые градиенты предназначены только для этой мини-бара, и поэтому они не должны заботиться о статистике глобального класса ... – Karnivaurus
Вы делаете хороший момент. Одно дело в том, что, когда мини-партии отбираются из более широкого распределения классов. Если вы должны были весить мини-партию, вес будет сходиться к весу более широкого распределения по полной итерации любым способом. Поэтому я думаю, что с учетом того, что эффективнее всего вычислять class_weights один раз. Хотел бы я найти какую-то цитату для этого, но пока не могу ... –