2016-08-14 4 views
7

Как я могу читать переменные и их состояния с контрольной точки?Tensorflow. Перечислить переменные в контрольной точке

Я работаю с автокодировщиками, и моя контрольная точка содержит полное состояние сети, то есть кодировщик, декодер, оптимизатор и т. Д. Я хочу обманывать кодировки и, следовательно, в моем режиме оценки потребуется только часть декодера сети.

Тот же вопрос более абстрактным образом: как я могу читать только определенные переменные из существующей контрольной точки для повторного использования в другой модели?

Должен ли я просто назвать свою переменную соответственно? Или есть способ получить что-то вроде:

w_init = read_from_state(state_location, var_name) 

def read_from_state(state_location, var_name): 
    # the magic goes here 
    pass 

ответ

14

Там в list_variables метод checkpoint_utils.py который позволяет просматривать все сохраненные переменные.

Тем не менее, для вашего использования может быть проще восстановить с помощью Saver. Если вы знаете имена переменных при сохранении контрольной точки, вы можете создать новую заставку и сообщить ей инициализировать эти имена в новые объекты Variable (возможно, с разными именами). Это используется в примере CIFAR для выбора восстановления subset of variables. См Choosing which Variables to Save and Restore в Howto

0

Другим способом, который будет печатать все контрольные точки тензоров (или только один, если он указано) вместе с их содержанием:

from tensorflow.python.tools import inspect_checkpoint as inch 
inch.print_tensors_in_checkpoint_file('path/to/ckpt', '', True) 
""" 
Args: 
    file_name: Name of the checkpoint file. 
    tensor_name: Name of the tensor in the checkpoint file to print. 
    all_tensors: Boolean indicating whether to print all tensors. 
""" 

Он всегда будет печатать содержание тензора.

И, в то время как мы в этом, вот как использовать checkpoint_utils, предложенный предыдущим ответом:

from tensorflow.contrib.framework.python.framework import checkpoint_utils 
    var_list = checkpoint_utils.list_variables('path/to/ckpt') 
    for v in var_list: print(v)