diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/dense_to_jagged/test_dense_to_jagged.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/dense_to_jagged/test_dense_to_jagged.py index 04c8d4165491198fecd7b8ab7fa4bff30a6afbb1..9d4f7b3cfcbf978a7b2bb812aadf6672186339ed 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/dense_to_jagged/test_dense_to_jagged.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/dense_to_jagged/test_dense_to_jagged.py @@ -29,16 +29,52 @@ logging.getLogger().setLevel(logging.INFO) torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") DENSE_DIM0 = [128, 40] # 测试不同batch大小 -DENSE_DIM1 = [210] # 固定特征维度1 -DENSE_DIM2 = [1, 8] # 固定特征维度2 +DENSE_DIM1 = [210] # 固定特征维度1 +DENSE_DIM2 = [1, 8] # 固定特征维度2 DIM_LIST = list(itertools.product(DENSE_DIM0, DENSE_DIM1, DENSE_DIM2)) DENSE_DATATYPE = [torch.float32, torch.int64] # 测试不同数据类型 -OFFSET_DATATYPE = [torch.int32, torch.int64] # 偏移量数据类型 +OFFSET_DATATYPE = [torch.int32, torch.int64] # 偏移量数据类型 TYPE_LIST = list(itertools.product(DENSE_DATATYPE, OFFSET_DATATYPE)) -def get_result(device, denses, offsets, types, use_output_size): +def dense_to_jagged_wrapper(dense, offsets, total_L=None): + return DenseToJagged.apply(dense, offsets, total_L) + + +def dense_to_jagged(dense, offsets, total_L=None): + if total_L is None: + total_L = offsets[0][-1].item() + out0, out1 = torch.ops.mxrec.dense_to_jagged(dense.to(DEVICE), offsets, total_L) + return out0.to(DEVICE), out1 + + +def jagged_to_padded_dense(values, offsets, max_lengths, padding_value): + return torch.ops.mxrec.jagged_to_padded_dense( + values=values.to(DEVICE), + offsets=offsets, + max_lengths=max(max_lengths), + padding_value=padding_value, + ) + + +class DenseToJagged(torch.autograd.Function): + @staticmethod + def forward(ctx, dense, offsets, total_L=None): + ctx.save_for_backward(*offsets) + out0, out1 = dense_to_jagged(dense, offsets, total_L) + ctx.dense_shape = dense.shape + return out0, out1 + + @staticmethod + def backward(ctx, grad_out0, grad_out1): + offsets = list(ctx.saved_tensors) + max_len = ctx.dense_shape[1] + grad_dense = jagged_to_padded_dense(grad_out0, offsets, [max_len], 0.0) + return grad_dense, None, None + + +def get_result(device, denses, offsets, types, output_size=None): dense_datatype, offset_datatype = types dense_torch = torch.from_numpy(denses).to(dense_datatype).to(device) offsets_torch = torch.from_numpy(offsets).to(offset_datatype).to(device) @@ -46,11 +82,6 @@ def get_result(device, denses, offsets, types, use_output_size): # 计算累积偏移量 jagged_id_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(offsets_torch) - # 获取输出大小(最后一个偏移量即总元素数) - output_size = None - if use_output_size: - output_size = jagged_id_offset[-1] - # 执行核心操作:稠密张量→不规则张量 jagged_embedding = torch.ops.fbgemm.dense_to_jagged(dense_torch, [jagged_id_offset], output_size)[0] return jagged_embedding.cpu() @@ -58,17 +89,100 @@ def get_result(device, denses, offsets, types, use_output_size): @pytest.mark.parametrize("dims", DIM_LIST) @pytest.mark.parametrize("types", TYPE_LIST) -@pytest.mark.parametrize("use_output_size", [True, False]) # 测试是否传入 output_size -def test_dense_to_jagged(dims, types, use_output_size): +@pytest.mark.parametrize("output_size_type", ["none", "exact"]) # 测试不同output_size场景 +def test_dense_to_jagged(dims, types, output_size_type): dense_dim0, dense_dim1, dense_dim2 = dims # 1. 生成随机输入数据 denses = np.random.randn(dense_dim0, dense_dim1, dense_dim2).astype(np.float32) offsets = np.random.randint(0, dense_dim1, dense_dim0) # 生成随机偏移量 + # 计算实际的output_size + actual_size = np.sum(offsets) + + # 根据测试类型设置output_size + output_size = None + if output_size_type == "exact": + output_size = actual_size + # 2. 分别获取CPU和NPU结果 - golden_result = get_result(torch.device("cpu"), denses, offsets, types, use_output_size) - npu_result = get_result(torch.device(DEVICE), denses, offsets, types, use_output_size) + golden_result = get_result(torch.device("cpu"), denses, offsets, types, output_size) + npu_result = get_result(torch.device(DEVICE), denses, offsets, types, output_size) # 3. 结果比对(允许1e-4的误差) + # 正常情况应该完全匹配 result_forward = torch.abs(golden_result[0] - npu_result[0]) < 1e-4 - logging.info(result_forward.all().item()) # 输出是否全部通过验证 \ No newline at end of file + assert result_forward.all().item() + + # ===== 反向传播验证 ===== + # 6. 准备可训练参数 + dense_datatype, offset_datatype = types + dense_torch = torch.from_numpy(denses).to(dense_datatype).to(DEVICE) + offsets_torch = torch.from_numpy(offsets).to(offset_datatype).to(DEVICE) + + # 计算累积偏移量 + jagged_id_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(offsets_torch) + + input_dense_npu = dense_torch.clone().to(torch.float32).to(DEVICE).requires_grad_(True) + input_dense_npu_py = dense_torch.clone().to(torch.float32).to(DEVICE).requires_grad_(True) + + # 7. 计算NPU前向传播 + npu_jagged_for_grad = torch.ops.mxrec.dense_to_jagged( + input_dense_npu, + [jagged_id_offset.to(DEVICE)], + output_size + )[0] + + # 8. 计算NPU python实现前向传播 + npu_py_jagged_for_grad = dense_to_jagged_wrapper( + input_dense_npu_py, + [jagged_id_offset.to(DEVICE)], + output_size + )[0] + + # 9. 生成随机梯度(与输出形状相同) + grad_output = torch.randn_like(npu_jagged_for_grad) + + # 10. NPU反向传播 + npu_jagged_for_grad.backward(grad_output.to(DEVICE)) + npu_grad_input = input_dense_npu.grad + + # 11. NPU python反向传播 + npu_py_jagged_for_grad.backward(grad_output.to(DEVICE)) + npu_py_grad_input = input_dense_npu_py.grad + + # 12. 梯度比对 + assert torch.allclose( + npu_py_grad_input.cpu(), + npu_grad_input.cpu(), + atol=1e-4, + rtol=1e-4 + ), f"NPU python梯度与NPU梯度不匹配\nNPU python梯度:\n{npu_py_grad_input.cpu()}\nNPU梯度:\n{npu_grad_input.cpu()}" + + +# 专门测试异常情况的测试用例 +@pytest.mark.parametrize("dims", [(128, 210, 8)]) # 固定维度简化测试 +@pytest.mark.parametrize("types", [(torch.float32, torch.int32)]) # 固定类型简化测试 +def test_dense_to_jagged_edge_cases(dims, types): + dense_dim0, dense_dim1, dense_dim2 = dims + # 1. 生成随机输入数据 + denses = np.random.randn(dense_dim0, dense_dim1, dense_dim2).astype(np.float32) + offsets = np.random.randint(0, dense_dim1, dense_dim0) + + # 计算实际的output_size + actual_size = np.sum(offsets) + + # 测试output_size为0的情况 + with pytest.raises(RuntimeError): + get_result(torch.device(DEVICE), denses, offsets, types, 0) + + # 测试output_size为负数的情况 + with pytest.raises(RuntimeError): + get_result(torch.device(DEVICE), denses, offsets, types, -1) + + # 测试大于actual_size的output_size情况 + with pytest.raises(RuntimeError): + get_result(torch.device(DEVICE), denses, offsets, types, actual_size + 10) + + # 测试小于actual_size的output_size情况 + with pytest.raises(RuntimeError): + get_result(torch.device(DEVICE), denses, offsets, types, max(1, actual_size - 10)) \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/expand_into_jagged_permute/test_expand_into_jagged_permute.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/expand_into_jagged_permute/test_expand_into_jagged_permute.py new file mode 100644 index 0000000000000000000000000000000000000000..7028a0021d24b115f2dfcf14687a6b0ca283678e --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/expand_into_jagged_permute/test_expand_into_jagged_permute.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import itertools +import logging +import sysconfig + +import pytest +import fbgemm_gpu +import numpy as np +import torch_npu +import torch + +DEVICE = "npu:0" +logging.getLogger().setLevel(logging.INFO) +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/jagged_to_padded_dense/test_jagged_to_padded_dense.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/jagged_to_padded_dense/test_jagged_to_padded_dense.py index d0e40b595e410a9275ed4d1645f684082cdc8028..f79c63109a5bb28acb5c8243d984c04da727e6ac 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/jagged_to_padded_dense/test_jagged_to_padded_dense.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/jagged_to_padded_dense/test_jagged_to_padded_dense.py @@ -28,6 +28,32 @@ torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") DEVICE = "npu:0" +def jagged_to_padded_dense_wrapper(values, offsets, max_lengths, padding_value): + return JaggedToPaddedDense.apply(values, offsets, max_lengths, padding_value) + + +class JaggedToPaddedDense(torch.autograd.Function): + @staticmethod + def forward(ctx, values, offsets, max_lengths, padding_value): + ctx.save_for_backward(*offsets) + ctx.total_L = values.shape[0] + return torch.ops.mxrec.jagged_to_padded_dense_forward( + values=values.to(DEVICE), + offsets=offsets, + max_lengths=max(max_lengths), + padding_value=padding_value, + ) + + @staticmethod + def backward(ctx, grad_output): + offsets = list(ctx.saved_tensors) + total_L = ctx.total_L + if total_L is None: + total_L = offsets[0][-1].item() + grad_values = torch.ops.mxrec.jagged_to_padded_dense_backward(grad_output.to(DEVICE), offsets, total_L) + return grad_values, None, None, None + + def generate_jagged_tensor(batch_size, max_seq_len, num_heads, attention_dim): """ 生成不规则(Jagged)张量测试数据 @@ -79,6 +105,7 @@ def test_jagged_to_padded_dense(batch_size, 2. 使用FBGEMM的CPU实现计算基准结果 3. 调用NPU算子计算结果 4. 对比两者差异(允许1e-4的误差) + 5. 新增: 验证自动求导功能 """ # 1. 生成测试数据 jagged_tensor, seq_offsets, total_sequences = generate_jagged_tensor( @@ -88,6 +115,7 @@ def test_jagged_to_padded_dense(batch_size, input_flat = jagged_tensor.reshape(total_sequences, num_heads * attention_dim) fbgemm_offsets = torch.from_numpy(seq_offsets) + # ===== 前向传播验证 ===== # 3. 调用FBGEMM CPU实现 fbgemm_dense = torch.ops.fbgemm.jagged_to_padded_dense( input_flat, @@ -104,7 +132,7 @@ def test_jagged_to_padded_dense(batch_size, 0.0 ) - # 5. 结果比对 + # 5. 前向传播结果比对 assert torch.allclose( fbgemm_dense.reshape(-1), npu_dense.cpu().reshape(-1), @@ -112,3 +140,42 @@ def test_jagged_to_padded_dense(batch_size, rtol=1e-4 ), f"NPU结果与FBGEMM CPU结果不匹配\nFBGEMM:\n{fbgemm_dense}\nNPU:\n{npu_dense.cpu()}" + # ===== 反向传播验证 ===== + # 6. 准备可训练参数 + input_flat_npu = input_flat.clone().to(DEVICE).requires_grad_(True) + input_flat_npu_py = input_flat.clone().to(DEVICE).requires_grad_(True) + + # 7. 计算NPU前向传播 + npu_dense_for_grad = torch.ops.mxrec.jagged_to_padded_dense( + input_flat_npu, + [fbgemm_offsets.to(DEVICE)], + [max_seq_len] if use_list_max_lengths else max_seq_len, + 0.0 + ) + + # 8. 计算NPU python实现前向传播 + npu_py_dense_for_grad = jagged_to_padded_dense_wrapper( + input_flat_npu_py, + [fbgemm_offsets.to(DEVICE)], + [max_seq_len], + 0.0 + ) + + # 9. 生成随机梯度(与输出形状相同) + grad_output = torch.randn_like(npu_dense_for_grad) + + # 10. NPU反向传播 + npu_dense_for_grad.backward(grad_output.to(DEVICE)) + npu_grad_input = input_flat_npu.grad + + # 11. NPU python反向传播 + npu_py_dense_for_grad.backward(grad_output.to(DEVICE)) + npu_py_grad_input = input_flat_npu_py.grad + + # 12. 梯度比对 + assert torch.allclose( + npu_py_grad_input.cpu(), + npu_grad_input.cpu(), + atol=1e-4, + rtol=1e-4 + ), f"NPU python梯度与NPU梯度不匹配\nNPU python梯度:\n{npu_py_grad_input.cpu()}\nNPU梯度:\n{npu_grad_input.cpu()}" \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/dense_to_jagged/dense_to_jagged.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/dense_to_jagged/dense_to_jagged.cpp index 65dc56b6955775b4021199fecf42a7a825d238b3..94d7ea267e719b8a3a882b57e01d7773b711365d 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/dense_to_jagged/dense_to_jagged.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/dense_to_jagged/dense_to_jagged.cpp @@ -11,9 +11,36 @@ #include #include "../common/pytorch_npu_helper.hpp" +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::Variable; using tensor_list = std::vector; using namespace at; +at::Tensor jagged_to_padded_dense_forward_npu(const at::Tensor& values, + const tensor_list& offsets, + const int64_t max_lengths, + const double padding_value) +{ + TORCH_CHECK(values.dim() == 2, + "values must be a 2D tensor, but got ", values.dim(), "D tensor"); + TORCH_CHECK(offsets.size() == 1, + "offsets must contain exactly 1 tensor, but got ", offsets.size(), " tensors"); + const auto& offset_tensor = offsets[0]; + TORCH_CHECK(offset_tensor.defined(), + "offset tensor must be defined (non-null)"); + TORCH_CHECK(offset_tensor.dim() == 1, + "offset tensor must be 1D, but got ", offset_tensor.dim(), "D"); + TORCH_CHECK(max_lengths > 0, "max_lengths must be positive, but got ", max_lengths); + const at::OptionalDeviceGuard guard(device_of(values)); + auto values_contin = values.contiguous(); + auto D = values.size(-1); + auto output = + at::full({offsets[0].size(0) - 1, max_lengths, values.size(1)}, padding_value, values.options()); + EXEC_NPU_CMD(aclnnJaggedToPaddedDense, values_contin, offsets[0], max_lengths, padding_value, output); + return output; +}; + // 目前只支持3维的dense at::Tensor dense_to_jagged_forward_npu(const at::Tensor& dense, const tensor_list& offsets, @@ -53,6 +80,63 @@ std::tuple dense_to_jagged_npu(const at::Tensor& dense, return {dense_to_jagged_forward_npu(dense, offsets, total_L), offsets}; }; +// 反向算子 - 使用jagged_to_padded_dense作为反向 +at::Tensor dense_to_jagged_backward_npu(const at::Tensor& values, + const tensor_list& offsets, + const int64_t max_lengths, + const double padding_value) +{ + return jagged_to_padded_dense_forward_npu(values, offsets, max_lengths, padding_value); +}; + +// 自动求导Function类 +class DenseToJaggedFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext* ctx, + const at::Tensor& dense, + const tensor_list& offsets, + const c10::optional total_L) + { + at::AutoDispatchBelowADInplaceOrView guard; + ctx->save_for_backward({dense, offsets[0]}); + + return dense_to_jagged_forward_npu(dense, offsets, total_L); + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) + { + auto grad_output = grad_outputs[0]; + auto saved = ctx->get_saved_variables(); + auto dense = saved[0]; + auto offsets_tensor = saved[1]; + + tensor_list offsets = {offsets_tensor}; + int64_t max_len = dense.size(1); + + // 调用jagged_to_padded_dense作为反向 + auto grad_dense = dense_to_jagged_backward_npu( + grad_output, offsets, max_len, 0.0); + + // 返回梯度:grad_dense, None, None + return {grad_dense, Variable(), Variable()}; + } +}; + +// 自动求导接口 +at::Tensor dense_to_jagged_autograd(const at::Tensor& dense, + const tensor_list& offsets, + const c10::optional total_L) +{ + return DenseToJaggedFunction::apply(dense, offsets, total_L); +} + +std::tuple dense_to_jagged_npu_autograd(const at::Tensor& dense, + const tensor_list& offsets, + const c10::optional total_L) +{ + return {dense_to_jagged_autograd(dense, offsets, total_L), offsets}; +}; + TORCH_LIBRARY_FRAGMENT(mxrec, m) { m.def("dense_to_jagged_forward(Tensor dense, " @@ -62,16 +146,34 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) m.def("dense_to_jagged(Tensor dense, " " Tensor[] offsets, " " SymInt? total_L=None) -> (Tensor, Tensor[])"); + + m.def("dense_to_jagged_backward(Tensor values, " + " Tensor[] offsets, " + " int max_lengths, " + " float padding_value) -> Tensor"); } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) { m.impl("dense_to_jagged_forward", &dense_to_jagged_forward_npu); m.impl("dense_to_jagged", &dense_to_jagged_npu); + m.impl("dense_to_jagged_backward", &dense_to_jagged_backward_npu); } TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) { m.impl("dense_to_jagged_forward", &dense_to_jagged_forward_npu); m.impl("dense_to_jagged", &dense_to_jagged_npu); + m.impl("dense_to_jagged_backward", &dense_to_jagged_backward_npu); } + +// 注册自动求导实现 +TORCH_LIBRARY_IMPL(mxrec, AutogradPrivateUse1, m) +{ + m.impl("dense_to_jagged", &dense_to_jagged_npu_autograd); +} + +TORCH_LIBRARY_IMPL(fbgemm, AutogradPrivateUse1, m) +{ + m.impl("dense_to_jagged", &dense_to_jagged_npu_autograd); +} \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/CMakeLists.txt b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2479e966db40d3990b42da26a1fa631e570c0518 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.10) + +project(expand_into_jagged_permute) + +include("${CMAKE_CURRENT_LIST_DIR}/../common/CommonTorchOpConfig.cmake") + +add_library(expand_into_jagged_permute SHARED expand_into_jagged_permute.cpp) + +target_compile_features(expand_into_jagged_permute PRIVATE cxx_std_17) +target_compile_options(expand_into_jagged_permute PRIVATE -D_GLIBCXX_USE_CXX11_ABI=${GLIBCXX_ABI}) + +target_link_libraries(expand_into_jagged_permute PUBLIC c10 torch torch_cpu torch_npu c_sec) \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/build_ops.sh b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/build_ops.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c4180e80563fe07d004cf59a821add657c01305 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/build_ops.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e +rm -rf build +mkdir -p build +cmake -B build +cmake --build build -j +chmod 550 ./build/*.so +export LD_LIBRARY_PATH=$ASCEND_OPP_PATH/vendors/expand_into_jagged_permute/op_api/lib:$LD_LIBRARY_PATH diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/expand_into_jagged_permute.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/expand_into_jagged_permute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b2da136a93315ccbb9a9642f2333cba8d31d42ce --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/expand_into_jagged_permute/expand_into_jagged_permute.cpp @@ -0,0 +1,19 @@ +/** + * @file expand_into_jagged_permute.cpp + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include +#include + +#include "../common/pytorch_npu_helper.hpp" +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::Variable; +using tensor_list = std::vector; +using namespace at; + diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/jagged_to_padded_dense/jagged_to_padded_dense.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/jagged_to_padded_dense/jagged_to_padded_dense.cpp index 14e5d0536c53ef143e2f3db683f4547207b31075..0c77b2964d3139d3c448d28a14ebff851f6d07a2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/jagged_to_padded_dense/jagged_to_padded_dense.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/jagged_to_padded_dense/jagged_to_padded_dense.cpp @@ -1,5 +1,5 @@ /** -* @file jagged_to_padded_dense.cpp + * @file jagged_to_padded_dense.cpp * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * @@ -11,6 +11,9 @@ #include #include "../common/pytorch_npu_helper.hpp" +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using torch::autograd::Variable; using tensor_list = std::vector; using namespace at; @@ -98,31 +101,109 @@ at::Tensor jagged_to_padded_dense_backward_npu(const at::Tensor& grad_output, return dense_to_jagged_forward_npu(grad_output, offsets, total_L); }; +// 自动求导Function类 +class JaggedToPaddedDenseV1 : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext* ctx, + const at::Tensor& values, + const tensor_list& offsets, + const int64_t max_lengths, + const double padding_value) + { + at::AutoDispatchBelowADInplaceOrView guard; + ctx->save_for_backward({values, offsets[0]}); + ctx->saved_data["max_lengths"] = max_lengths; + ctx->saved_data["padding_value"] = padding_value; + return jagged_to_padded_dense_forward_npu_v1(values, offsets, max_lengths, padding_value); + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) + { + auto grad_output = grad_outputs[0]; + auto saved = ctx->get_saved_variables(); + auto values = saved[0]; + auto offsets_tensor = saved[1]; + tensor_list offsets = {offsets_tensor}; + + int64_t total_L = values.size(0); + auto grad_input = jagged_to_padded_dense_backward_npu(grad_output, offsets, total_L); + return {grad_input, Variable(), Variable(), Variable()}; + } +}; + +class JaggedToPaddedDenseV2 : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext* ctx, + const at::Tensor& values, + const tensor_list& offsets, + const at::IntArrayRef max_lengths, + const double padding_value) + { + at::AutoDispatchBelowADInplaceOrView guard; + ctx->save_for_backward({values, offsets[0]}); + ctx->saved_data["max_lengths"] = max_lengths[0]; // 保存第一个元素 + ctx->saved_data["padding_value"] = padding_value; + return jagged_to_padded_dense_forward_npu_v2(values, offsets, max_lengths, padding_value); + } + + static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) + { + auto grad_output = grad_outputs[0]; + auto saved = ctx->get_saved_variables(); + auto values = saved[0]; + auto offsets_tensor = saved[1]; + tensor_list offsets = {offsets_tensor}; + + int64_t total_L = values.size(0); + auto grad_input = jagged_to_padded_dense_backward_npu(grad_output, offsets, total_L); + return {grad_input, Variable(), Variable(), Variable()}; + } +}; + +// 自动求导接口 +at::Tensor jagged_to_padded_dense_npu_v1_autograd(const at::Tensor& values, + const tensor_list& offsets, + const int64_t max_lengths, + const double padding_value) +{ + return JaggedToPaddedDenseV1::apply(values, offsets, max_lengths, padding_value); +} + +at::Tensor jagged_to_padded_dense_npu_v2_autograd(const at::Tensor& values, + const tensor_list& offsets, + const at::IntArrayRef max_lengths, + const double padding_value) +{ + return JaggedToPaddedDenseV2::apply(values, offsets, max_lengths, padding_value); +} + } // namespace fbgemm_npu TORCH_LIBRARY_FRAGMENT(mxrec, m) { m.def("jagged_to_padded_dense.v1(Tensor values, " - " Tensor[] offsets, " - " int max_lengths, " - " float padding_value) -> Tensor"); + " Tensor[] offsets, " + " int max_lengths, " + " float padding_value) -> Tensor"); // 新增int[]的max_lengths m.def("jagged_to_padded_dense.v2(Tensor values, " - " Tensor[] offsets, " - " int[] max_lengths, " - " float padding_value) -> Tensor"); + " Tensor[] offsets, " + " int[] max_lengths, " + " float padding_value) -> Tensor"); m.def("jagged_to_padded_dense_forward.v1(Tensor values, " - " Tensor[] offsets, " - " int max_lengths, " - " float padding_value) -> Tensor"); + " Tensor[] offsets, " + " int max_lengths, " + " float padding_value) -> Tensor"); // 新增int[]的max_lengths m.def("jagged_to_padded_dense_forward.v2(Tensor values, " - " Tensor[] offsets, " - " int[] max_lengths, " - " float padding_value) -> Tensor"); + " Tensor[] offsets, " + " int[] max_lengths, " + " float padding_value) -> Tensor"); - m.def("jagged_to_padded_dense_backward(Tensor grad, Tensor[] offsets, int total_L) -> Tensor"); + m.def("jagged_to_padded_dense_backward(Tensor grad, " + " Tensor[] offsets, " + " int total_L) -> Tensor"); } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) @@ -158,3 +239,16 @@ TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) TORCH_FN(fbgemm_npu::jagged_to_padded_dense_forward_npu_v2))); m.impl("jagged_to_padded_dense_backward", &fbgemm_npu::jagged_to_padded_dense_backward_npu); } + +// 注册自动求导实现 +TORCH_LIBRARY_IMPL(mxrec, AutogradPrivateUse1, m) +{ + m.impl("jagged_to_padded_dense.v1", TORCH_FN(fbgemm_npu::jagged_to_padded_dense_npu_v1_autograd)); + m.impl("jagged_to_padded_dense.v2", TORCH_FN(fbgemm_npu::jagged_to_padded_dense_npu_v2_autograd)); +} + +TORCH_LIBRARY_IMPL(fbgemm, AutogradPrivateUse1, m) +{ + m.impl("jagged_to_padded_dense.v1", TORCH_FN(fbgemm_npu::jagged_to_padded_dense_npu_v1_autograd)); + m.impl("jagged_to_padded_dense.v2", TORCH_FN(fbgemm_npu::jagged_to_padded_dense_npu_v2_autograd)); +} \ No newline at end of file