2017-02-01 3 views
3

Я написал следующий класс сверточной нейронной сети (CNN) в Tensorflow [Я попытался пропустить некоторые строки кода для ясности.]загрузить несколько моделей в Tensorflow

class CNN: 
def __init__(self, 
       num_filters=16,  # initial number of convolution filters 
      num_layers=5,   # number of convolution layers 
      num_input=2,   # number of channels in input 
      num_output=5,   # number of channels in output 
      learning_rate=1e-4, # learning rate for the optimizer 
      display_step = 5000, # displays training results every display_step epochs 
      num_epoch = 10000,  # number of epochs for training 
      batch_size= 64,  # batch size for mini-batch processing 
      restore_file=None,  # restore file (default: None) 

      ): 

       # define placeholders 
       self.image = tf.placeholder(tf.float32, shape = (None, None, None,self.num_input)) 
       self.groundtruth = tf.placeholder(tf.float32, shape = (None, None, None,self.num_output)) 

       # builds CNN and compute prediction 
       self.pred = self._build() 

       # I have already created a tensorflow session and saver objects 
       self.sess = tf.Session() 
       self.saver = tf.train.Saver() 

       # also, I have defined the loss function and optimizer as 
       self.loss = self._loss_function() 
       self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss) 

       if restore_file is not None: 
        print("model exists...loading from the model") 
        self.saver.restore(self.sess,restore_file) 
       else: 
        print("model does not exist...initializing") 
        self.sess.run(tf.initialize_all_variables()) 

def _build(self): 
    #builds CNN 

def _loss_function(self): 
    # computes loss 


# 
def train(self, train_x, train_y, val_x, val_y): 
    # uses mini batch to minimize the loss 
    self.sess.run(self.optimizer, feed_dict = {self.image:sample, self.groundtruth:gt}) 


    # I save the session after n=10 epochs as: 
    if epoch%n==0: 
     self.saver.save(sess,'snapshot',global_step = epoch) 

# finally my predict function is 
def predict(self, X): 
    return self.sess.run(self.pred, feed_dict={self.image:X}) 

Я тренировался два CNNs для двух отдельных задач, независимо друг от друга. Каждый из них занял около 1 дня. Скажем, model1 и model2 сохраняются как «snapshot-model1-10000» и «snapshot-model2-10000» (с соответствующими метафайлами) соответственно. Я могу тестировать каждую модель и вычислять ее производительность отдельно.

Теперь я хочу загрузить эти две модели в один скрипт. Я бы естественно попытаться сделать, как показано ниже:

cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........) 
cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........) 

я сталкиваюсь ошибку [сообщение об ошибке долго. Я просто скопировал/наклеил фрагмент.]

NotFoundError: Tensor name "Variable_26/Adam_1" not found in checkpoint files /home/amitkrkc/codes/A549_models/snapshot-hela-95000 
    [[Node: save_1/restore_slice_85 = RestoreSlice[dt=DT_FLOAT, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/restore_slice_85/tensor_name, save_1/restore_slice_85/shape_and_slice)]] 

Есть ли способ, чтобы загрузить из этих двух файлов два отдельных CNNs? Любое предложение/комментарий/отзывы приветствуются.

Спасибо,

ответ

4

Да, есть. Используйте отдельные графики.

g1 = tf.Graph() 
g2 = tf.Graph() 

with g1.as_default(): 
    cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........) 
with g2.as_default(): 
    cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........) 

EDIT:

Если вы хотите их в одном графике. Вам придется переименовать некоторые переменные. Одна из идей есть каждый CNN в отдельном объеме, и пусть заставки переменных ручек в этой области видимости, например:

saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), scope='model1') 

и CNN обернуть всю свою конструкцию в сфере:

with tf.variable_scope('model1'): 
    ... 

edit2:

Другого идея переименовывает переменные, которые управляет хранителем (поскольку я предполагаю, что вы хотите использовать сохраненные контрольные точки, не переучивая все. Сохранение позволяет использовать разные имена переменных в графике и в контрольной точке, посмотреть документацию для инициализации.

+0

спасибо. Ваше первое предложение отлично подходит для моего дела. – Amit

0

Я столкнулся с той же проблемой и не смог решить проблему (без переподготовки) с любым решением, которое я нашел в Интернете. Так что я сделал загрузку каждой модели в два отдельных потока, которые общаются с основным потоком. Это достаточно просто, чтобы написать код, вам просто нужно быть осторожным при синхронизации потоков. В моем случае каждый поток получил вход для своей проблемы и вернулся к основному потоку вывода. Он работает без каких-либо заметных накладных расходов.

Смежные вопросы