登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
AI 队友
登录
注册
Gitee 2025年度开源项目评选启动,快来选出你心中的最佳开源项目!
代码拉取完成,页面将自动刷新
仓库状态说明
开源项目
>
人工智能
>
机器学习/深度学习
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
88
Star
653
Fork
1.5K
Ascend
/
pytorch
暂停
代码
Issues
41
Pull Requests
350
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
torch.compile在dynamic选项开启下存在精度问题
TODO
#ICV3N8
缺陷
Chrisrr
创建于
2025-08-27 17:52
一、问题现象(附报错日志上下文): torch.compile存在精度问题 二、软件版本: -- CANN 版本 (e.g., CANN 3.0.x,5.x.x): 8.2.rc1.alpha002 -- Pytorch 版本: 2.6.0 -- Python 版本 (e.g., Python 3.7.5): Python 3.11.9 三、测试步骤: 直接使用`python test.py`运行这个用例,会产生精度问题 ``` import torch def add(x, y): return x+y dy_model = torch.compile(add, dynamic=True, backend="inductor") def test_add(shape0, shape1, failed_it): x = torch.randn(shape0, shape1).npu() y = torch.randn(shape0, shape1).npu() expect = x + y output = dy_model(x, y) assert torch.allclose(expect, output, rtol=1e-5, atol=1e-8), f"exp: {expect}\n out: {output}\n{failed_it}" test_add(32, 16, 0) test_add(16, 16, 0) print("test begin") for i in range(100): test_add(64, 16, i) ``` 四、日志信息: 这个问题产生的原因是,`dynamic=True`的情况下,编译出的triton算子调用时仅使用了一个hardcode的参数而不是一个变量: ```python triton_unk_fused_add_0.run(arg2_1, arg3_1, buf0, 512, grid=grid(512), stream=stream0) ``` 这里的`512`是错误的,应该是`s0 * s1`。 <details> <summary>点击查看完整Triton代码</summary> ```python # AOT ID: ['1_inference'] from ctypes import c_void_p, c_long, c_int import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.async_compile import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall import torch_npu import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import ( split_scan_grid, grid_combo_kernels, start_graph, end_graph, cooperative_reduction_grid, ) from torch_npu._inductor.npu_triton_heuristics import grid import torch_npu from torch_npu._inductor import get_current_raw_stream as get_raw_stream from torch_npu._inductor import get_current_raw_stream as get_raw_stream aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p # kernel path: /tmp/torchinductor_root/nw/cnw6pnjoek2mihg5t34p5x2plndxppftcub6qwi3c7a6tve6lt65.py # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add] # Source node to ATen node mapping: # add => add # Graph fragment: # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, %arg3_1), kwargs = {}) # SchedulerNodes: [SchedulerNode(name='op0')] triton_unk_fused_add_0 = async_compile.triton('triton_unk_fused_add_0', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties from torch._inductor.runtime import triton_helpers from torch_npu._inductor import npu_triton_heuristics from torch_npu._inductor import npu_triton_helpers from torch_npu._inductor.runtime import NPUDeviceProperties from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math import torch import torch_npu @npu_triton_heuristics.pointwise_npu_index( size_hints=[512], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'x0_numel': 'i32'}, 'device': NPUDeviceProperties(type='npu', index=0, multi_processor_count=48, cc='Ascend910B2C', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'mix_mode': 'aiv'}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_unk_fused_add_0', 'mutated_arg_names': [], 'backend_hash': '0be6d125c3717a68a7dfaaa1b045a8176af516f6b99c17e4b6e1d57a267eb930', 'split_axis': [0], 'tiling_axis': [0], 'axis_names': ['x0'], 'low_dims': {0}, 'numof_reduction_axis': 0, 'split_axis_dtype': torch.float32, 'dual_reduction': False, 'traced_graph_hash': 'TRACED_GRAPH_HASH', 'traced_graph_dir': 'TRACED_GRAPH_DIR', 'store_cubin': False, 'force_disable_caches': False, 'profile_bandwidth_with_do_bench_using_profiling': False}, min_elem_per_thread=0 ) @triton.jit def triton_unk_fused_add_0(in_ptr0, in_ptr1, out_ptr0, x0_numel, X0BLOCK: tl.constexpr, X0BLOCK_SUB: tl.constexpr): x0_offset = tl.program_id(0) * X0BLOCK base_x0= tl.arange(0, X0BLOCK_SUB) loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB for loop_x0 in range(loops_x0): x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0 x0_mask = x0 < min(X0BLOCK+x0_offset, x0_numel) tmp0 = tl.load(in_ptr0 + (x0), x0_mask) tmp1 = tl.load(in_ptr1 + (x0), x0_mask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, x0_mask) ''', device_str='npu') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1 = args args.clear() s0 = arg0_1 s1 = arg1_1 buf0 = empty_strided((s0, s1), (s1, 1), device='npu', dtype=torch.float32) # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add] stream0 = get_raw_stream(0) triton_unk_fused_add_0.run(arg2_1, arg3_1, buf0, 512, grid=grid(512), stream=stream0) return (buf0, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = 32 arg1_1 = 16 arg2_1 = rand_strided((32, 16), (16, 1), device='npu:0', dtype=torch.float32) arg3_1 = rand_strided((32, 16), (16, 1), device='npu:0', dtype=torch.float32) fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module) ``` </details> 而GPU上的代码生成正确: 
一、问题现象(附报错日志上下文): torch.compile存在精度问题 二、软件版本: -- CANN 版本 (e.g., CANN 3.0.x,5.x.x): 8.2.rc1.alpha002 -- Pytorch 版本: 2.6.0 -- Python 版本 (e.g., Python 3.7.5): Python 3.11.9 三、测试步骤: 直接使用`python test.py`运行这个用例,会产生精度问题 ``` import torch def add(x, y): return x+y dy_model = torch.compile(add, dynamic=True, backend="inductor") def test_add(shape0, shape1, failed_it): x = torch.randn(shape0, shape1).npu() y = torch.randn(shape0, shape1).npu() expect = x + y output = dy_model(x, y) assert torch.allclose(expect, output, rtol=1e-5, atol=1e-8), f"exp: {expect}\n out: {output}\n{failed_it}" test_add(32, 16, 0) test_add(16, 16, 0) print("test begin") for i in range(100): test_add(64, 16, i) ``` 四、日志信息: 这个问题产生的原因是,`dynamic=True`的情况下,编译出的triton算子调用时仅使用了一个hardcode的参数而不是一个变量: ```python triton_unk_fused_add_0.run(arg2_1, arg3_1, buf0, 512, grid=grid(512), stream=stream0) ``` 这里的`512`是错误的,应该是`s0 * s1`。 <details> <summary>点击查看完整Triton代码</summary> ```python # AOT ID: ['1_inference'] from ctypes import c_void_p, c_long, c_int import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align from torch import device, empty_strided from torch._inductor.async_compile import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall import torch_npu import triton import triton.language as tl from torch._inductor.runtime.triton_heuristics import ( split_scan_grid, grid_combo_kernels, start_graph, end_graph, cooperative_reduction_grid, ) from torch_npu._inductor.npu_triton_heuristics import grid import torch_npu from torch_npu._inductor import get_current_raw_stream as get_raw_stream from torch_npu._inductor import get_current_raw_stream as get_raw_stream aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor alloc_from_pool = torch.ops.inductor._alloc_from_pool async_compile = AsyncCompile() empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p # kernel path: /tmp/torchinductor_root/nw/cnw6pnjoek2mihg5t34p5x2plndxppftcub6qwi3c7a6tve6lt65.py # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add] # Source node to ATen node mapping: # add => add # Graph fragment: # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, %arg3_1), kwargs = {}) # SchedulerNodes: [SchedulerNode(name='op0')] triton_unk_fused_add_0 = async_compile.triton('triton_unk_fused_add_0', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties from torch._inductor.runtime import triton_helpers from torch_npu._inductor import npu_triton_heuristics from torch_npu._inductor import npu_triton_helpers from torch_npu._inductor.runtime import NPUDeviceProperties from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math import torch import torch_npu @npu_triton_heuristics.pointwise_npu_index( size_hints=[512], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'x0_numel': 'i32'}, 'device': NPUDeviceProperties(type='npu', index=0, multi_processor_count=48, cc='Ascend910B2C', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, warp_size=32), 'constants': {}, 'mix_mode': 'aiv'}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_unk_fused_add_0', 'mutated_arg_names': [], 'backend_hash': '0be6d125c3717a68a7dfaaa1b045a8176af516f6b99c17e4b6e1d57a267eb930', 'split_axis': [0], 'tiling_axis': [0], 'axis_names': ['x0'], 'low_dims': {0}, 'numof_reduction_axis': 0, 'split_axis_dtype': torch.float32, 'dual_reduction': False, 'traced_graph_hash': 'TRACED_GRAPH_HASH', 'traced_graph_dir': 'TRACED_GRAPH_DIR', 'store_cubin': False, 'force_disable_caches': False, 'profile_bandwidth_with_do_bench_using_profiling': False}, min_elem_per_thread=0 ) @triton.jit def triton_unk_fused_add_0(in_ptr0, in_ptr1, out_ptr0, x0_numel, X0BLOCK: tl.constexpr, X0BLOCK_SUB: tl.constexpr): x0_offset = tl.program_id(0) * X0BLOCK base_x0= tl.arange(0, X0BLOCK_SUB) loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB for loop_x0 in range(loops_x0): x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0 x0_mask = x0 < min(X0BLOCK+x0_offset, x0_numel) tmp0 = tl.load(in_ptr0 + (x0), x0_mask) tmp1 = tl.load(in_ptr1 + (x0), x0_mask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, x0_mask) ''', device_str='npu') async_compile.wait(globals()) del async_compile def call(args): arg0_1, arg1_1, arg2_1, arg3_1 = args args.clear() s0 = arg0_1 s1 = arg1_1 buf0 = empty_strided((s0, s1), (s1, 1), device='npu', dtype=torch.float32) # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add] stream0 = get_raw_stream(0) triton_unk_fused_add_0.run(arg2_1, arg3_1, buf0, 512, grid=grid(512), stream=stream0) return (buf0, ) def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance arg0_1 = 32 arg1_1 = 16 arg2_1 = rand_strided((32, 16), (16, 1), device='npu:0', dtype=torch.float32) arg3_1 = rand_strided((32, 16), (16, 1), device='npu:0', dtype=torch.float32) fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1]) return print_performance(fn, times=times, repeat=repeat) if __name__ == "__main__": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module) ``` </details> 而GPU上的代码生成正确: 
评论 (
1
)
登录
后才可以发表评论
状态
TODO
TODO
WIP
DONE
CLOSED
REJECTED
负责人
未设置
标签
未设置
项目
未立项任务
未立项任务
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
分支 (79)
标签 (186)
master
v2.8.0
v2.1.0
v2.6.0
v2.7.1
v2.5.1
v2.6.0-7.1.0
v2.5.1-7.1.0
v2.1.0-7.1.0
revert-merge-23967-master
revert-merge-23966-v2.8.0
revert-merge-23965-v2.7.1
revert-merge-23964-v2.6.0
revert-merge-23962-v2.5.1
revert-merge-23789-v2.1.0
v2.1.0-7.0.0
v2.4.0-7.0.0
v2.4.0
v2.3.1
v2.3.1-7.0.0
v2.5.1-7.0.0
v2.4.0-6.0.0
v2.3.1-6.0.0
v2.1.0-6.0.0
v2.1.0-6.0.rc3
v2.3.1-6.0.rc3
v2.4.0-6.0.rc3
v2.2.0
v1.11.0-6.0.rc1
v2.1.0-6.0.rc1
v2.2.0-6.0.rc1
v1.11.0-6.0.rc2
v2.1.0-6.0.rc2
v2.2.0-6.0.rc2
v2.3.1-6.0.rc2
v1.11.0
v2.1.0-5.0.0
v2.0.1-5.0.0
v1.11.0-5.0.0
v2.0.1
v2.1.0-5.0.rc3
v1.11.0-5.0.rc3
v2.0.1-5.0.rc3
v1.11.0-5.0.rc3.3
v1.8.1
v1.11.0-x1
v1.8.1-5.0.rc3
v1.11.0-5.0.rc2.2
v1.11.0-zj
v1.11.0-5.0.rc2.1
v2.0.1-5.0.rc2
v1.11.0-5.0.rc2
v1.8.1-5.0.rc2
v2.0.0-5.0.rc2
v1.8.1-5.0.rc1
v1.11.0-5.0.rc1
v1.11.0-yd
v1.11.0-xf
v1.11.0-infer
v1.11.0-bigkernel
v1.11.0-host_api
v1.8.1-3.0.0
v1.11.0-5.0.rc2.t100
v1.8.1-5.0.rc2.t100
v1.8.1-3.0.0-dev
v1.11.0-3.0.0
v2.0-dev
v1.8.1-3.0.rc3
v1.5.0-3.0.0
v1.5.0
v1.8.1-3.0.rc1
v1.11.0-3.0.rc3
v1.8.1-3.0.rc2
v1.5.0-3.0.rc3
v1.5.0-3.0.rc2
2.0.4.tr5
v1.5.0-3.0.rc1
2.0.2.tr5
2.0.3.tr5
v7.2.RC1.alpha002-pytorch2.8.0
v7.2.RC1.alpha002-pytorch2.7.1
v7.2.RC1.alpha002-pytorch2.6.0
v7.2.RC1.alpha002-pytorch2.1.0
v7.1.0.2-pytorch2.5.1
v7.1.0.2-pytorch2.6.0
v7.1.0.2-pytorch2.1.0
v7.0.0.1-pytorch2.4.0
v7.0.0.1-pytorch2.1.0
v7.2.RC1.alpha001-pytorch2.8.0
v7.2.RC1.alpha001-pytorch2.7.1
v7.2.RC1.alpha001-pytorch2.6.0
v7.2.RC1.alpha001-pytorch2.5.1
v7.2.RC1.alpha001-pytorch2.1.0
v7.1.0.1-pytorch2.6.0
v7.1.0.1-pytorch2.5.1
v7.1.0.1-pytorch2.1.0
v7.1.0-pytorch2.6.0
v7.1.0-pytorch2.5.1
v7.1.0-pytorch2.1.0
v7.1.RC1.alpha003-pytorch2.6.0
v7.1.RC1.alpha003-pytorch2.5.1
v7.1.RC1.alpha003-pytorch2.1.0
v7.1.RC1.alpha002-pytorch2.7.1
v7.1.RC1.alpha002-pytorch2.6.0
v7.1.RC1.alpha002-pytorch2.5.1
v7.1.RC1.alpha002-pytorch2.4.0
v7.1.RC1.alpha002-pytorch2.3.1
v7.1.RC1.alpha002-pytorch2.1.0
v6.0.0.1-pytorch2.4.0
v6.0.0.1-pytorch2.3.1
v6.0.0.1-pytorch2.1.0
v7.1.RC1.alpha001-pytorch2.6.0
v7.1.RC1.alpha001-pytorch2.5.1
v7.1.RC1.alpha001-pytorch2.4.0
v7.1.RC1.alpha001-pytorch2.3.1
v7.1.RC1.alpha001-pytorch2.1.0
v7.0.0-pytorch2.5.1
v7.0.0-pytorch2.4.0
v7.0.0-pytorch2.3.1
v7.0.RC1.alpha002-pytorch2.6.0
v7.0.0-pytorch2.1.0
v7.0.RC1.alpha002-pytorch2.5.1
v7.0.RC1.alpha002-pytorch2.4.0
v7.0.RC1.alpha002-pytorch2.3.1
v7.0.RC1.alpha002-pytorch2.1.0
v7.0.RC1.alpha001-pytorch2.5.1
v7.0.RC1.alpha001-pytorch2.1.0
v7.0.RC1.alpha001-pytorch2.4.0
v7.0.RC1.alpha001-pytorch2.3.1
v6.0.0-pytorch2.4.0
v6.0.0-pytorch2.3.1
v6.0.0-pytorch2.1.0
v6.0.0.alpha003-pytorch2.4.0
v6.0.0.alpha003-pytorch2.3.1
v6.0.0.alpha003-pytorch2.1.0
v6.0.0.alpha002-pytorch2.4.0
v6.0.0.alpha002-pytorch2.3.1
v6.0.0.alpha002-pytorch2.1.0
v6.0.0.alpha001-pytorch2.5.1
v6.0.rc3-pytorch2.4.0
v6.0.rc3-pytorch2.3.1
v6.0.rc3-pytorch2.1.0
v6.0.0.alpha001-pytorch2.4.0
v6.0.0.alpha001-pytorch2.3.1
v6.0.0.alpha001-pytorch2.1.0
v6.0.rc2.1-pytorch1.11.0
v6.0.rc2.1-pytorch2.3.1
v6.0.rc2.1-pytorch2.2.0
v6.0.rc2.1-pytorch2.1.0
v6.0.rc3.alpha003-pytorch2.3.1
v6.0.rc3.alpha003-pytorch2.1.0
v6.0.rc3.alpha001-pytorch2.4.0
v6.0.rc3.alpha002-pytorch2.3.1
v6.0.rc3.alpha002-pytorch2.2.0
v6.0.rc3.alpha002-pytorch2.1.0
v6.0.rc3.alpha002-pytorch1.11.0
v6.0.rc2-pytorch2.1.0
v6.0.rc2-pytorch2.3.1
v6.0.rc2-pytorch2.2.0
v6.0.rc2-pytorch1.11.0
v6.0.rc3.alpha001-pytorch2.3.1
v6.0.rc3.alpha001-pytorch2.2.0
v6.0.rc3.alpha001-pytorch2.1.0
v6.0.rc3.alpha001-pytorch1.11.0
v6.0.rc2.alpha002-pytorch2.3.1
v6.0.rc2.alpha003-pytorch1.11.0
v6.0.rc2.alpha003-pytorch2.2.0
v6.0.rc2.alpha003-pytorch2.1.0
v6.0.rc1.1-pytorch2.2.0
v6.0.rc1.1-pytorch2.1.0
v6.0.rc1.1-pytorch1.11.0
v5.0.1.2-pytorch1.11.0
v5.0.1.2-pytorch2.1.0
v5.0.1.2-pytorch2.0.1
v6.0.rc2.alpha002-pytorch2.2.0
v6.0.rc2.alpha002-pytorch2.1.0
v6.0.rc2.alpha002-pytorch1.11.0
v6.0.rc1-pytorch2.2.0
v6.0.rc1-pytorch2.1.0
v6.0.rc1-pytorch1.11.0
v6.0.rc2.alpha001-pytorch2.2.0
v6.0.rc2.alpha001-pytorch2.1.0
v6.0.rc2.alpha001-pytorch1.11.0
v6.0.rc1.alpha003-pytorch2.0.1
v6.0.rc1.alpha003-pytorch2.1.0
v5.0.1.1-pytorch2.0.1
v5.0.1.1-pytorch1.11.0
v5.0.1.1-pytorch2.1.0
v6.0.rc1.alpha003-pytorch1.11.0
v6.0.rc1.alpha002-pytorch2.1.0
v6.0.rc1.alpha002-pytorch1.11.0
v6.0.rc1.alpha002-pytorch2.0.1
v6.0.rc1.alpha001-pytorch2.2.0
v5.0.1-pytorch2.1.0
v5.0.1-pytorch2.0.1
v5.0.1-pytorch1.11.0
v6.0.RC1.alpha001-pytorch2.0.1
v6.0.RC1.alpha001-pytorch2.1.0
v6.0.RC1.alpha001-pytorch1.11.0
v5.0.0-pytorch2.1.0
v5.0.0-pytorch2.0.1
v5.0.0-pytorch1.11.0
v5.0.0.alpha003-pytorch2.1.0
v5.0.0.alpha003-pytorch2.0.1
v5.0.0.alpha003-pytorch1.11.0
v5.0.rc3.3-pytorch1.11.0
v5.0.rc3.2-pytorch1.11.0
v5.0.0.alpha002-pytorch2.1.0
v5.0.0.alpha002-pytorch2.0.1
v5.0.0.alpha002-pytorch1.11.0
v5.0.rc3.1-pytorch1.11.0
v5.0.0.alpha001-pytorch2.1.0
v5.0.0.alpha001-pytorch2.0.1
v5.0.0.alpha001-pytorch1.11.0
v5.0.rc3-pytorch2.1.0
v5.0.rc3-pytorch2.0.1
v5.0.rc3-pytorch1.11.0
v5.0.rc3.alpha003-pytorch2.0.1
v5.0.rc3.alpha003-pytorch1.11.0
v5.0.rc3.alpha003-pytorch1.8.1
v5.0.rc2.2-pytorch1.11.0
v5.0.rc2.1-pytorch1.11.0
v5.0.rc3.alpha002-pytorch2.0.1
v5.0.rc3.alpha002-pytorch1.11.0
v5.0.rc3.alpha002-pytorch1.8.1
v5.0.rc2-pytorch2.0.1
v5.0.rc2-pytorch1.11.0
v5.0.rc2-pytorch1.8.1
v5.0.rc3.alpha001-pytorch1.8.1
v5.0.rc3.alpha001-pytorch1.11.0
v5.0.rc2.alpha003-pytorch1.11.0
v5.0.rc2.alpha003-pytorch1.8.1
v5.0.rc2.alpha002-pytorch1.11.0
v5.0.rc2.alpha002-pytorch1.8.1
v5.0.rc1.alpha003-pytorch1.11.0
v5.0.rc1.alpha003-pytorch1.8.1
v5.0.rc1-pytorch1.11.0
v5.0.rc1-pytorch1.8.1
v5.0.rc1.alpha002-pytorch1.11.0
v5.0.rc1.alpha002-pytorch1.8.1
v5.0.rc1.alpha001-pytorch1.11.0
v5.0.rc1.alpha001-pytorch1.8.1
v3.0.0-pytorch1.11.0
v3.0.0-pytorch1.8.1
v3.0.0-pytorch1.5.0
v3.0.alpha006-pytorch1.8.1
v3.0.alpha005-pytorch1.8.1
v3.0.alpha003-pytorch1.8.1
v3.0.rc3-pytorch1.11.0
v3.0.rc3-pytorch1.8.1
v3.0.rc3-pytorch1.5.0
v3.0.rc2-pytorch1.8.1
v3.0.rc2-pytorch1.5.0
v3.0.rc1-pytorch1.8.1
v3.0.rc1-pytorch1.5.0
v2.0.4
v2.0.4-rc2
v2.0.4-rc1
v2.0.3.1
v2.0.3
v2.0.3-rc4
v2.0.3-rc3
v2.0.3-rc2
v2.0.3-rc1
v2.0.2
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
预计工期
(小时)
参与者(1)
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册