2017-02-01 9 views
0

Я уже сохраняю эту модель во время обучения, но мне тяжело ее загрузить и оценить.Восстановить и оценить модель Tensorflow

Я пробовал разные подходы, но мне не удалось загрузить сохраненную модель и оценить ее, чтобы получить ее предсказание над некоторым образцом, который является файлом изображения.

Может ли кто-нибудь помочь в этом? как я видел, это не так сложно, но я скучаю по этому поводу.

#!/usr/bin/python 
import tensorflow as tf 

BATCH_SIZE = 128 
NUM_EXAMPLES_PER_EPOCH = 50000 
VALIDATION_SIZE = 10000 
WIDTH = 128 
HEIGHT = 64 
CHANNELS = 3 
CLASSES = 10 
NUMBERS = 4 


def inference(inputs): 

    with tf.variable_scope("conv_pool_1"): 
     kernel = tf.get_variable(name="kernel", 
           shape=[5, 5, 3, 48], 
           initializer=tf.truncated_normal_initializer(stddev=0.05), 
           dtype=tf.float32) 
     biases = tf.get_variable(name="biases", 
           shape=[48], 
           initializer=tf.constant_initializer(value=0.), 
           dtype=tf.float32) 
     conv = tf.nn.conv2d(input=inputs, 
          filter=kernel, 
          strides=[1, 1, 1, 1], 
          padding="SAME") 
     conv_bias = tf.nn.bias_add(value=conv, 
            bias=biases, 
            name="add_biases") 
     relu = tf.nn.relu(conv_bias) 
     pool = tf.nn.max_pool(value=relu, 
           ksize=[1, 2, 2, 1], 
           strides=[1, 2, 2, 1], 
           padding="SAME", 
           name="pooling") 

    with tf.variable_scope("conv_pool_2"): 
     kernel = tf.get_variable(name="kernel", 
           shape=[5, 5, 48, 64], 
           initializer=tf.truncated_normal_initializer(stddev=0.05), 
           dtype=tf.float32) 
     biases = tf.get_variable(name="biases", 
           shape=[64], 
           initializer=tf.constant_initializer(value=0.), 
           dtype=tf.float32) 
     conv = tf.nn.conv2d(input=pool, 
          filter=kernel, 
          strides=[1, 1, 1, 1], 
          padding="SAME") 
     conv_bias = tf.nn.bias_add(value=conv, 
            bias=biases, 
            name="add_biases") 
     relu = tf.nn.relu(conv_bias) 
     pool = tf.nn.max_pool(value=relu, 
           ksize=[1, 2, 1, 1], 
           strides=[1, 2, 1, 1], 
           padding="SAME", 
           name="pooling") 
    with tf.variable_scope("conv_pool_3"): 
     kernel = tf.get_variable(name="kernel", 
           shape=[5, 5, 64, 128], 
           initializer=tf.truncated_normal_initializer(stddev=0.05), 
           dtype=tf.float32) 
     biases = tf.get_variable(name="biases", 
           shape=[128], 
           initializer=tf.constant_initializer(value=0.), 
           dtype=tf.float32) 
     conv = tf.nn.conv2d(input=pool, 
          filter=kernel, 
          strides=[1, 1, 1, 1], 
          padding="SAME") 
     conv_bias = tf.nn.bias_add(value=conv, 
            bias=biases, 
            name="add_biases") 
     relu = tf.nn.relu(conv_bias) 
     pool = tf.nn.max_pool(value=relu, 
           ksize=[1, 2, 2, 1], 
           strides=[1, 2, 2, 1], 
           padding="SAME", 
           name="pooling") 
    reshape = tf.reshape(pool, 
         shape=[BATCH_SIZE, -1], 
         name="reshape") 
    dims = reshape.get_shape().as_list()[-1] 
    with tf.variable_scope("fully_conn"): 
     weights = tf.get_variable(name="weights", 
            shape=[dims, 2048], 
            initializer=tf.truncated_normal_initializer(stddev=0.05), 
            dtype=tf.float32) 
     biases = tf.get_variable(name="biases", 
           shape=[2048], 
           initializer=tf.constant_initializer(value=0.), 
           dtype=tf.float32) 
     output = tf.nn.xw_plus_b(x=reshape, 
           weights=weights, 
           biases=biases) 
     conn = tf.nn.relu(output) 
    with tf.variable_scope("output"): 
     weights = tf.get_variable(name="weights", 
            shape=[2048, NUMBERS * CLASSES], 
            initializer=tf.truncated_normal_initializer(stddev=0.05), 
            dtype=tf.float32) 
     biases = tf.get_variable(name="biases", 
           shape=[NUMBERS * CLASSES], 
           initializer=tf.constant_initializer(value=0.), 
           dtype=tf.float32) 
     logits = tf.nn.xw_plus_b(x=conn, 
           weights=weights, 
           biases=biases) 
     reshape = tf.reshape(logits, shape=[BATCH_SIZE, NUMBERS, CLASSES]) 
    return reshape 


def loss(logits, labels): 
    cross_entropy_per_number = tf.nn.softmax_cross_entropy_with_logits(logits, labels) 
    cross_entropy = tf.reduce_mean(cross_entropy_per_number) 
    tf.add_to_collection("loss", cross_entropy) 
    return cross_entropy 


def evaluation(logits, labels): 
    prediction = tf.argmax(logits, 2) 
    actual = tf.argmax(labels, 2) 
    equal = tf.equal(prediction, actual) 
    # equal = tf.reduce_all(equal, 1) 
    accuracy = tf.reduce_mean(tf.cast(equal, tf.float32), name="accuracy") 
    return accuracy 


def train(loss, learning_rate=0.00001): 
    optimizer = tf.train.GradientDescentOptimizer(learning_rate) 
    train_op = optimizer.minimize(loss) 
    return train_op 

ответ

0

Как вы это можете сохранить? Вы пробовали: (для экономии)

saver = tf.train.Saver() 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
saver.save(sess, 'my-model') 

(для загрузки)

sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my-model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 

Официальный ссылка на это: https://www.tensorflow.org/versions/master/api_docs/python/state_ops/exporting_and_importing_meta_graphs (или заменить номер версии, такие как r0.12 для master в URL).

0

Теперь я загружая его правильно Теперь

saver = tf.train.import_meta_graph('model/model.ckpt.meta') 

init = tf.group(tf.initialize_all_variables(), 
       tf.initialize_local_variables()) 
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 

saver.restore(sess, 'model/model.ckpt') 

Теперь я пытаюсь получить предсказание, я думаю, что это что-то вроде этого, но я не знаю, как я буду получать переменные из модели, Я создал ранее, чтобы получить его прогноз:

prediction=tf.argmax(y_conv,1) 
prediction.eval(feed_dict={x: [imvalue],keep_prob: 1.0}, session=sess) 
Смежные вопросы