import org.springframework.core.io.ClassPathResource;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class TensorFlowModelLoader {
public Session loadModel(String modelPath) throws IOException {
// 读取模型文件
byte[] modelBytes = Files.readAllBytes(Path.of(modelPath));
// 加载TensorFlow模型
Graph graph = TensorFlow.createGraph();
Session session = new Session(graph);
session.run(modelBytes);
return session;
}
public float[] predict(Session session, float[] inputData) {
// 创建输入Tensor
try (Tensor<Float> inputTensor = Tensor.create(inputData)) {
// 运行模型进行预测
String[] outputNames = {"output"}; // 假设输出节点名为"output"
Tensor<Float> resultTensor = session.runner()
.feed("input", inputTensor) // 假设输入节点名为"input"
.fetch(outputNames)
.run()
.get(0).expect(Float.class);
// 获取预测结果
float[] result = resultTensor.copyTo(new float[10]); // 假设输出形状为[10]
resultTensor.close();
return result;
}
}
public static void main(String[] args) throws IOException {
TensorFlowModelLoader loader = new TensorFlowModelLoader();
Session session = loader.loadModel("path/to/your/model.pb");
float[] inputData = {0.1f, 0.2f, 0.3f}; // 示例输入数据
float[] prediction = loader.predict(session, inputData);
// 输出预测结果
for (float p : prediction) {
System.out.println(p);
}
// 关闭Session
session.close();
}
}
这段代码展示了如何在Spring Boot应用中加载TensorFlow模型并进行预测。首先,它定义了一个loadModel
方法来读取模型文件并创建一个TensorFlowSession
。predict
方法接受一个Session
和输入数据,创建输入Tensor,运行模型,并获取输出预测结果。最后,在main
方法中,我们加载模型,进行预测,并关闭Session
。