2016-03-20 2 views
11

я экономлю свое состояние сеанса, как так:Как получить global_step при восстановлении контрольных точек в Tensorflow?

self._saver = tf.saver() 
self._saver.save(self._session, '/network', global_step=self._time) 

Когда я позже восстановить, я хочу, чтобы получить значение global_step на контрольно-пропускном пункте я восстановить из. Это делается для того, чтобы установить из него некоторые гиперпараметры.

Хакерный способ сделать это - пропустить и проанализировать имена файлов в каталоге контрольной точки. Но угрюмый должен быть лучше, встроенный способ сделать это?

ответ

17

Общая картина должна иметь global_step переменную для отслеживания шагов

global_step = tf.Variable(0, name='global_step', trainable=False) 
train_op = optimizer.minimize(loss, global_step=global_step) 

Затем вы можете сохранить с

saver.save(sess, save_path, global_step=global_step) 

При восстановлении, значение global_step восстанавливается, а

+1

Это не работает, каждый раз, когда я возобновляют обучение переменной global_step сбрасывается в 0 –

+0

это будет означать, что global_step вы экономите до контрольного пункта 0 , или вы повторно инициализируете его до 0 после его восстановления –

+0

Это было бы хорошим решением, но если 'saver.restore' может вернуть global_step, это было бы просто. Мы можем просто сделать «global_step = saver.restore (...)» Как вы думаете, команда tenorflow может быть заинтересована в этом направлении? –

0

Текущая версия 0.10rc0 выглядит по-другому, нет tf.saver(). Теперь это tf.train.Saver(). Кроме того, команда save добавляет информацию к имени файла save_path для global_step, поэтому мы не можем просто вызвать восстановление на том же пути save_path, поскольку это не фактический файл сохранения.

Самый простой способ я вижу прямо сейчас, чтобы использовать SessionManager вместе с заставки, как это:

my_checkpoint_dir = "/tmp/checkpoint_dir" 
# make a saver to use with SessionManager for restoring 
saver = tf.train.Saver() 
# Build an initialization operation to run below. 
init = tf.initialize_all_variables() 
# use a SessionManager to help with automatic variable restoration 
sm = tf.train.SessionManager() 
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored 
# if no such checkpoint, then call the init_op after creating a new session 
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir)) 

Вот так. Теперь у вас есть сеанс, который либо восстановлен из my_checkpoint_dir (убедитесь, что каталог существует до его вызова), либо если там нет контрольной точки, он создает новый сеанс и вызывает init_op для инициализации ваших переменных.

Если вы хотите сохранить, вы просто сохраните любое имя, которое вы хотите в этом каталоге, и передайте файл global_step. Вот пример, когда я сохраняю переменную шага в цикле как global_step, поэтому он возвращается к этой точке если убить программу и перезапустить его, чтобы он восстанавливает блокпост:.

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt') 
saver.save(sess, checkpoint_path, global_step=step) 

Это создает файлы в my_checkpoint_dir как «model.ckpt-1000», где 1000 является global_step прошло в случае, если он продолжает работать, то вы получите больше похоже на «model.ckpt-2000». При запуске программы SessionManager поднимает последнюю из них. Путь checkpoint_path может быть любым желаемым именем файла, если он находится в checkpoint_dir. Save() создаст этот файл с добавлением global_step (как показано выше). Он также создает индексный файл «контрольной точки», в результате которого SessionManager обнаруживает последнюю контрольную точку сохранения.

2

Это немного рубить, но и другие ответы не работает для меня на всех

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename 
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) 

Update 9/2017

Я не уверен, если это начал работать из-за для обновления, но следующий метод, по-видимому, эффективен в том, чтобы global_step обновлялся и загружался должным образом:

Создать два ops.Один провести global_step и другое, чтобы увеличить его:

global_step = tf.Variable(0, trainable=False, name='global_step') 
    increment_global_step = tf.assign_add(global_step,1, 
              name = 'increment_global_step') 

Теперь в цикле обучения работать приращению Op каждый раз, когда вы запускаете свой тренировочный цит.

sess.run([train_op,increment_global_step],feed_dict=feed_dict) 

Если вы когда-либо хотите получить вам глобальное значение шага как целое число в любой момент, просто используйте следующую команду после загрузки модели:

sess.run(global_step) 

Это может быть полезно для создания имен файлов или вычисление какова ваша текущая эпоха, не имея второго тензорного потока Variable для хранения этого значения. Например, расчет текущей эпохи при загрузке будет что-то вроде:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records) 
1

Я была такая же проблема, как Лоуренс Ду, я не мог найти способ, чтобы получить global_step путем восстановления модели. Поэтому я применил his hack к the inception v3 training code in the Tensorflow/models github repo Я использую. В приведенном ниже коде также содержится исправление, связанное с pretrained_model_checkpoint_path.

Если у вас есть лучшее решение или знаете, что мне не хватает, оставьте комментарий!

В любом случае, этот код работает для меня:

... 

# When not restoring start at 0 
last_step = 0 
if FLAGS.pretrained_model_checkpoint_path: 
    # A model consists of three files, use the base name of the model in 
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500 
    # 
    # Because we need to give the base name you can't assert (will always fail) 
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) 

    variables_to_restore = tf.get_collection(
     slim.variables.VARIABLES_TO_RESTORE) 
    restorer = tf.train.Saver(variables_to_restore) 
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path) 
    print('%s: Pre-trained model restored from %s' % 
      (datetime.now(), FLAGS.pretrained_model_checkpoint_path)) 

    # HACK : global step is not restored for some unknown reason 
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1]) 

    # assign to global step 
    sess.run(global_step.assign(last_step)) 

... 

for step in range(last_step + 1, FLAGS.max_steps): 

    ... 
0

просто отметить мое решение о сохранении глобального шага и восстановления.

Сохранить:

global_step = tf.Variable(0, trainable=False, name='global_step') 
saver.save(sess, model_path + model_name, global_step=_global_step) 

Восстановление:

if os.path.exists(model_path): 
    saver.restore(sess, tf.train.latest_checkpoint(model_path)) 
    print("Model restore finished, current globle step: %d" % global_step.eval()) 
Смежные вопросы