From f9316171db19a0eb466184511c401f514af003d8 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 21 Oct 2025 03:21:55 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=A2=9E=E5=BC=BApth?= =?UTF-8?q?=E8=BD=AConnx=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../SwinTransformer_for_Pytorch/pth2onnx.py | 81 ++++++++++--------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 3b5263295e..472af622e9 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -1,70 +1,79 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import os import argparse import torch import timm - +import numpy as np def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size - model_name = args.model_name out_path = args.out_path - # get size + checkpoint = torch.load(pth_path, map_location='cpu') + + config = checkpoint['config'] + state_dict = checkpoint['model'] + + model_name = config.MODEL.NAME + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # 修复relative_position_index的形状不匹配问题 + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + # 原始形状是 [2401],需要重塑为 [49, 49] + if new_state_dict[key].shape == torch.Size([2401]): + new_state_dict[key] = new_state_dict[key].view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + + model.eval() + if 's3' in model_name: - size = int(model_name.split('_')[3]) + input_size = int(model_name.split('_')[3]) else: - size = int(model_name.split('_')[4]) - input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) - input_names = ["image"] - output_names = ["out"] + input_size = int(model_name.split('_')[4]) - # build model - model = timm.create_model(model_name, checkpoint_path=pth_path) - model.eval() + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) + print(f"输入数据形状: {input_data.shape}") + + print("开始导出ONNX...") + # 导出ONNX torch.onnx.export( model, input_data, out_path, verbose=True, opset_version=11, - input_names=input_names, - output_names=output_names + input_names=["image"], + output_names=["output"], ) + print(f"✅ ONNX模型已保存到: {out_path}") def parse_arguments(): - parser = argparse.ArgumentParser(description='SwinTransformer onnx export.') + parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') - parser.add_argument('-n', '--model_name', type=str, default='swin_base_patch4_window12_384', - help='model name for swintransformer') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') - args = parser.parse_args() - args.out_path = os.path.abspath(args.out_path) - os.makedirs(os.path.dirname(args.out_path), exist_ok=True) - return args - + return parser.parse_args() if __name__ == '__main__': args = parse_arguments() pth2onnx(args) + -- Gitee From 553c99fe3058f776be51a1185aa1c245395b399f Mon Sep 17 00:00:00 2001 From: Matrix_K Date: Wed, 22 Oct 2025 02:48:20 +0000 Subject: [PATCH 2/2] add --- ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 472af622e9..fd2bf74457 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -4,6 +4,7 @@ import torch import timm import numpy as np + def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size @@ -63,6 +64,7 @@ def pth2onnx(args): print(f"✅ ONNX模型已保存到: {out_path}") + def parse_arguments(): parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, -- Gitee