Просматривая TF Dev Summit talk о потенциальных (и очень экспериментальных) преимуществах включения XLA на графиках TensorFlow, я подумал, что немного экспериментирую с ним.Включить XLA при использовании tf.contrib.learn.Estimator
Вопрос: При использовании tf.contrib.learn.Estimator
, как включить JIT XLA?
Я флаг некоторые опс для JIT XLA по
with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
output = tf.add(input1, input2)
Однако документы предупреждают, что это означало только для тестирования. Я хотел бы быть в состоянии сделать это, используя рекомендуемый способ
# Config to turn on JIT compilation
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config)
, но я не могу понять, способ установки tf.Session
снаружи model_fn
.
Этот псевдо-код может прояснить проблему лучше:
def char_cnn_model(features, target, mode, params, model_dir):
# do stuff to build the CNN
...
return tf.contrib.learn.ModelFnOps(mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
def main(unused_argv):
# load training data, test data etc
...
classifier = learn.Estimator(
model_fn=char_cnn_model, # defined above
model_dir=model_dir,
params=params,
config=tf.contrib.learn.RunConfig(save_checkpoints_secs=60,
log_device_placement=True,
tf_random_seed=7))
classifier.fit(
X_train,
y_train,
steps=5000,
monitors=[validation_monitor]) # validation_monitor defined in main
accuracy_score = classifier.evaluate(x=X_test, y=y_test)
tf.contrib.learn.RunConfig
казался хорошим кандидатом, но это не есть что-то для сеанса (который я думаю, имеет смысл, почему бы RunConfig есть указатель на существующую сессию?)
Я мог бы переусердствовать, и tf.get_default_session
может быть всем, что мне нужно, но могу ли я изменить конфигурацию сеанса после его создания?