2017-01-18 4 views
2

У меня была переобученная начальная модель для моего собственного набора данных. Модель Tho построена на python, и теперь у меня есть сохраненный граф как файл .pb и файл метки как .txt. Теперь мне нужно предсказать использование этой модели для изображения через java. Может кто-нибудь, пожалуйста, помогите мнеЗапуск модели tensorflow, написанной на питоне для обучения и прогнозирования из java

ответ

3

Команда TensorFlow разрабатывает интерфейс Java, но пока не стабилен. Вы можете найти существующий код здесь: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java и следить за его развитием здесь https://github.com/tensorflow/tensorflow/issues/5. Вы можете посмотреть GraphTest.java, SessionTest.java и TensorTest.java, чтобы узнать, как он используется (хотя, как объясняется, это может измениться в будущем). В принципе, вам нужно загрузить двоичный сохраненный график в объект Graph, создать с ним Session и запустить его с соответствующими значениями (как Tensor s), чтобы получить List<Tensor> с выходом. Соединенный из примеров, в источнике:

import java.nio.file.Files; 
import java.nio.file.Paths; 
import org.tensorflow.Graph; 
import org.tensorflow.Session; 
import org.tensorflow.Tensor; 

try (Graph graph = new Graph()) { 
    graph.importGraphDef(Files.readAllBytes(Paths.get("saved_model.pb")); 
    try (Session sess = new Session(graph)) { 
     try (Tensor x = Tensor.create(1.0f); 
      Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 
      System.out.println(y.floatValue()); 
     } 
    } 
} 
0

код, который я использовал, который работал читать protobuf файл, заканчивая .pb.

try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) { 
    Session sess = b.session(); 
    ... 
    float[][]matrix = sess.runner() 
     .feed("x", input) 
     .feed("keep_prob", keep_prob) 
     .fetch("y_conv") 
     .run() 
     .get(0) 
     .copyTo(new float[1][10]); 
    ... 
} 

код Python Я использовал, чтобы сохранить его было:

signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'x': tf.saved_model.utils.build_tensor_info(x)}, 
    outputs = {'y_conv': tf.saved_model.utils.build_tensor_info(y_conv)}, 
) 
    builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model") 
    builder.add_meta_graph_and_variables(sess, 
     [tf.saved_model.tag_constants.SERVING], 
     signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} 
    ) 
    builder.save() 
Смежные вопросы