2

Я ознакомился с вводным документом TF Slim и, насколько я понимаю, он принимает только одну партию данных изображения при каждом запуске (32 изображения). Очевидно, кто-то хочет пройти через это и тренироваться для множества разных партий. Вступление не распространяется на это. Как это можно сделать правильно. Я предполагаю, что должен быть какой-то способ указать функцию пакетной загрузки, которая должна быть вызвана автоматически при запуске события пакетного обучения, но я не могу найти простой пример для этого во вступлении.Пакетное обучение в Tensorflow Slim

# Note that this may take several minutes. 

import os 

from datasets import flowers 
from nets import inception 
from preprocessing import inception_preprocessing 

slim = tf.contrib.slim 
image_size = inception.inception_v1.default_image_size 


def get_init_fn(): 
    """Returns a function run by the chief worker to warm-start the training.""" 
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] 

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 

    variables_to_restore = [] 
    for var in slim.get_model_variables(): 
     excluded = False 
     for exclusion in exclusions: 
      if var.op.name.startswith(exclusion): 
       excluded = True 
       break 
     if not excluded: 
      variables_to_restore.append(var) 

    return slim.assign_from_checkpoint_fn(
     os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 
     variables_to_restore) 


train_dir = '/tmp/inception_finetuned/' 

with tf.Graph().as_default(): 
    tf.logging.set_verbosity(tf.logging.INFO) 

    dataset = flowers.get_split('train', flowers_data_dir) 
    images, _, labels = load_batch(dataset, height=image_size, width=image_size) 

    # Create the model, use the default arg scope to configure the batch norm parameters. 
    with slim.arg_scope(inception.inception_v1_arg_scope()): 
     logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 

    # Specify the loss function: 
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 
    slim.losses.softmax_cross_entropy(logits, one_hot_labels) 
    total_loss = slim.losses.get_total_loss() 

    # Create some summaries to visualize the training process: 
    tf.scalar_summary('losses/Total Loss', total_loss) 

    # Specify the optimizer and create the train op: 
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 
    train_op = slim.learning.create_train_op(total_loss, optimizer) 

    # Run the training: 
    final_loss = slim.learning.train(
     train_op, 
     logdir=train_dir, 
     init_fn=get_init_fn(), 
     number_of_steps=2) 


print('Finished training. Last batch loss %f' % final_loss) 
+0

Не является ли функция load_batch неопределенной в примере кода u shared? Я не знаком с вашим примером, но я начал бы читать эту функцию, чтобы понять пакетный процесс. – pltrdy

+0

Здесь указывается https://github.com/tensorflow/models/blob/master/slim/slim_walkthough.ipynb Но это ничего не делает, кроме того, что вы получаете пакет. –

+0

Значит, вам просто нужно перебирать партии! – pltrdy

ответ

1

slim.learning.train функция содержит обучающий цикл, поэтому код, который вы дали действительно фактически поезд на нескольких серий изображений.

См. here in the source code, где train_step_fn вызывается внутри цикла while. train_step (значение по умолчанию train_step_fn) содержит строку sess.run([train_op, global_step]...), которая фактически запускает тренировочную операцию на одной партии изображений.

+0

Ну, я положил оператор печати в функцию load_batch и прошел обучение более чем на 1 шаг и обнаружил, что функция пакетной загрузки вызывается только один раз, поэтому означает, что одни и те же данные используются для нескольких шагов, поэтому проблема. –

+0

Кроме того, я не указал функцию load_batch в вызове learning.train, поэтому как бы он «знал» использовать это для загрузки новых партий? –

+0

Я сделал больше исследований в этом, и кажется, что есть очередь, которая устанавливается, откуда пакет автоматически загружается каждый раз. Чтобы проверить это, у меня есть связанный с этим вопрос: http://stackoverflow.com/questions/41868871/tensorflow-slim-debugging-during-training. Прокомментируйте, если это возможно. –

Смежные вопросы