
Java调用OnnxRuntime推理(纯代码)
测试onxx model和图片涉及到了涉密,所以就不提供了。如果本地环境有问题或者需要搭建本地环境或者部署环境可以参考:搭建opencv和JavaOnnxRuntime环境
代码里面清楚描述的案例中使用的onnx model结构,使用前请使用下面网站看看结构是否一致,否则可能会导致代码异常。\
下面代码可能有部分包maven没有导入,自行导入一下。
Maven依赖
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<javacv.version>1.5.9</javacv.version>
<system.windowsx64>windows-x86_64</system.windowsx64>
<system.liunx64>linux-x86_64</system.liunx64>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime -->
<!-- <dependency>-->
<!-- <groupId>com.microsoft.onnxruntime</groupId>-->
<!-- <artifactId>onnxruntime</artifactId>-->
<!-- <version>1.17.3</version>-->
<!-- </dependency>-->
<!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime_gpu -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.17.3</version>
</dependency>
<dependency>
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy</artifactId>
</dependency>
<!-- javacv+javacpp核心库-->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv</artifactId>
<version>${javacv.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacpp-platform</artifactId>
<version>${javacv.version}</version>
</dependency>
<!-- 最小opencv依赖包 ,必须包含上面的javacv+javacpp -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>opencv</artifactId>
<version>4.7.0-${javacv.version}</version>
<classifier>${system.windowsx64}</classifier>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>opencv</artifactId>
<version>4.7.0-${javacv.version}</version>
<classifier>${system.liunx64}</classifier>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>openblas</artifactId>
<version>0.3.23-${javacv.version}</version>
<classifier>${system.windowsx64}</classifier>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>openblas</artifactId>
<version>0.3.23-${javacv.version}</version>
<classifier>${system.liunx64}</classifier>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>flycapture</artifactId>
<version>2.13.3.31-${javacv.version}</version>
<classifier>${system.windowsx64}</classifier>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>flycapture</artifactId>
<version>2.13.3.31-${javacv.version}</version>
<classifier>${system.liunx64}</classifier>
</dependency>
</dependencies>
Java代码
实体类
OnnxModelHolder
/**
* Onnx模型持有者
*
* @since 2024/5/6 13:59
* @author JunPzx
*/
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
@Data
public class OnnxModelHolder {
private OrtEnvironment env;
private OrtSession session;
private JSONObject tags;
private long count;
private long channels;
private long netHeight;
private long netWidth;
}
PredictResult
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;
/**
* 推测结果对象
*
* @since 2024/5/6 17:17
* @author JunPzx
*/
@NoArgsConstructor
@Data
@AllArgsConstructor
@Accessors(chain = true)
public class PredictResult {
/**
* y轴最小值
*/
private Float ymin;
/**
* x轴最小值
*/
private Float xmin;
/**
* y轴最大值
*/
private Float ymax;
/**
* x轴最大值
*/
private Float xmax;
/**
* 概率
*/
private Float percentage;
/**
* 置信度
*/
private Float confidence;
/**
* 类别
*/
private String name;
}
逻辑代码
OnnxModelLoader
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import com.junpzx.linglong.modelruntime.domain.OnnxModelHolder;
import lombok.extern.slf4j.Slf4j;
import org.opencv.core.Core;
import java.util.Map;
/**
* onnx模型加载器
*
* @since 2024/5/6 13:57
* @author JunPzx
*/
@Slf4j
public class OnnxModelLoader {
public OnnxModelHolder loadModel(String modelPath) throws Exception {
return loadModel(modelPath, false, null);
}
public OnnxModelHolder loadModel(String modelPath, int gpuDeviceId) throws Exception {
return loadModel(modelPath, true, gpuDeviceId);
}
private OnnxModelHolder loadModel(String modelPath, Boolean enableGpu, Integer gpuDeviceId) throws Exception {
OnnxModelHolder onnxModelHolder = new OnnxModelHolder();
onnxModelHolder.setEnv(OrtEnvironment.getEnvironment());
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
if (enableGpu) {
try {
// The GPU device ID to execute on
sessionOptions.addCUDA(gpuDeviceId);
} catch (Exception e) {
log.error("加载Onnx模型指定GPU CUDA发生错误", e);
}
}
onnxModelHolder.setSession(onnxModelHolder.getEnv().createSession(modelPath, sessionOptions));
OnnxModelMetadata metadata = onnxModelHolder.getSession().getMetadata();
Map<String, NodeInfo> infoMap = onnxModelHolder.getSession().getInputInfo();
TensorInfo nodeInfo = (TensorInfo) infoMap.get("images").getInfo();
// todo 临时方案,后续根据实际情况获取模型的标签数据
onnxModelHolder.setTags(new JSONObject() {
{
put("0", "miner");
put("1", "helmet");
put("2", "head");
}
});
// 从模型中读取标签数据
// String nameClass = metadata.getCustomMetadata().get("names");
// onnxModelHolder.setTags(JSONObject.parseObject(nameClass.replace("\"", "\"\"")));
// 打印解析到的数据
printlnInfo(metadata, infoMap, nodeInfo, onnxModelHolder.getTags());
//1 模型每次处理一张图片
onnxModelHolder.setCount(nodeInfo.getShape()[0]);
//3 模型通道数
onnxModelHolder.setChannels(nodeInfo.getShape()[1]);
//640 模型高
onnxModelHolder.setNetHeight(nodeInfo.getShape()[2]);
//640 模型宽
onnxModelHolder.setNetWidth(nodeInfo.getShape()[3]);
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
return onnxModelHolder;
}
private void printlnInfo(OnnxModelMetadata metadata, Map<String, NodeInfo> infoMap, TensorInfo nodeInfo,
JSONObject names) {
System.out.println("-------打印模型信息开始--------");
System.out.println("getProducerName=" + metadata.getProducerName());
System.out.println("getGraphName=" + metadata.getGraphName());
System.out.println("getDescription=" + metadata.getDescription());
System.out.println("getDomain=" + metadata.getDomain());
System.out.println("getVersion=" + metadata.getVersion());
System.out.println("getCustomMetadata=" + metadata.getCustomMetadata());
System.out.println("getInputInfo=" + infoMap);
System.out.println("nodeInfo=" + nodeInfo);
System.out.println("类别信息:" + names);
System.out.println("-------打印模型信息结束--------");
}
}
PredictHandler
package com.xatl.linglong.modelruntime.service;
import com.xatl.linglong.modelruntime.domain.PredictResult;
import java.awt.image.BufferedImage;
import java.io.InputStream;
import java.util.List;
/**
* 预测处理器
*
* @author JunPzx
* @since 2024/5/6 14:27
*/
public interface PredictHandler {
/**
* 预测
*
* @param targetPath 目标路径
* @return 预测结果
*/
List<PredictResult> predict(String targetPath) throws Exception;
/**
* 预测
*
* @param is 是
* @return {@link List }<{@link PredictResult }>
* @throws Exception 例外
*/
List<PredictResult> predict(InputStream is) throws Exception;
/**
* 预测
*
* @param bufferedImage 缓冲图像
* @return {@link List }<{@link PredictResult }>
* @throws Exception
*/
List<PredictResult> predict(BufferedImage bufferedImage) throws Exception;
}
OnnxModelPredictHandler
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.junpzx.linglong.modelruntime.domain.OnnxModelHolder;
import com.junpzx.linglong.modelruntime.domain.PredictResult;
import lombok.extern.slf4j.Slf4j;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.io.File;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.util.*;
/**
* onnx 模型预测处理器
*
* @author JunPzx
* @since 2024/5/6 14:29
*/
@Slf4j
public class OnnxModelPredictHandler implements PredictHandler {
private final float DEFAULT_NMS_THRESHOLD = 0.45f;
private final float DEFAULT_CONFIDENCE_THRESHOLD = 0.25f;
private final OnnxModelHolder onnxModelHolder;
public OnnxModelPredictHandler(OnnxModelHolder onnxModelHolder) {
this.onnxModelHolder = onnxModelHolder;
}
public OnnxModelPredictHandler(String onnxModelPath) throws Exception {
this.onnxModelHolder = new OnnxModelLoader().loadModel(onnxModelPath);
}
@Override
public List<PredictResult> predict(String targetPath) throws Exception {
// 获取待预测图片
File sourceImage = new File(targetPath);
byte[] bytes = Files.readAllBytes(sourceImage.toPath());
MatOfByte matOfByte = new MatOfByte(bytes);
// 读取图片
Mat originalMat = Imgcodecs.imdecode(matOfByte, Imgcodecs.IMREAD_COLOR);
// 填充图片大小
Mat resizeMat = resizeWithPadding(originalMat);
// 将图片转成onnx模型需要的格式
try (OnnxTensor tensor = transferTensor(resizeMat)) {
// 获取推测结果
try (OrtSession.Result ortResult = onnxModelHolder.getSession().run(Collections.singletonMap("images", tensor))) {
OnnxTensor tensorOrtResult = (OnnxTensor) ortResult.get(0);
float[][][] dataRes = (float[][][]) tensorOrtResult.getValue();
float[][] data = dataRes[0];
// 过滤推测结果
JSONArray result = filterRec1(data);
result = filterRec2(result);
// 将推测结果转换为原图的坐标
JSONArray originalMatResult = transferSrc2Dst(result, originalMat.width(), originalMat.height());
// 组装返回结果返回
return originalMatResult.toJavaList(PredictResult.class);
}
}
}
/**
* 填充图片边框
*
* @param src 图片目标
* @return 填充后的图片目标
*/
private Mat resizeWithPadding(Mat src) {
Mat dst = new Mat();
int oldW = src.width();
int oldH = src.height();
double r = Math.min((double) onnxModelHolder.getNetWidth() / oldW, (double) onnxModelHolder.getNetHeight() / oldH);
int newUnpadW = (int) Math.round(oldW * r);
int newUnpadH = (int) Math.round(oldH * r);
int dw = (Long.valueOf(onnxModelHolder.getNetWidth()).intValue() - newUnpadW) / 2;
int dh = (Long.valueOf(onnxModelHolder.getNetHeight()).intValue() - newUnpadH) / 2;
int top = (int) Math.round(dh - 0.1);
int bottom = (int) Math.round(dh + 0.1);
int left = (int) Math.round(dw - 0.1);
int right = (int) Math.round(dw + 0.1);
Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
return dst;
}
/**
* 将图片转换为模型输入张量
*
* @param dst 图片目标
* @return 模型输入张量
*/
private OnnxTensor transferTensor(Mat dst) throws Exception {
Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
float[] whc = new float[Long.valueOf(onnxModelHolder.getChannels()).intValue()
* Long.valueOf(onnxModelHolder.getNetWidth()).intValue()
* Long.valueOf(onnxModelHolder.getNetHeight()).intValue()];
dst.get(0, 0, whc);
float[] chw = whc2cwh(whc);
return OnnxTensor.createTensor(onnxModelHolder.getEnv(),
FloatBuffer.wrap(chw),
new long[]{onnxModelHolder.getCount(), onnxModelHolder.getChannels(),
onnxModelHolder.getNetHeight(), onnxModelHolder.getNetWidth()});
}
/**
* 调整 src中的 [宽度,高度,通道] ->[通道,宽度,高度]
*
* @param src 待调整的数组
* @return 调整后的结果
*/
private float[] whc2cwh(float[] src) {
float[] chw = new float[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
return chw;
}
/**
* 过滤出大于指定置信度的数据并进行组装
*
* @param data 需要处理的数据
* @return 过滤后的结果
*/
private JSONArray filterRec1(float[][] data) {
JSONArray recList = new JSONArray();
for (float[] bbox : data) {
float[] xywh = new float[]{bbox[0], bbox[1], bbox[2], bbox[3]};
float[] xyxy = xywh2xyxy(xywh);
float confidence = bbox[4];
float[] classInfo = Arrays.copyOfRange(bbox, 5, 85);
int maxIndex = getMaxIndex(classInfo);
float maxValue = classInfo[maxIndex];
String maxClass = (String) onnxModelHolder.getTags().get(String.valueOf(maxIndex));
// 首先根据框图置信度粗选
if (confidence >= DEFAULT_CONFIDENCE_THRESHOLD) {
JSONObject detect = new JSONObject();
detect.put("name", maxClass);
// 概率
detect.put("percentage", maxValue);
// 置信度
detect.put("confidence", confidence);
detect.put("xmin", xyxy[0]);
detect.put("ymin", xyxy[1]);
detect.put("xmax", xyxy[2]);
detect.put("ymax", xyxy[3]);
recList.add(detect);
}
}
return recList;
}
/**
* NMS 过滤预测结果
* <p>
* 1.按照预测得分对所有边界框进行排序,通常是按照置信度得分降序排列。<p>
* 2.选择具有最高得分的边界框,并将其添加到最终的输出列表中。<p>
* 3.计算其余边界框与已选定的边界框的重叠度(IOU)。<p>
* 4.删除与已选定的边界框重叠度高于预定义阈值的边界框。<p>
* 5.重复上述步骤,直到所有边界框都被处理。<p>
*
* @param data 需要过滤的数据
* @return 过滤结果
*/
private JSONArray filterRec2(JSONArray data) {
JSONArray res = new JSONArray();
data.sort(Comparator.comparing(obj -> ((JSONObject) obj).getString("confidence")).reversed());
while (!data.isEmpty()) {
JSONObject max = data.getJSONObject(0);
res.add(max);
Iterator<Object> it = data.iterator();
while (it.hasNext()) {
JSONObject obj = (JSONObject) it.next();
double iou = calculateIoU(max, obj);
if (iou > DEFAULT_NMS_THRESHOLD) {
it.remove();
}
}
}
return res;
}
/**
* 将box的[box中心点x轴坐标,box中心y轴坐标,宽度,高度] 转换为
* [box左下x轴坐标点,box左下y轴坐标点,box右上x轴坐标点,box右上y轴坐标点]
*
* @param bbox 盒子
* @return 转换后的坐标点
*/
private float[] xywh2xyxy(float[] bbox) {
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
// 左下x,y坐标
float x1 = x - w * 0.5f;
float y1 = y - h * 0.5f;
// 右上x,y坐标
float x2 = x + w * 0.5f;
float y2 = y + h * 0.5f;
return new float[]{
x1 < 0 ? 0 : x1,
y1 < 0 ? 0 : y1,
x2 > onnxModelHolder.getNetWidth() ? onnxModelHolder.getNetWidth() : x2,
y2 > onnxModelHolder.getNetHeight() ? onnxModelHolder.getNetHeight() : y2};
}
/**
* 获取数组中最大值的index
*
* @param array 查找目标
* @return 最大值index
*/
private int getMaxIndex(float[] array) {
int maxIndex = 0;
float maxVal = array[0];
for (int i = 1; i < array.length; i++) {
if (array[i] > maxVal) {
maxVal = array[i];
maxIndex = i;
}
}
return maxIndex;
}
/**
* 计算交并比(IOU 是一种常用的评估指标,用于衡量预测边界框(或分割结果)与真实边界框之间的重叠程度)
*
* @param box1 预测边界框
* @param box2 真实边界框
* @return 交并比
*/
private double calculateIoU(JSONObject box1, JSONObject box2) {
double x1 = Math.max(box1.getDouble("xmin"), box2.getDouble("xmin"));
double y1 = Math.max(box1.getDouble("ymin"), box2.getDouble("ymin"));
double x2 = Math.min(box1.getDouble("xmax"), box2.getDouble("xmax"));
double y2 = Math.min(box1.getDouble("ymax"), box2.getDouble("ymax"));
double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
double box1Area = (box1.getDouble("xmax") - box1.getDouble("xmin") + 1) * (box1.getDouble("ymax") - box1.getDouble("ymin") + 1);
double box2Area = (box2.getDouble("xmax") - box2.getDouble("xmin") + 1) * (box2.getDouble("ymax") - box2.getDouble("ymin") + 1);
double unionArea = box1Area + box2Area - intersectionArea;
return intersectionArea / unionArea;
}
private JSONArray transferSrc2Dst(JSONArray data, int srcw, int srch) {
JSONArray res = new JSONArray();
float gain = Math.min((float) onnxModelHolder.getNetWidth() / srcw, (float) onnxModelHolder.getNetHeight() / srch);
float padW = (onnxModelHolder.getNetWidth() - srcw * gain) * 0.5f;
float padH = (onnxModelHolder.getNetHeight() - srch * gain) * 0.5f;
data.forEach(n -> {
JSONObject obj = JSONObject.parseObject(n.toString());
float xmin = obj.getFloat("xmin");
float ymin = obj.getFloat("ymin");
float xmax = obj.getFloat("xmax");
float ymax = obj.getFloat("ymax");
float xmin_ = Math.max(0, Math.min(srcw - 1, (xmin - padW) / gain));
float ymin_ = Math.max(0, Math.min(srch - 1, (ymin - padH) / gain));
float xmax_ = Math.max(0, Math.min(srcw - 1, (xmax - padW) / gain));
float ymax_ = Math.max(0, Math.min(srch - 1, (ymax - padH) / gain));
obj.put("xmin", xmin_);
obj.put("ymin", ymin_);
obj.put("xmax", xmax_);
obj.put("ymax", ymax_);
res.add(obj);
});
return res;
}
}
测试使用
import cn.hutool.core.thread.ThreadUtil;
import com.xatl.junpzx.modelruntime.domain.OnnxModelHolder;
import com.xatl.junpzx.modelruntime.domain.PredictResult;
import com.xatl.junpzx.modelruntime.service.OnnxModelLoader;
import com.xatl.junpzx.modelruntime.service.OnnxModelPredictHandler;
import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
/**
*
* @since 2024/5/6 15:05
* @author JunPzx
*/
public class OnnxPredictTest {
public static void main(String[] args) throws Exception {
// String imagePath = "E:\\Desktop\\model\\saveImg\\img_00100.jpg";
String onnxModelPath = "E:\\Desktop\\model\\miner-helmet-head.onnx";
long time = System.currentTimeMillis();
OnnxModelLoader loader = new OnnxModelLoader();
// 使用CPU推理
OnnxModelHolder onnxModelHolder = loader.loadModel(onnxModelPath);
// 使用GPU推理
// OnnxModelHolder onnxModelHolder = loader.loadModel(onnxModelPath, 0);
long loadedTime = System.currentTimeMillis();
System.out.println("加载消耗时间:" + (loadedTime - time));
OnnxModelPredictHandler onnxModelPredictHandler = new OnnxModelPredictHandler(onnxModelHolder);
for (int i = 0; i < 10; i++) {
int finalI = i;
String imagePath = "E:\\Desktop\\model\\saveImg\\img_0010" + i + ".jpg";
ThreadUtil.execAsync(() -> {
long predictBeginTime = System.currentTimeMillis();
try {
List<PredictResult> predictResult = onnxModelPredictHandler.predict(imagePath);
} catch (Exception e) {
throw new RuntimeException(e);
}
long predictEndTime = System.currentTimeMillis();
long curr = predictEndTime - predictBeginTime;
System.out.println("第" + finalI + "次推理消耗时间:" + curr);
});
}
}
public static void pointBox(String pic, List<PredictResult> box) {
if (box.isEmpty()) {
System.out.println("暂无识别目标");
return;
}
try {
File imageFile = new File(pic);
BufferedImage img = ImageIO.read(imageFile);
Graphics2D graph = img.createGraphics();
graph.setStroke(new BasicStroke(2));
graph.setFont(new Font("Serif", Font.BOLD, 20));
graph.setColor(Color.RED);
box.forEach(n -> {
float w = n.getXmax() - n.getXmin();
float h = n.getYmax() - n.getYmin();
graph.drawRect(
n.getXmin().intValue(),
n.getYmin().intValue(),
Float.valueOf(w).intValue(),
Float.valueOf(h).intValue());
BigDecimal bigDecimal = BigDecimal.valueOf(n.getConfidence()).setScale(2, RoundingMode.DOWN);
graph.drawString(n.getName() + " " + bigDecimal, n.getXmin() - 1, n.getYmin() - 5);
});
graph.dispose();
JFrame frame = new JFrame("Image Dialog");
frame.setSize(img.getWidth(), img.getHeight());
JLabel label = new JLabel(new ImageIcon(img));
frame.getContentPane().add(label);
frame.setVisible(true);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
} catch (Exception e) {
System.exit(0);
}
}
}
本文是原创文章,采用 CC BY-NC-ND 4.0 协议,完整转载请注明来自 星辰大海-Secret丶君
评论
匿名评论
隐私政策
你无需删除空行,直接评论以获取最佳展示效果