From 66d14ab12e57cb2ccffb3e7ae53a9472386db8e1 Mon Sep 17 00:00:00 2001 From: xiajunhua Date: Fri, 22 Aug 2025 16:26:03 +0800 Subject: [PATCH] hstu backward optimize --- .../op_host/hstu_dense_backward.cpp | 12 ++ .../hstu_dense_backward_jagged_kernel.h | 5 +- .../op_kernel/hstu_dense_backward_kernel.h | 109 ++++++++++++------ 3 files changed, 84 insertions(+), 42 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index 5dc0f296..4cb2afc8 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -133,6 +133,18 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB return ge::GRAPH_FAILED; } + if (gradType == ge::DataType::DT_BF16) { + int64_t depth = 4; + tiling.qkMatmul.set_depthA1(depth); + tiling.qkMatmul.set_depthB1(depth); + tiling.qGradMatmul.set_depthA1(depth); + tiling.qGradMatmul.set_depthB1(depth); + tiling.kGradMatmul.set_depthA1(depth); + tiling.kGradMatmul.set_depthB1(depth); + tiling.vGradMatmul.set_depthA1(depth); + tiling.vGradMatmul.set_depthB1(depth); + } + context->SetBlockDim(coreNum); tiling.set_aivNum(vecCoreNum); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h index f11cfe55..403303f2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h @@ -569,10 +569,7 @@ public: __aicore__ inline void CopyQGradToOutput() { SyncAll(); - - if (GetBlockIdx() == 0) { - this->DoCopyQGrad(backwardTilingData->seqOffset); - } + this->DoCopyQGrad(backwardTilingData->seqOffset); } protected: diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index d4e50f98..3b24e77b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -156,7 +156,18 @@ public: // 所有核共享一片globalMemory,且存在累加操作,每次执行需要清理内存防止上次执行结果残留数据影响本次结果 // 多核执行后需要调用SyncAll保证多核间同步正常 - InitGlobalMemory(qGradAccumTemp, qGradAccumTempSpace, static_cast(0)); + int64_t unitClear = qGradAccumTempSpace / aivNum; + int64_t leftClear = qGradAccumTempSpace % aivNum; + uint64_t globalOffset = GetBlockIdx() * unitClear; + uint64_t clearLen = unitClear; + if (GetBlockIdx() == aivNum - 1) { + clearLen += leftClear; + } + GlobalTensor thisBlockQGrad; + thisBlockQGrad.SetGlobalBuffer(reinterpret_cast<__gm__ float *>( + reinterpret_cast<__gm__ uint8_t *>(workspace) + aivNum * totalTempSpaceForOneVec + + globalOffset * sizeof(float)), clearLen); + InitGlobalMemory(thisBlockQGrad, clearLen, static_cast(0)); SyncAll(); } @@ -850,50 +861,72 @@ public: } } - __aicore__ inline void DoCopyQGrad(const uint32_t *seqOffset) + __aicore__ inline void DoCopyBlockQGrad(int64_t thisBatchIdx, int64_t headIdx, int64_t curSeqLen, + const uint32_t *seqOffset) { - for (int64_t batchIdx = 0; batchIdx < batchSize; batchIdx++) { - int64_t curSeqLen = static_cast(seqOffset[batchIdx + 1] - seqOffset[batchIdx]); - for (int64_t headIdx = 0; headIdx < headNum; headIdx++) { - int64_t totalLen = curSeqLen * headDim; - int64_t remain = totalLen; - int64_t thisLen = vecOnceDataNum; - while (remain > 0) { - if (thisLen > remain) { - thisLen = remain; - } + int64_t totalLen = curSeqLen * headDim; + int64_t remain = totalLen; + int64_t thisLen = vecOnceDataNum; + while (remain > 0) { + if (thisLen > remain) { + thisLen = remain; + } - int64_t curOffset = (headNum * seqOffset[batchIdx] * headDim) + (headIdx * totalLen) + - (totalLen - remain); - LocalTensor input = queueVecScoreQK.AllocTensor(); - DataCopy(input, qGradAccumTemp[curOffset], thisLen); - queueVecScoreQK.EnQue(input); + int64_t curOffset = (headNum * seqOffset[thisBatchIdx] * headDim) + (headIdx * totalLen) + + (totalLen - remain); + LocalTensor input = queueVecScoreQK.AllocTensor(); + DataCopy(input, qGradAccumTemp[curOffset], thisLen); + queueVecScoreQK.EnQue(input); - LocalTensor newInput = queueVecScoreQK.DeQue(); - LocalTensor output = queueOutputTemp.AllocTensor(); - if (std::is_same::value) { - DataCopy(output.template ReinterpretCast(), newInput, thisLen); - } else { - Cast(output, newInput, RoundMode::CAST_RINT, thisLen); - } - queueOutputTemp.EnQue(output); - queueVecScoreQK.FreeTensor(newInput); + LocalTensor newInput = queueVecScoreQK.DeQue(); + LocalTensor output = queueOutputTemp.AllocTensor(); + if (std::is_same::value) { + DataCopy(output.template ReinterpretCast(), newInput, thisLen); + } else { + Cast(output, newInput, RoundMode::CAST_RINT, thisLen); + } + queueOutputTemp.EnQue(output); + queueVecScoreQK.FreeTensor(newInput); - LocalTensor newOutput = queueOutputTemp.DeQue(); + LocalTensor newOutput = queueOutputTemp.DeQue(); - uint16_t blockCount = thisLen / headDim; - uint16_t blockLen = headDim * sizeof(qType) / DATA_ALIGN_BYTES; - uint16_t dstStride = (headNum - 1) * headDim * sizeof(qType) / DATA_ALIGN_BYTES; - DataCopyParams copyParams{blockCount, blockLen, 0, dstStride}; + uint16_t blockCount = thisLen / headDim; + uint16_t blockLen = headDim * sizeof(qType) / DATA_ALIGN_BYTES; + uint16_t dstStride = (headNum - 1) * headDim * sizeof(qType) / DATA_ALIGN_BYTES; + DataCopyParams copyParams{blockCount, blockLen, 0, dstStride}; - int64_t curOutOffset = seqOffset[batchIdx] * headNum * headDim + - headIdx * headDim + (totalLen - remain) * headNum; - DataCopy(qGrad[curOutOffset], newOutput, copyParams); - queueOutputTemp.FreeTensor(newOutput); + int64_t curOutOffset = seqOffset[thisBatchIdx] * headNum * headDim + + headIdx * headDim + (totalLen - remain) * headNum; + DataCopy(qGrad[curOutOffset], newOutput, copyParams); + queueOutputTemp.FreeTensor(newOutput); - remain = remain - thisLen; - } - } + remain = remain - thisLen; + } + } + + __aicore__ inline void DoCopyQGrad(const uint32_t *seqOffset) + { + int64_t batchIdx = GetBlockIdx(); + int64_t taskNum = batchSize * headNum; + int64_t coreTask = taskNum / aivNum; + int64_t coreSplitId = taskNum % aivNum; + + int64_t taskNumOfThisCore = 0; + int64_t offsetOfThisCore = 0; + if (batchIdx >= coreSplitId) { + taskNumOfThisCore = coreTask; + offsetOfThisCore = coreSplitId * (coreTask + 1) + (batchIdx - coreSplitId) * coreTask; + } else { + taskNumOfThisCore = coreTask + 1; + offsetOfThisCore = batchIdx * (coreTask + 1); + } + + for (int64_t taskId = 0; taskId < taskNumOfThisCore; taskId++) { + int64_t thisBatchIdx = (offsetOfThisCore + taskId) / headNum; + int64_t headIdx = (offsetOfThisCore + taskId) % headNum; + + int64_t curSeqLen = static_cast(seqOffset[thisBatchIdx + 1] - seqOffset[thisBatchIdx]); + DoCopyBlockQGrad(thisBatchIdx, headIdx, curSeqLen, seqOffset); } } -- Gitee