测试onxx model和图片涉及到了涉密,所以就不提供了。如果本地环境有问题或者需要搭建本地环境或者部署环境可以参考:搭建opencv和JavaOnnxRuntime环境

代码里面清楚描述的案例中使用的onnx model结构,使用前请使用下面网站看看结构是否一致,否则可能会导致代码异常。\

下面代码可能有部分包maven没有导入,自行导入一下。

Maven依赖

 <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <javacv.version>1.5.9</javacv.version>
        <system.windowsx64>windows-x86_64</system.windowsx64>
        <system.liunx64>linux-x86_64</system.liunx64>
</properties>

<dependencies>
        <!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime -->
<!--        <dependency>-->
<!--            <groupId>com.microsoft.onnxruntime</groupId>-->
<!--            <artifactId>onnxruntime</artifactId>-->
<!--            <version>1.17.3</version>-->
<!--        </dependency>-->

        <!-- https://mvnrepository.com/artifact/com.microsoft.onnxruntime/onnxruntime_gpu -->
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime_gpu</artifactId>
            <version>1.17.3</version>
        </dependency>

        <dependency>
            <groupId>org.codehaus.groovy</groupId>
            <artifactId>groovy</artifactId>
        </dependency>

        <!-- javacv+javacpp核心库-->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv</artifactId>
            <version>${javacv.version}</version>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacpp-platform</artifactId>
            <version>${javacv.version}</version>
        </dependency>
        <!-- 最小opencv依赖包 ,必须包含上面的javacv+javacpp -->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>opencv</artifactId>
            <version>4.7.0-${javacv.version}</version>
            <classifier>${system.windowsx64}</classifier>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>opencv</artifactId>
            <version>4.7.0-${javacv.version}</version>
            <classifier>${system.liunx64}</classifier>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>openblas</artifactId>
            <version>0.3.23-${javacv.version}</version>
            <classifier>${system.windowsx64}</classifier>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>openblas</artifactId>
            <version>0.3.23-${javacv.version}</version>
            <classifier>${system.liunx64}</classifier>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>flycapture</artifactId>
            <version>2.13.3.31-${javacv.version}</version>
            <classifier>${system.windowsx64}</classifier>
        </dependency>
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>flycapture</artifactId>
            <version>2.13.3.31-${javacv.version}</version>
            <classifier>${system.liunx64}</classifier>
        </dependency>
    </dependencies>

Java代码

实体类

OnnxModelHolder

/**
 * Onnx模型持有者
 *
 * @since 2024/5/6 13:59
 * @author JunPzx
 */
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
@Data
public class OnnxModelHolder {

    private OrtEnvironment env;

    private OrtSession session;

    private JSONObject tags;

    private long count;

    private long channels;

    private long netHeight;

    private long netWidth;
}

PredictResult

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.Accessors;

/**
 * 推测结果对象
 *
 * @since 2024/5/6 17:17
 * @author JunPzx
 */
@NoArgsConstructor
@Data
@AllArgsConstructor
@Accessors(chain = true)
public class PredictResult {
    /**
     * y轴最小值
     */
    private Float ymin;
    /**
     * x轴最小值
     */
    private Float xmin;
    /**
     * y轴最大值
     */
    private Float ymax;
    /**
     * x轴最大值
     */
    private Float xmax;
    /**
     * 概率
     */
    private Float percentage;
    /**
     * 置信度
     */
    private Float confidence;
    /**
     * 类别
     */
    private String name;
}

逻辑代码

OnnxModelLoader


import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import com.junpzx.linglong.modelruntime.domain.OnnxModelHolder;
import lombok.extern.slf4j.Slf4j;
import org.opencv.core.Core;

import java.util.Map;

/**
 * onnx模型加载器
 *
 * @since 2024/5/6 13:57
 * @author JunPzx
 */
@Slf4j
public class OnnxModelLoader {

    public OnnxModelHolder loadModel(String modelPath) throws Exception {
        return loadModel(modelPath, false, null);
    }

    public OnnxModelHolder loadModel(String modelPath, int gpuDeviceId) throws Exception {
        return loadModel(modelPath, true, gpuDeviceId);
    }


    private OnnxModelHolder loadModel(String modelPath, Boolean enableGpu, Integer gpuDeviceId) throws Exception {
        OnnxModelHolder onnxModelHolder = new OnnxModelHolder();

        onnxModelHolder.setEnv(OrtEnvironment.getEnvironment());
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        if (enableGpu) {
            try {
                // The GPU device ID to execute on
                sessionOptions.addCUDA(gpuDeviceId);
            } catch (Exception e) {
                log.error("加载Onnx模型指定GPU CUDA发生错误", e);
            }
        }
        onnxModelHolder.setSession(onnxModelHolder.getEnv().createSession(modelPath, sessionOptions));
        OnnxModelMetadata metadata = onnxModelHolder.getSession().getMetadata();
        Map<String, NodeInfo> infoMap = onnxModelHolder.getSession().getInputInfo();
        TensorInfo nodeInfo = (TensorInfo) infoMap.get("images").getInfo();
        // todo 临时方案,后续根据实际情况获取模型的标签数据
        onnxModelHolder.setTags(new JSONObject() {
            {
                put("0", "miner");
                put("1", "helmet");
                put("2", "head");
            }
        });
        // 从模型中读取标签数据
        //            String nameClass = metadata.getCustomMetadata().get("names");
        //            onnxModelHolder.setTags(JSONObject.parseObject(nameClass.replace("\"", "\"\"")));
        // 打印解析到的数据
        printlnInfo(metadata, infoMap, nodeInfo, onnxModelHolder.getTags());
        //1 模型每次处理一张图片
        onnxModelHolder.setCount(nodeInfo.getShape()[0]);
        //3 模型通道数
        onnxModelHolder.setChannels(nodeInfo.getShape()[1]);
        //640 模型高
        onnxModelHolder.setNetHeight(nodeInfo.getShape()[2]);
        //640 模型宽
        onnxModelHolder.setNetWidth(nodeInfo.getShape()[3]);
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        return onnxModelHolder;
    }


    private void printlnInfo(OnnxModelMetadata metadata, Map<String, NodeInfo> infoMap, TensorInfo nodeInfo,
                             JSONObject names) {
        System.out.println("-------打印模型信息开始--------");
        System.out.println("getProducerName=" + metadata.getProducerName());
        System.out.println("getGraphName=" + metadata.getGraphName());
        System.out.println("getDescription=" + metadata.getDescription());
        System.out.println("getDomain=" + metadata.getDomain());
        System.out.println("getVersion=" + metadata.getVersion());
        System.out.println("getCustomMetadata=" + metadata.getCustomMetadata());
        System.out.println("getInputInfo=" + infoMap);
        System.out.println("nodeInfo=" + nodeInfo);
        System.out.println("类别信息:" + names);
        System.out.println("-------打印模型信息结束--------");
    }
}

PredictHandler

package com.xatl.linglong.modelruntime.service;

import com.xatl.linglong.modelruntime.domain.PredictResult;

import java.awt.image.BufferedImage;
import java.io.InputStream;
import java.util.List;

/**
 * 预测处理器
 *
 * @author JunPzx
 * @since 2024/5/6 14:27
 */
public interface PredictHandler {

    /**
     * 预测
     *
     * @param targetPath 目标路径
     * @return 预测结果
     */
    List<PredictResult> predict(String targetPath) throws Exception;

    /**
     * 预测
     *
     * @param is 是
     * @return {@link List }<{@link PredictResult }>
     * @throws Exception 例外
     */
    List<PredictResult> predict(InputStream is) throws Exception;

    /**
     * 预测
     *
     * @param bufferedImage 缓冲图像
     * @return {@link List }<{@link PredictResult }>
     * @throws Exception
     */
    List<PredictResult> predict(BufferedImage bufferedImage) throws Exception;

}

OnnxModelPredictHandler


import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtSession;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.junpzx.linglong.modelruntime.domain.OnnxModelHolder;
import com.junpzx.linglong.modelruntime.domain.PredictResult;
import lombok.extern.slf4j.Slf4j;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import java.io.File;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.util.*;

/**
 * onnx 模型预测处理器
 *
 * @author JunPzx
 * @since 2024/5/6 14:29
 */
@Slf4j
public class OnnxModelPredictHandler implements PredictHandler {

    private final float DEFAULT_NMS_THRESHOLD = 0.45f;

    private final float DEFAULT_CONFIDENCE_THRESHOLD = 0.25f;

    private final OnnxModelHolder onnxModelHolder;

    public OnnxModelPredictHandler(OnnxModelHolder onnxModelHolder) {
        this.onnxModelHolder = onnxModelHolder;
    }

    public OnnxModelPredictHandler(String onnxModelPath) throws Exception {
        this.onnxModelHolder = new OnnxModelLoader().loadModel(onnxModelPath);
    }

    @Override
    public List<PredictResult> predict(String targetPath) throws Exception {
        // 获取待预测图片
        File sourceImage = new File(targetPath);
        byte[] bytes = Files.readAllBytes(sourceImage.toPath());
        MatOfByte matOfByte = new MatOfByte(bytes);
        // 读取图片
        Mat originalMat = Imgcodecs.imdecode(matOfByte, Imgcodecs.IMREAD_COLOR);
        // 填充图片大小
        Mat resizeMat = resizeWithPadding(originalMat);
        // 将图片转成onnx模型需要的格式
        try (OnnxTensor tensor = transferTensor(resizeMat)) {
            // 获取推测结果
            try (OrtSession.Result ortResult = onnxModelHolder.getSession().run(Collections.singletonMap("images", tensor))) {
                OnnxTensor tensorOrtResult = (OnnxTensor) ortResult.get(0);
                float[][][] dataRes = (float[][][]) tensorOrtResult.getValue();
                float[][] data = dataRes[0];
                // 过滤推测结果
                JSONArray result = filterRec1(data);
                result = filterRec2(result);
                // 将推测结果转换为原图的坐标
                JSONArray originalMatResult = transferSrc2Dst(result, originalMat.width(), originalMat.height());
                // 组装返回结果返回
                return originalMatResult.toJavaList(PredictResult.class);
            }
        }
    }

    /**
     * 填充图片边框
     *
     * @param src 图片目标
     * @return 填充后的图片目标
     */
    private Mat resizeWithPadding(Mat src) {
        Mat dst = new Mat();
        int oldW = src.width();
        int oldH = src.height();
        double r = Math.min((double) onnxModelHolder.getNetWidth() / oldW, (double) onnxModelHolder.getNetHeight() / oldH);
        int newUnpadW = (int) Math.round(oldW * r);
        int newUnpadH = (int) Math.round(oldH * r);
        int dw = (Long.valueOf(onnxModelHolder.getNetWidth()).intValue() - newUnpadW) / 2;
        int dh = (Long.valueOf(onnxModelHolder.getNetHeight()).intValue() - newUnpadH) / 2;
        int top = (int) Math.round(dh - 0.1);
        int bottom = (int) Math.round(dh + 0.1);
        int left = (int) Math.round(dw - 0.1);
        int right = (int) Math.round(dw + 0.1);
        Imgproc.resize(src, dst, new Size(newUnpadW, newUnpadH));
        Core.copyMakeBorder(dst, dst, top, bottom, left, right, Core.BORDER_CONSTANT);
        return dst;
    }

    /**
     * 将图片转换为模型输入张量
     *
     * @param dst 图片目标
     * @return 模型输入张量
     */
    private OnnxTensor transferTensor(Mat dst) throws Exception {
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
        float[] whc = new float[Long.valueOf(onnxModelHolder.getChannels()).intValue()
                * Long.valueOf(onnxModelHolder.getNetWidth()).intValue()
                * Long.valueOf(onnxModelHolder.getNetHeight()).intValue()];
        dst.get(0, 0, whc);
        float[] chw = whc2cwh(whc);
        return OnnxTensor.createTensor(onnxModelHolder.getEnv(),
                FloatBuffer.wrap(chw),
                new long[]{onnxModelHolder.getCount(), onnxModelHolder.getChannels(),
                        onnxModelHolder.getNetHeight(), onnxModelHolder.getNetWidth()});
    }

    /**
     * 调整 src中的 [宽度,高度,通道] ->[通道,宽度,高度]
     *
     * @param src 待调整的数组
     * @return 调整后的结果
     */
    private float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; ++ch) {
            for (int i = ch; i < src.length; i += 3) {
                chw[j] = src[i];
                j++;
            }
        }
        return chw;
    }

    /**
     * 过滤出大于指定置信度的数据并进行组装
     *
     * @param data 需要处理的数据
     * @return 过滤后的结果
     */
    private JSONArray filterRec1(float[][] data) {
        JSONArray recList = new JSONArray();
        for (float[] bbox : data) {
            float[] xywh = new float[]{bbox[0], bbox[1], bbox[2], bbox[3]};
            float[] xyxy = xywh2xyxy(xywh);
            float confidence = bbox[4];
            float[] classInfo = Arrays.copyOfRange(bbox, 5, 85);
            int maxIndex = getMaxIndex(classInfo);
            float maxValue = classInfo[maxIndex];
            String maxClass = (String) onnxModelHolder.getTags().get(String.valueOf(maxIndex));
            // 首先根据框图置信度粗选
            if (confidence >= DEFAULT_CONFIDENCE_THRESHOLD) {
                JSONObject detect = new JSONObject();
                detect.put("name", maxClass);
                // 概率
                detect.put("percentage", maxValue);
                // 置信度
                detect.put("confidence", confidence);
                detect.put("xmin", xyxy[0]);
                detect.put("ymin", xyxy[1]);
                detect.put("xmax", xyxy[2]);
                detect.put("ymax", xyxy[3]);
                recList.add(detect);
            }
        }
        return recList;
    }

    /**
     * NMS 过滤预测结果
     * <p>
     * 1.按照预测得分对所有边界框进行排序,通常是按照置信度得分降序排列。<p>
     * 2.选择具有最高得分的边界框,并将其添加到最终的输出列表中。<p>
     * 3.计算其余边界框与已选定的边界框的重叠度(IOU)。<p>
     * 4.删除与已选定的边界框重叠度高于预定义阈值的边界框。<p>
     * 5.重复上述步骤,直到所有边界框都被处理。<p>
     *
     * @param data 需要过滤的数据
     * @return 过滤结果
     */
    private JSONArray filterRec2(JSONArray data) {
        JSONArray res = new JSONArray();
        data.sort(Comparator.comparing(obj -> ((JSONObject) obj).getString("confidence")).reversed());
        while (!data.isEmpty()) {
            JSONObject max = data.getJSONObject(0);
            res.add(max);
            Iterator<Object> it = data.iterator();
            while (it.hasNext()) {
                JSONObject obj = (JSONObject) it.next();
                double iou = calculateIoU(max, obj);
                if (iou > DEFAULT_NMS_THRESHOLD) {
                    it.remove();
                }
            }
        }
        return res;
    }

    /**
     * 将box的[box中心点x轴坐标,box中心y轴坐标,宽度,高度] 转换为
     * [box左下x轴坐标点,box左下y轴坐标点,box右上x轴坐标点,box右上y轴坐标点]
     *
     * @param bbox 盒子
     * @return 转换后的坐标点
     */
    private float[] xywh2xyxy(float[] bbox) { 
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
        // 左下x,y坐标
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        // 右上x,y坐标
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        return new float[]{
                x1 < 0 ? 0 : x1,
                y1 < 0 ? 0 : y1,
                x2 > onnxModelHolder.getNetWidth() ? onnxModelHolder.getNetWidth() : x2,
                y2 > onnxModelHolder.getNetHeight() ? onnxModelHolder.getNetHeight() : y2};
    }

    /**
     * 获取数组中最大值的index
     *
     * @param array 查找目标
     * @return 最大值index
     */
    private int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxVal = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxVal) {
                maxVal = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    /**
     * 计算交并比(IOU 是一种常用的评估指标,用于衡量预测边界框(或分割结果)与真实边界框之间的重叠程度)
     *
     * @param box1 预测边界框
     * @param box2 真实边界框
     * @return 交并比
     */
    private double calculateIoU(JSONObject box1, JSONObject box2) {
        double x1 = Math.max(box1.getDouble("xmin"), box2.getDouble("xmin"));
        double y1 = Math.max(box1.getDouble("ymin"), box2.getDouble("ymin"));
        double x2 = Math.min(box1.getDouble("xmax"), box2.getDouble("xmax"));
        double y2 = Math.min(box1.getDouble("ymax"), box2.getDouble("ymax"));
        double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
        double box1Area = (box1.getDouble("xmax") - box1.getDouble("xmin") + 1) * (box1.getDouble("ymax") - box1.getDouble("ymin") + 1);
        double box2Area = (box2.getDouble("xmax") - box2.getDouble("xmin") + 1) * (box2.getDouble("ymax") - box2.getDouble("ymin") + 1);
        double unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }


    private JSONArray transferSrc2Dst(JSONArray data, int srcw, int srch) {
        JSONArray res = new JSONArray();

        float gain = Math.min((float) onnxModelHolder.getNetWidth() / srcw, (float) onnxModelHolder.getNetHeight() / srch);
        float padW = (onnxModelHolder.getNetWidth() - srcw * gain) * 0.5f;
        float padH = (onnxModelHolder.getNetHeight() - srch * gain) * 0.5f;
        data.forEach(n -> {
            JSONObject obj = JSONObject.parseObject(n.toString());
            float xmin = obj.getFloat("xmin");
            float ymin = obj.getFloat("ymin");
            float xmax = obj.getFloat("xmax");
            float ymax = obj.getFloat("ymax");
            float xmin_ = Math.max(0, Math.min(srcw - 1, (xmin - padW) / gain));
            float ymin_ = Math.max(0, Math.min(srch - 1, (ymin - padH) / gain));
            float xmax_ = Math.max(0, Math.min(srcw - 1, (xmax - padW) / gain));
            float ymax_ = Math.max(0, Math.min(srch - 1, (ymax - padH) / gain));
            obj.put("xmin", xmin_);
            obj.put("ymin", ymin_);
            obj.put("xmax", xmax_);
            obj.put("ymax", ymax_);
            res.add(obj);
        });
        return res;
    }
}

测试使用

import cn.hutool.core.thread.ThreadUtil;
import com.xatl.junpzx.modelruntime.domain.OnnxModelHolder;
import com.xatl.junpzx.modelruntime.domain.PredictResult;
import com.xatl.junpzx.modelruntime.service.OnnxModelLoader;
import com.xatl.junpzx.modelruntime.service.OnnxModelPredictHandler;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;

/**
 *
 * @since 2024/5/6 15:05
 * @author JunPzx
 */
public class OnnxPredictTest {

    public static void main(String[] args) throws Exception {
//        String imagePath = "E:\\Desktop\\model\\saveImg\\img_00100.jpg";
        String onnxModelPath = "E:\\Desktop\\model\\miner-helmet-head.onnx";
        long time = System.currentTimeMillis();
        OnnxModelLoader loader = new OnnxModelLoader();
		// 使用CPU推理
        OnnxModelHolder onnxModelHolder = loader.loadModel(onnxModelPath);
        // 使用GPU推理
		// OnnxModelHolder onnxModelHolder = loader.loadModel(onnxModelPath, 0);
        long loadedTime = System.currentTimeMillis();
        System.out.println("加载消耗时间:" + (loadedTime - time));
        OnnxModelPredictHandler onnxModelPredictHandler = new OnnxModelPredictHandler(onnxModelHolder);
        for (int i = 0; i < 10; i++) {
            int finalI = i;
            String imagePath = "E:\\Desktop\\model\\saveImg\\img_0010" + i + ".jpg";
            ThreadUtil.execAsync(() -> {
                long predictBeginTime = System.currentTimeMillis();
                try {
                    List<PredictResult> predictResult = onnxModelPredictHandler.predict(imagePath);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
                long predictEndTime = System.currentTimeMillis();
                long curr = predictEndTime - predictBeginTime;
                System.out.println("第" + finalI + "次推理消耗时间:" + curr);
            });
        }
    }


    public static void pointBox(String pic, List<PredictResult> box) {
        if (box.isEmpty()) {
            System.out.println("暂无识别目标");
            return;
        }
        try {
            File imageFile = new File(pic);
            BufferedImage img = ImageIO.read(imageFile);
            Graphics2D graph = img.createGraphics();
            graph.setStroke(new BasicStroke(2));
            graph.setFont(new Font("Serif", Font.BOLD, 20));
            graph.setColor(Color.RED);
            box.forEach(n -> {
                float w = n.getXmax() - n.getXmin();
                float h = n.getYmax() - n.getYmin();
                graph.drawRect(
                        n.getXmin().intValue(),
                        n.getYmin().intValue(),
                        Float.valueOf(w).intValue(),
                        Float.valueOf(h).intValue());
                BigDecimal bigDecimal = BigDecimal.valueOf(n.getConfidence()).setScale(2, RoundingMode.DOWN);
                graph.drawString(n.getName() + " " + bigDecimal, n.getXmin() - 1, n.getYmin() - 5);
            });
            graph.dispose();
            JFrame frame = new JFrame("Image Dialog");
            frame.setSize(img.getWidth(), img.getHeight());
            JLabel label = new JLabel(new ImageIcon(img));
            frame.getContentPane().add(label);
            frame.setVisible(true);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        } catch (Exception e) {
            System.exit(0);
        }
    }
}