diff --git a/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adagrad_unweighted_exact_kernel_unique.h b/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adagrad_unweighted_exact_kernel_unique.h index 7b049438ccebfbf1686a2648aec515f676acd896..a42a79aa0f3af12eea6be33bdbceeaffaa3ad83b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adagrad_unweighted_exact_kernel_unique.h +++ b/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adagrad_unweighted_exact_kernel_unique.h @@ -147,7 +147,7 @@ private: GlobalTensor dynamicWeightsGT; GlobalTensor dynamicM1GT; - int numOfOut = 3; + int numOfOut = 2; int indicesNumOneBlock; int64_t thisMoment1Index; diff --git a/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adam_unweighted_exact_kernel_unique.h b/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adam_unweighted_exact_kernel_unique.h index bf3e5e121e377bf4a77b0bd4f885e2b2dbeeec52..eaa70f7fcf2fdad6d4fc4c62b58e525b837ef658 100644 --- a/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adam_unweighted_exact_kernel_unique.h +++ b/mxrec_add_ons/rec_for_torch/operators/backward_codegen_adagrad_unweighted_exact/op_kernel/backward_codegen_adam_unweighted_exact_kernel_unique.h @@ -26,6 +26,7 @@ using namespace BackwardCodegenUnweightedExact; using namespace BackwardCodegenUnweightedExactUnique; namespace BackwardCodegenUnweightedAdamExactUnique { constexpr int NUM_OUTPUTS = 3; // grad, momentum1, momentum2 +constexpr int BUFFER_NUM = 1; // double buffer class BackwardCodegenAdamUnweightedExactKernelUnique : public BackwardCodegenUnweightedExactKernelUnique { public: __aicore__ inline BackwardCodegenAdamUnweightedExactKernelUnique() {} @@ -45,6 +46,8 @@ public: beta2sqrt = tilingData.beta2sqrt; numOfOut = NUM_OUTPUTS; + maxProcessLen = ubCanUsed / BUFFER_NUM / USE_QUEUE_NUM / NUM_OUTPUTS; // 最大处理embed dim + maxProcessLen = maxProcessLen / FLOAT_ALIGNMENT * FLOAT_ALIGNMENT; // 32B对齐 indicesNumOneBlock = blockLen / numOfOut / maxD; if (indicesNumOneBlock >= MAX_ARGS_PIPE_LEN) { indicesNumOneBlock = MAX_ARGS_PIPE_LEN; @@ -95,33 +98,33 @@ public: Muls(outLt[thisGradIndex], outLt[thisGradIndex], stepSize, totalLen); } - __aicore__ inline void CopyInNormal(int64_t *updateArgs, int thisLen, int embedDim) + __aicore__ inline void CopyInNormal(int64_t *updateArgs, int thisLen, int embedDim, int splitDim, int embOffset) { __gm__ int64_t* weightsOffsetsPtr = (__gm__ int64_t*)weightsOffsets; LocalTensor inputLt = queIn.template DeQue(); for (int64_t i = 0; i < thisLen; i++) { int64_t thisIndForThisTable = uniqueIdGT.GetValue(thisTableOffset + i); int64_t thisWeightOffset = *(weightsOffsetsPtr + tableIndex); - updateArgs[i] = thisWeightOffset + thisIndForThisTable * embedDim; - DataCopy(inputLt[i * maxD + thisMoment1Index], momentum1DevGT[updateArgs[i]], embedDim); - DataCopy(inputLt[i * maxD + thisMoment2Index], momentum2DevGT[updateArgs[i]], embedDim); + updateArgs[i] = thisWeightOffset + thisIndForThisTable * embedDim + embOffset; + DataCopy(inputLt[i * maxProcessLen + thisMoment1Index], momentum1DevGT[updateArgs[i]], splitDim); + DataCopy(inputLt[i * maxProcessLen + thisMoment2Index], momentum2DevGT[updateArgs[i]], splitDim); } queIn.template EnQue(inputLt); } - __aicore__ inline void CopyOutNormal(int64_t *outOffset, int thisLen, int embedDim) + __aicore__ inline void CopyOutNormal(int64_t *outOffset, int thisLen, int splitDim) { LocalTensor newOutLt = queOut.template DeQue(); SetAtomicAdd(); for (int64_t i = 0; i < thisLen; i++) { - int thisGradIndex = i * maxD; - DataCopy(weightsDevOutGT[outOffset[i]], newOutLt[thisGradIndex], embedDim); + int thisGradIndex = i * maxProcessLen; + DataCopy(weightsDevOutGT[outOffset[i]], newOutLt[thisGradIndex], splitDim); } SetAtomicNone(); for (int64_t i = 0; i < thisLen; i++) { - int thisGradIndex = i * maxD; - DataCopy(momentum1DevOutGT[outOffset[i]], newOutLt[thisMoment1Index + thisGradIndex], embedDim); - DataCopy(momentum2DevOutGT[outOffset[i]], newOutLt[thisMoment2Index + thisGradIndex], embedDim); + int thisGradIndex = i * maxProcessLen; + DataCopy(momentum1DevOutGT[outOffset[i]], newOutLt[thisMoment1Index + thisGradIndex], splitDim); + DataCopy(momentum2DevOutGT[outOffset[i]], newOutLt[thisMoment2Index + thisGradIndex], splitDim); } queOut.template FreeTensor(newOutLt); } @@ -130,44 +133,56 @@ public: { __gm__ int32_t* dOffsetsPtr = (__gm__ int32_t*)dOffsets; - indicesNumOneBlock = blockLen / numOfOut / maxD; + indicesNumOneBlock = blockLen / numOfOut / maxProcessLen; if (indicesNumOneBlock >= MAX_ARGS_PIPE_LEN) { indicesNumOneBlock = MAX_ARGS_PIPE_LEN; } - int64_t thisLen = thisTableLen; - int64_t remain = thisTableLen; int64_t embedDim = *(dOffsetsPtr + tableIndex + 1) - *(dOffsetsPtr + tableIndex); - - while (remain > 0) { - if (remain > indicesNumOneBlock) { - thisLen = indicesNumOneBlock; + int64_t embOffset = 0; + int64_t splitDim = embedDim; + while (embOffset < embedDim) { // embedim 拆分计算 + if (splitDim > maxProcessLen) { + splitDim = maxProcessLen; } - - int calcLen = thisLen * maxD; - thisMoment1Index = calcLen * M1_INDEX; - thisMoment2Index = calcLen * M2_INDEX; - remain -= thisLen; - LocalTensor inputLt = queIn.template AllocTensor(); - LocalTensor outputLt = queOut.template AllocTensor(); - - // copyIn - CpGm2Local(inputLt, outGT[thisTableOffset * maxD], calcLen); - queIn.template EnQue(inputLt); - // CopyIn - int64_t updateArgs[MAX_ARGS_PIPE_LEN]; - CopyInNormal(updateArgs, thisLen, embedDim); - // compute - inputLt = queIn.template DeQue(); - - ComputeAdam(inputLt, outputLt, calcLen); - queOut.template EnQue(outputLt); - - // copyOut - CopyOutNormal(updateArgs, thisLen, embedDim); - - queIn.template FreeTensor(inputLt); - thisTableOffset += thisLen; - thisLen = remain; + int64_t thisLen = thisTableLen; + int64_t remain = thisTableLen; + int64_t tableOffset = thisTableOffset; + while (remain > 0) { + if (remain > indicesNumOneBlock) { + thisLen = indicesNumOneBlock; + } + + int calcLen = thisLen * maxProcessLen; + thisMoment1Index = calcLen * M1_INDEX; + thisMoment2Index = calcLen * M2_INDEX; + remain -= thisLen; + LocalTensor inputLt = queIn.template AllocTensor(); + LocalTensor outputLt = queOut.template AllocTensor(); + + // copyIn Grad + uint16_t blockNum = splitDim / FLOAT_ALIGNMENT; + uint16_t srcGap = (embedDim - splitDim) / FLOAT_ALIGNMENT; + DataCopyParams param = {static_cast(thisLen), blockNum, srcGap, 0}; + DataCopy(inputLt, outGT[tableOffset * maxD + embOffset], param); + queIn.template EnQue(inputLt); + // CopyIn weights momentum + int64_t updateArgs[MAX_ARGS_PIPE_LEN]; + CopyInNormal(updateArgs, thisLen, embedDim, splitDim, embOffset); + // compute + inputLt = queIn.template DeQue(); + + ComputeAdam(inputLt, outputLt, calcLen); + queOut.template EnQue(outputLt); + + // copyOut + CopyOutNormal(updateArgs, thisLen, splitDim); + + queIn.template FreeTensor(inputLt); + tableOffset += thisLen; + thisLen = remain; + } + embOffset += splitDim; + splitDim = embedDim - embOffset; } } __aicore__ inline void Compute(Args args) @@ -203,6 +218,7 @@ private: int64_t iter; int numOfOut; int indicesNumOneBlock; + int maxProcessLen; int64_t thisMoment1Index; int64_t thisMoment2Index;