2016-07-22 2 views
0

Я работаю над проектом, который заключается в локализации объекта в изображении. Метод, который я собираюсь принять, основан на алгоритме локализации в CS231n-8.Как свести к минимуму две потери с помощью TensorFlow?

Структура сети имеет две головки оптимизации, головку классификации и регрессионную головку. Как я могу свести их к минимуму при обучении сети?

У меня есть одна идея, которая суммирует оба из них в одну потерю. Но проблема заключается в потере классификации: потеря softmax и потеря регрессии - потеря l2, что означает, что они имеют различный диапазон. Я не думаю, что это лучший способ.

ответ

2

Это зависит от статуса вашей сети.

Если ваша сеть просто способна извлекать функции [вы используете весы, хранящиеся в какой-либо другой сети], вы можете установить эти веса как константы, а затем отдельно подготовить две главы классификации, поскольку градиент не будет проходить через корыто константы.

Если вы не используете вес из заранее подготовленных модели, вы

  1. должно обучить сеть для извлечения функции: таким образом обучить сеть с использованием головки классификации и пусть поток градиента от классификации направляйте к первому сверточному фильтру. Таким образом, ваша сеть теперь может классифицировать объекты, объединяющие извлеченные функции.
  2. Преобразуйте в постоянные тензоры изученные веса сверточных фильтров и классификационную головку и подготовьте регрессионную головку.

Регрессионная головка научится совмещать функции, извлеченные из сверточного слоя, адаптируя его параметры, чтобы минимизировать потери L2.

Tl; др:

  1. Поезд сеть для классификации первой.
  2. Преобразование каждого полученного параметра в постоянный тензор, используя graph_util.convert_variables_to_constants, как показано в сценарии 'freeze_graph`.
  3. Поезд регрессионной головки.
+0

Отлично! Спасибо за Ваш ответ. Могу ли я задать еще один вопрос, как представить класс «фона» при обучении регрессионной головы? Я использую нули сейчас. Есть ли способ лучше? –

+2

Вы тренируете регрессионную головку: вам не нужен класс фона. Ваш набор поездов содержит положение ограничивающей рамки объекта, поэтому вы научитесь регрессировать эти координаты и ничего больше. Фон - это все, что находится за пределами этих координат. Я использую класс фона только в головке классификации, который не используется во время обучения (я показываю только изображения, содержащие объект), но используется, когда сеть запускается в режиме локализации. Если я предскажу некоторые координаты с головой регрессии, но глава классификации говорит мне, что это фон, я пропускаю местоположение – nessuno

+0

Это имеет смысл. Спасибо. –