2015-12-28 5 views
8

Я пытаюсь реализовать предложение из ответов: Tensorflow: how to save/restore a model?tensorflow: сохранение и восстановление сеанса

У меня есть объект, который оборачивает в tensorflow модели в sklearn стиле.

import tensorflow as tf 
class tflasso(): 
    saver = tf.train.Saver() 
    def __init__(self, 
       learning_rate = 2e-2, 
       training_epochs = 5000, 
        display_step = 50, 
        BATCH_SIZE = 100, 
        ALPHA = 1e-5, 
        checkpoint_dir = "./", 
      ): 
     ... 

    def _create_network(self): 
     ... 


    def _load_(self, sess, checkpoint_dir = None): 
     if checkpoint_dir: 
      self.checkpoint_dir = checkpoint_dir 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      self.saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

    def fit(self, train_X, train_Y , load = True): 
     self.X = train_X 
     self.xlen = train_X.shape[1] 
     # n_samples = y.shape[0] 

     self._create_network() 
     tot_loss = self._create_loss() 
     optimizer = tf.train.AdagradOptimizer(self.learning_rate).minimize(tot_loss) 

     # Initializing the variables 
     init = tf.initialize_all_variables() 
     " training per se" 
     getb = batchgen(self.BATCH_SIZE) 

     yvar = train_Y.var() 
     print(yvar) 
     # Launch the graph 
     NUM_CORES = 3 # Choose how many cores to use. 
     sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, 
                  intra_op_parallelism_threads=NUM_CORES) 
     with tf.Session(config= sess_config) as sess: 
      sess.run(init) 
      if load: 
       self._load_(sess) 
      # Fit all training data 
      for epoch in range(self.training_epochs): 
       for (_x_, _y_) in getb(train_X, train_Y): 
        _y_ = np.reshape(_y_, [-1, 1]) 
        sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) 
       # Display logs per epoch step 
       if (1+epoch) % self.display_step == 0: 
        cost = sess.run(tot_loss, 
          feed_dict={ self.vars.xx: train_X, 
            self.vars.yy: np.reshape(train_Y, [-1, 1])}) 
        rsq = 1 - cost/yvar 
        logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) 
        print(logstr) 
        self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', 
         global_step= 1+ epoch) 

      print("Optimization Finished!") 
     return self 

Когда я бегу:

tfl = tflasso() 
tfl.fit(train_X, train_Y , load = False) 

Я получаю выход:

Epoch: 50 cost = 38.4705 R^2 = -1.2036 
    b1: 0.118122 
Epoch: 100 cost = 26.4506 R^2 = -0.5151 
    b1: 0.133597 
Epoch: 150 cost = 22.4330 R^2 = -0.2850 
    b1: 0.142261 
Epoch: 200 cost = 20.0361 R^2 = -0.1477 
    b1: 0.147998 

Однако, когда я пытаюсь восстановить параметры (даже не убивая объект): tfl.fit(train_X, train_Y , load = True)

Я получаю странные результаты. Прежде всего, загруженное значение не соответствует сохраненному.

loading a session 
loaded b1: 0.1   <------- Loaded another value than saved 
Epoch: 50 cost = 30.8483 R^2 = -0.7670 
    b1: 0.137484 

Каков правильный способ загрузки и, возможно, сначала проверить сохраненные переменные?

+0

документация о тензорном потоке лишена довольно простых примеров, вам нужно копать в папках с примерами и понимать их в основном на своем собственном – diffeomorphism

ответ

10

TL; DR: Вы должны попытаться переработать этот класс так, что self.create_network() называется (я) только один раз, и (б) перед tf.train.Saver() построен.

Здесь есть две тонкие проблемы, связанные с структурой кода и поведением tf.train.Saver constructor по умолчанию. Когда вы строите заставку без аргументов (как в вашем коде), она собирает текущий набор переменных в вашей программе и добавляет ops к графу для сохранения и восстановления. В вашем коде, когда вы вызываете tflasso(), он построит заставку, и переменных не будет (потому что create_network() еще не был вызван). В результате контрольная точка должна быть пустой.

Вторая проблема заключается в том, что — по умолчанию — Формат сохраненной контрольной точки - это карта от name property of a variable к ее текущему значению. Если вы создаете две переменные с тем же именем, они будут автоматически «uniquified» по TensorFlow:

v = tf.Variable(..., name="weights") 
assert v.name == "weights" 
w = tf.Variable(..., name="weights") 
assert v.name == "weights_1" # The "_1" is added by TensorFlow. 

Следствием этого является то, что, когда вы звоните self.create_network() во втором вызове tfl.fit(), переменные будут иметь разные имена из имен, которые хранятся в контрольной точке —, или были бы, если бы заставка была построена после сети. (Вы можете избежать этого, передавая Name- Variable словарь конструктору заставки, но это, как правило, довольно неудобно.)

Есть два основных пути решения проблемы:

  1. В каждом вызове tflasso.fit(), создать всю модель заново, определив новый tf.Graph, затем в этом графике построим сеть и создадим tf.train.Saver.

  2. РЕКОМЕНДУЕТСЯ Создание сети, то tf.train.Saver в tflasso конструктора, и повторно использовать этот график на каждый вызов tflasso.fit().Обратите внимание, что вам, возможно, потребуется сделать еще одну работу по реорганизации вещей (в частности, я не уверен, что вы делаете с self.X и self.xlen), но это должно быть возможно с помощью placeholders и подачи.

+0

спасибо! 'Xlen' используется в' self._create_network() 'для установки размера ввода' X' (placeholder init: 'self.vars.xx = tf.placeholder (" float ", shape = [None, self.xlen ]) '). Из того, что вы говорите, предпочтительным способом является передача 'xlen' в инициализатор. –

+0

Есть ли способ сбросить значения uniquifier/clear old tf при повторной инициализации объекта? –

+1

Для этого вам нужно создать новый 'tf.Graph' и сделать его по умолчанию, прежде чем вы (i) создадите сеть и (ii) сделаете' Saver'. Если вы оберните тело 'tflasso.fit()' в 'с помощью tf.Graph(). As_default():' block и переместите конструкцию 'Saver' внутри этого блока, имена должны быть одинаковыми каждый раз, когда вы вызовите 'fit()'. – mrry

Смежные вопросы