diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 3b5263295e8173dc8bc1592fc44eb9eff5462fd0..fd2bf744577b18d07ec00b6153e5a7db50e2f872 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -1,70 +1,81 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import os import argparse import torch import timm +import numpy as np def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size - model_name = args.model_name out_path = args.out_path - # get size + checkpoint = torch.load(pth_path, map_location='cpu') + + config = checkpoint['config'] + state_dict = checkpoint['model'] + + model_name = config.MODEL.NAME + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # 修复relative_position_index的形状不匹配问题 + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + # 原始形状是 [2401],需要重塑为 [49, 49] + if new_state_dict[key].shape == torch.Size([2401]): + new_state_dict[key] = new_state_dict[key].view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + + model.eval() + if 's3' in model_name: - size = int(model_name.split('_')[3]) + input_size = int(model_name.split('_')[3]) else: - size = int(model_name.split('_')[4]) - input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) - input_names = ["image"] - output_names = ["out"] + input_size = int(model_name.split('_')[4]) - # build model - model = timm.create_model(model_name, checkpoint_path=pth_path) - model.eval() + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) + print(f"输入数据形状: {input_data.shape}") + print("开始导出ONNX...") + + # 导出ONNX torch.onnx.export( model, input_data, out_path, verbose=True, opset_version=11, - input_names=input_names, - output_names=output_names + input_names=["image"], + output_names=["output"], ) + print(f"✅ ONNX模型已保存到: {out_path}") + def parse_arguments(): - parser = argparse.ArgumentParser(description='SwinTransformer onnx export.') + parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') - parser.add_argument('-n', '--model_name', type=str, default='swin_base_patch4_window12_384', - help='model name for swintransformer') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') - args = parser.parse_args() - args.out_path = os.path.abspath(args.out_path) - os.makedirs(os.path.dirname(args.out_path), exist_ok=True) - return args - + return parser.parse_args() if __name__ == '__main__': args = parse_arguments() pth2onnx(args) +