diff --git a/ACL_PyTorch/contrib/cv/detection/YoloWorld/README.md b/ACL_PyTorch/contrib/cv/detection/YoloWorld/README.md new file mode 100644 index 0000000000000000000000000000000000000000..32e92c26d0028a5f8519df435947aa37d5b7593e --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YoloWorld/README.md @@ -0,0 +1,100 @@ +# YoloWorld高性能(TorchAir)-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) +- [模型推理性能&精度](#模型推理性能&精度) + +****** + +# 概述 +YOLOWorld 是对传统 YOLO 系列的突破性扩展,通过融合文本语义解决了 “类别固定” 的痛点。 + +- 版本说明: + ``` + url=https://github.com/ultralytics/ultralytics + tag=8.3.180 + model_name=yolov8x-worldv2.pt + ``` + +# 推理环境准备 +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |--------------| ------------------------------------------------------------ | + | 固件与驱动 | 24.0.0 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.2.rc1 | - | + | Python | 3.11 | - | + | PyTorch | 2.6.0 | - | + | Ascend Extension PyTorch | 2.6.0 | - | + | 说明:Atlas 800I A2/Atlas 300I Pro 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +# 快速上手 + +## 获取源码 +1. 获取本仓源码 + + ``` + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd ModelZoo-PyTorch/ACL_PyTorch/contrib/cv/detection/YoloWorld + ``` + +1. 获取模型仓源码 + ```bash + git clone https://github.com/ultralytics/ultralytics.git + cd ultralytics + git checkout v8.3.180 + git apply ../adapt-diff.patch + pip install -e . + ``` + +2. 安装依赖 + ```bash + pip3 install -r requirements.txt + ``` + +3. 下载模型权重 + + 开源链接:https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8x-world.pt + 魔塔社区:https://modelscope.cn/models/anyforge/ultralytics-allmodels/files + +4. 完整下载后的文件目录树如下 + + ```shell + YoloWorld + ├── ultralytics // 从开源代码仓下载的文件夹 + ├── yolov8x-world.pt // 模型权重下载 + ├── adapt-diff.patch + ├── infer.py // 本仓库提供的自定义推理脚本 + ├── README.md + └── requirements.txt + ``` + +## 模型推理 + +1. 执行推理命令 + + ```bash + # 解决nan值导致的精度异常问题 + export INF_NAN_MODE_FORCE_DISABLE=1 + + python3 infer.py + ``` + +2. 性能推理命令 + ```bash + python3 performance_test.py + ``` + 推理后将打屏推理结果和模型性能 + +# 模型推理性能&精度 +以800I A2为例 + +| 模型 | 硬件 | 端到端性能 | +|---------------------|------|-------------------| +| yolov8x-worldv2 |800T A2| 69.64 FPS | + diff --git a/ACL_PyTorch/contrib/cv/detection/YoloWorld/adapt-diff.patch b/ACL_PyTorch/contrib/cv/detection/YoloWorld/adapt-diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..11dc1f1733af84e59fa283aa07ec53a3bd2b89cb --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YoloWorld/adapt-diff.patch @@ -0,0 +1,140 @@ +diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py +index b96c66e60..193e9bb8c 100644 +--- a/ultralytics/engine/predictor.py ++++ b/ultralytics/engine/predictor.py +@@ -41,6 +41,7 @@ from typing import Any, Dict, List, Optional, Union + import cv2 + import numpy as np + import torch ++from torch import nn + + from ultralytics.cfg import get_cfg, get_save_dir + from ultralytics.data import load_inference_source +@@ -179,7 +180,16 @@ class BasePredictor: + if self.args.visualize and (not self.source_type.tensor) + else False + ) +- return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) ++ _, _, height, width = im.shape ++ img_ht, img_wt = max(height, width), max(height, width) ++ left_pad = (img_wt - width) // 2 ++ right_pad = img_wt - width - left_pad ++ top_pad = (img_ht - height) // 2 ++ bot_pad = img_ht - height - top_pad ++ im = nn.ReplicationPad2d(padding=(left_pad, right_pad, top_pad, bot_pad))(im) ++ with torch.autocast(device_type="npu",dtype=torch.float16): ++ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) ++ # return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im: List[np.ndarray]) -> List[np.ndarray]: + """ +diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py +index 59a4bfd1c..4d87e7a74 100644 +--- a/ultralytics/engine/validator.py ++++ b/ultralytics/engine/validator.py +@@ -156,7 +156,8 @@ class BaseValidator: + callbacks.add_integration_callbacks(self) + model = AutoBackend( + weights=model or self.args.model, +- device=select_device(self.args.device, self.args.batch), ++ # device=select_device(self.args.device, self.args.batch), ++ device=torch.device("npu"), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, +diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py +index eca996dc6..c56ce9901 100644 +--- a/ultralytics/nn/autobackend.py ++++ b/ultralytics/nn/autobackend.py +@@ -12,6 +12,9 @@ import cv2 + import numpy as np + import torch + import torch.nn as nn ++import torch_npu ++import torchair ++from torch_npu.contrib import transfer_to_npu + from PIL import Image + + from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML +@@ -204,6 +207,10 @@ class AutoBackend(nn.Module): + model.half() if fp16 else model.float() + ch = model.yaml.get("channels", 3) + self.model = model # explicitly assign for to(), cpu(), cuda(), half() ++ _config = torchair.CompilerConfig() ++ _config.experimental_config.frozen_parameter = True ++ npu_backend = torchair.get_npu_backend(compiler_config=_config) ++ model = torch.compile(model, backend=npu_backend, fullgraph=True, dynamic=False) + pt = True + + # PyTorch +diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py +index 78ce2b745..0fa586dd6 100644 +--- a/ultralytics/nn/modules/block.py ++++ b/ultralytics/nn/modules/block.py +@@ -702,7 +702,12 @@ class C2fAttn(nn.Module): + (torch.Tensor): Output tensor after processing. + """ + y = list(self.cv1(x).split((self.c, self.c), 1)) +- y.extend(m(y[-1]) for m in self.m) ++ #--------------------------- ++ for lm in self.m: ++ o1 = lm(y[-1]) ++ y.extend(o1.unsqueeze(0)) ++ #--------------------------- ++ # y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + +diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py +index 52d824063..051ca9665 100644 +--- a/ultralytics/nn/modules/head.py ++++ b/ultralytics/nn/modules/head.py +@@ -160,9 +160,16 @@ class Detect(nn.Module): + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) ++ ''' + if self.format != "imx" and (self.dynamic or self.shape != shape): + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape ++ ''' ++ #------------------------------------------ ++ anchors, strides = make_anchors(x, self.stride, 0.5) ++ anchors = anchors.transpose(0, 1) ++ strides = strides.transpose(0, 1) ++ #------------------------------------------ + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] +@@ -176,10 +183,15 @@ class Detect(nn.Module): + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) +- norm = self.strides / (self.stride[0] * grid_size) +- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) ++ # norm = self.strides / (self.stride[0] * grid_size) ++ # dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) ++ #--------------------------- ++ norm = strides / (strides[0] * grid_size) ++ dbox = self.decode_bboxes(self.dfl(box) * norm, anchors.unsqueeze(0) * norm[:, :2]) ++ #--------------------------- + else: +- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides ++ # dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides ++ dbox = self.decode_bboxes(self.dfl(box), anchors.unsqueeze(0)) * strides + if self.export and self.format == "imx": + return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1) + return torch.cat((dbox, cls.sigmoid()), 1) +diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py +index 3a2091f66..ed8ba55a0 100644 +--- a/ultralytics/utils/tal.py ++++ b/ultralytics/utils/tal.py +@@ -375,7 +375,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) +- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ #stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) ++ stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device) * stride) + return torch.cat(anchor_points), torch.cat(stride_tensor) + diff --git a/ACL_PyTorch/contrib/cv/detection/YoloWorld/infer.py b/ACL_PyTorch/contrib/cv/detection/YoloWorld/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..75d84b3ee08048b4d853da405ea4f42766481cf2 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YoloWorld/infer.py @@ -0,0 +1,11 @@ + + + + +from ultralytics import YOLO + +# Create a YOLO-World model +model = YOLO("yolov8x-worldv2.pt") # or select yolov8m/l-world.pt for different sizes + +# Conduct model validation on the COCO8 example dataset +metrics = model.val(data="coco8.yaml") \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/detection/YoloWorld/performance_test.py b/ACL_PyTorch/contrib/cv/detection/YoloWorld/performance_test.py new file mode 100644 index 0000000000000000000000000000000000000000..608c1e3ee6eb59875a66d822cd0e1f43ec75fe80 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YoloWorld/performance_test.py @@ -0,0 +1,99 @@ + + + +import os +import time +from pathlib import Path +from ultralytics import YOLO +import numpy as np + + +os.environ['TASK_QUEUE_ENABLE'] = '2' +os.environ['PYTORCH_NPU_ALLOC_CONF'] = 'expandable_segments:True' +os.environ['ACLNN_CACHE_LIMIT'] = '100000' +os.environ["CPU_AFFINITY_CONF"] = '1' + + + +def test_yolo_world(model_path, data_path, img_size=640, conf_threshold=0.25, iou_threshold=0.5): + """ + 测试YOLO-World模型的性能和精度 + + 参数: + model_path: 模型权重文件路径 + data_path: 测试数据集配置文件路径 + img_size: 输入图像尺寸 + conf_threshold: 置信度阈值 + iou_threshold: IOU阈值 + """ + print("=" * 50) + print(f"开始测试YOLO-World模型: {Path(model_path).name}") + print(f"测试数据集: {Path(data_path).name}") + print(f"图像尺寸: {img_size}, 置信度阈值: {conf_threshold}, IOU阈值: {iou_threshold}") + print("=" * 50) + + # 加载模型 + start_load = time.time() + model = YOLO(model_path).to('npu') + load_time = time.time() - start_load + print(f"模型加载时间: {load_time:.4f}秒") + + # 精度测试 (使用val方法) + print("\n开始精度测试...") + start_val = time.time() + metrics = model.val( + data=data_path, + imgsz=img_size, + conf=conf_threshold, + iou=iou_threshold, + verbose=False # 关闭详细输出,避免冗余 + + ) + val_time = time.time() - start_val + print(f"精度测试完成,耗时: {val_time:.4f}秒") + + + # 性能测试 (推理速度) + print("\n开始性能测试...") + # 创建测试图像 (随机像素,模拟输入) + test_image = np.random.randint(0, 255, (img_size, img_size, 3), dtype=np.uint8) + + # 预热模型 (首次推理通常较慢) + for _ in range(10): + model(test_image, imgsz=img_size, conf=conf_threshold, iou=iou_threshold, verbose=False) + + # 正式测试 + num_runs = 100 # 测试次数 + start_infer = time.time() + for _ in range(num_runs): + model(test_image, imgsz=img_size, conf=conf_threshold, iou=iou_threshold, verbose=False) + infer_time = time.time() - start_infer + + # 计算性能指标 + avg_infer_time = infer_time / num_runs * 1000 # 转换为毫秒 + fps = num_runs / infer_time + + # 打印性能指标 + print("\n===== 性能指标 =====") + print(f"平均推理时间: {avg_infer_time:.2f}毫秒") + print(f"帧率 (FPS): {fps:.2f}") + + +if __name__ == "__main__": + # 模型和数据集配置 + MODEL_PATH = "yolov8x-worldv2.pt" # 可替换为yolov8m-world.pt或yolov8l-world.pt + DATA_PATH = "coco8.yaml" # 测试数据集配置文件 + + # 测试参数 + IMG_SIZE = 640 + CONF_THRESHOLD = 0.25 + IOU_THRESHOLD = 0.5 + + # 执行测试 + test_yolo_world( + model_path=MODEL_PATH, + data_path=DATA_PATH, + img_size=IMG_SIZE, + conf_threshold=CONF_THRESHOLD, + iou_threshold=IOU_THRESHOLD + ) \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/detection/YoloWorld/requirements.txt b/ACL_PyTorch/contrib/cv/detection/YoloWorld/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..755adf5563a8bce979c22a500e0897bab633d511 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/detection/YoloWorld/requirements.txt @@ -0,0 +1,3 @@ +numpy==1.26.4 +opencv-python==4.9.0.80 +polars==1.33.1 \ No newline at end of file