From 3b7f82df586f3270d124d1a728d7bec7c6716d4e Mon Sep 17 00:00:00 2001 From: Fred Chow Date: Wed, 26 May 2021 20:14:56 -0700 Subject: [PATCH] lower array nodes in the ssatab phase of mapleme when language is not java --- src/mapleall/maple_ir/include/mir_lower.h | 2 + src/mapleall/maple_ir/src/mir_lower.cpp | 195 ++++++++++++++++++++++ src/mapleall/maple_me/include/ssa_tab.h | 2 +- src/mapleall/maple_me/src/me_function.cpp | 4 +- src/mapleall/maple_me/src/ssa_tab.cpp | 45 +++-- 5 files changed, 228 insertions(+), 20 deletions(-) diff --git a/src/mapleall/maple_ir/include/mir_lower.h b/src/mapleall/maple_ir/include/mir_lower.h index 2f90d5c4ac..fb7969347c 100644 --- a/src/mapleall/maple_ir/include/mir_lower.h +++ b/src/mapleall/maple_ir/include/mir_lower.h @@ -66,6 +66,8 @@ class MIRLower { BlockNode *LowerBlock(BlockNode&); void LowerBrCondition(BlockNode &block); void LowerFunc(MIRFunction &func); + BaseNode *LowerFarray(ArrayNode *array); + BaseNode *LowerCArray(ArrayNode *array); void ExpandArrayMrt(MIRFunction &func); IfStmtNode *ExpandArrayMrtIfBlock(IfStmtNode &node); WhileStmtNode *ExpandArrayMrtWhileBlock(WhileStmtNode &node); diff --git a/src/mapleall/maple_ir/src/mir_lower.cpp b/src/mapleall/maple_ir/src/mir_lower.cpp index 4e249212b4..9df9582421 100644 --- a/src/mapleall/maple_ir/src/mir_lower.cpp +++ b/src/mapleall/maple_ir/src/mir_lower.cpp @@ -17,6 +17,18 @@ #define DO_LT_0_CHECK 1 namespace maple { + +static constexpr uint64 RoundUpConst(uint64 offset, uint32 align) { + return (-align) & (offset + align - 1); +} + +static inline uint64 RoundUp(uint64 offset, uint32 align) { + if (align == 0) { + return offset; + } + return RoundUpConst(offset, align); +} + LabelIdx MIRLower::CreateCondGotoStmt(Opcode op, BlockNode &blk, const IfStmtNode &ifStmt) { auto *brStmt = mirModule.CurFuncCodeMemPool()->New(op); brStmt->SetOpnd(ifStmt.Opnd(), 0); @@ -346,6 +358,189 @@ void MIRLower::LowerFunc(MIRFunction &func) { func.SetBody(newBody); } +BaseNode *MIRLower::LowerFarray(ArrayNode *array) { + auto *farrayType = static_cast(array->GetArrayType(GlobalTables::GetTypeTable())); + size_t eSize = GlobalTables::GetTypeTable().GetTypeFromTyIdx(farrayType->GetElemTyIdx())->GetSize(); + MIRType &arrayType = *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType())); + /* how about multi-dimension array? */ + if (array->GetIndex(0)->GetOpCode() == OP_constval) { + const ConstvalNode *constvalNode = static_cast(array->GetIndex(0)); + if (constvalNode->GetConstVal()->GetKind() == kConstInt) { + const MIRIntConst *pIntConst = static_cast(constvalNode->GetConstVal()); + CHECK_FATAL(mirModule.IsJavaModule() || pIntConst->GetValue() >= 0, "Array index should >= 0."); + int64 eleOffset = pIntConst->GetValue() * eSize; + + BaseNode *baseNode = array->GetBase(); + if (eleOffset == 0) { + return baseNode; + } + + MIRIntConst *eleConst = + GlobalTables::GetIntConstTable().GetOrCreateIntConst(eleOffset, arrayType); + BaseNode *offsetNode = mirModule.CurFuncCodeMemPool()->New(eleConst); + offsetNode->SetPrimType(array->GetPrimType()); + + BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New(OP_add); + rAdd->SetPrimType(array->GetPrimType()); + rAdd->SetOpnd(baseNode, 0); + rAdd->SetOpnd(offsetNode, 1); + return rAdd; + } + } + + BaseNode *rMul = nullptr; + + BaseNode *baseNode = array->GetBase(); + + BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New(OP_add); + rAdd->SetPrimType(array->GetPrimType()); + rAdd->SetOpnd(baseNode, 0); + rAdd->SetOpnd(rMul, 1); + return rAdd; +} + +BaseNode *MIRLower::LowerCArray(ArrayNode *array) { + MIRType *aType = array->GetArrayType(GlobalTables::GetTypeTable()); + if (aType->GetKind() == kTypeJArray) { + return array; + } + if (aType->GetKind() == kTypeFArray) { + return LowerFarray(array); + } + + MIRArrayType *arrayType = static_cast(aType); + /* There are two cases where dimension > 1. + * 1) arrayType->dim > 1. Process the current arrayType. (nestedArray = false) + * 2) arrayType->dim == 1, but arraytype->eTyIdx is another array. (nestedArray = true) + * Assume at this time 1) and 2) cannot mix. + * Along with the array dimension, there is the array indexing. + * It is allowed to index arrays less than the dimension. + * This is dictated by the number of indexes. + */ + bool nestedArray = false; + int dim = arrayType->GetDim(); + MIRType *innerType = nullptr; + MIRArrayType *innerArrayType = nullptr; + uint64 elemSize = 0; + if (dim == 1) { + innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(arrayType->GetElemTyIdx()); + if (innerType->GetKind() == kTypeArray) { + nestedArray = true; + do { + innerArrayType = static_cast(innerType); + elemSize = RoundUp(innerArrayType->GetElemType()->GetSize(), + arrayType->GetElemType()->GetAlign()); + dim++; + innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx()); + } while (innerType->GetKind() == kTypeArray); + } + } + + int32 numIndex = static_cast(array->NumOpnds()) - 1; + MIRArrayType *curArrayType = arrayType; + BaseNode *resNode = array->GetIndex(0); + if (dim > 1) { + BaseNode *prevNode = nullptr; + for (int i = 0; (i < dim) && (i < numIndex); i++) { + uint32 mpyDim = 1; + if (nestedArray) { + CHECK_FATAL(arrayType->GetSizeArrayItem(0) > 0, "Zero size array dimension"); + innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(curArrayType->GetElemTyIdx()); + curArrayType = static_cast(innerType); + while (innerType->GetKind() == kTypeArray) { + innerArrayType = static_cast(innerType); + mpyDim *= innerArrayType->GetSizeArrayItem(0); + innerType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(innerArrayType->GetElemTyIdx()); + } + } else { + CHECK_FATAL(arrayType->GetSizeArrayItem(static_cast(i)) > 0, "Zero size array dimension"); + for (int j = i + 1; j < dim; j++) { + mpyDim *= arrayType->GetSizeArrayItem(static_cast(j)); + } + } + + BaseNode *index = static_cast(array->GetIndex(static_cast(i))); + bool isConst = false; + int64 indexVal = 0; + if (index->op == OP_constval) { + ConstvalNode *constNode = static_cast(index); + indexVal = (static_cast(constNode->GetConstVal()))->GetValue(); + isConst = true; + MIRIntConst *newConstNode = mirModule.GetMemPool()->New( + indexVal * static_cast(mpyDim), + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()))); + BaseNode *newValNode = mirModule.CurFuncCodeMemPool()->New(newConstNode); + newValNode->SetPrimType(array->GetPrimType()); + if (i == 0) { + prevNode = newValNode; + continue; + } else { + resNode = newValNode; + } + } + if (i > 0 && isConst == false) { + resNode = array->GetIndex(static_cast(i)); + } + + BaseNode *mpyNode; + if (isConst) { + MIRIntConst *mulConst = mirModule.GetMemPool()->New( + static_cast(mpyDim) * indexVal, + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()))); + BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New(mulConst); + mulSize->SetPrimType(array->GetPrimType()); + mpyNode = mulSize; + } else if (mpyDim == 1 && prevNode) { + mpyNode = prevNode; + prevNode = resNode; + } else { + mpyNode = mirModule.CurFuncCodeMemPool()->New(OP_mul); + mpyNode->SetPrimType(array->GetPrimType()); + MIRIntConst *mulConst = mirModule.GetMemPool()->New( + mpyDim, *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()))); + BaseNode *mulSize = mirModule.CurFuncCodeMemPool()->New(mulConst); + mulSize->SetPrimType(array->GetPrimType()); + mpyNode->SetOpnd(mulSize, 0); + mpyNode->SetOpnd(resNode, 1); + } + if (i == 0) { + prevNode = mpyNode; + continue; + } + BaseNode *newResNode = mirModule.CurFuncCodeMemPool()->New(OP_add); + newResNode->SetPrimType(array->GetPrimType()); + newResNode->SetOpnd(mpyNode, 0); + newResNode->SetOpnd(prevNode, 1); + prevNode = newResNode; + } + resNode = prevNode; + } + + BaseNode *rMul = nullptr; + // esize is the size of the array element (eg. int = 4 long = 8) + uint64 esize; + if (nestedArray) { + esize = elemSize; + } else { + esize = arrayType->GetElemType()->GetSize(); + } + Opcode opadd = OP_add; + MIRIntConst *econst = mirModule.GetMemPool()->New(esize, + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(array->GetPrimType()))); + BaseNode *eSize = mirModule.CurFuncCodeMemPool()->New(econst); + eSize->SetPrimType(array->GetPrimType()); + rMul = mirModule.CurFuncCodeMemPool()->New(OP_mul); + rMul->SetPrimType(array->GetPrimType()); + rMul->SetOpnd(resNode, 0); + rMul->SetOpnd(eSize, 1); + BaseNode *baseNode = array->GetBase(); + BaseNode *rAdd = mirModule.CurFuncCodeMemPool()->New(opadd); + rAdd->SetPrimType(array->GetPrimType()); + rAdd->SetOpnd(baseNode, 0); + rAdd->SetOpnd(rMul, 1); + return rAdd; +} + IfStmtNode *MIRLower::ExpandArrayMrtIfBlock(IfStmtNode &node) { if (node.GetThenPart() != nullptr) { node.SetThenPart(ExpandArrayMrtBlock(*node.GetThenPart())); diff --git a/src/mapleall/maple_me/include/ssa_tab.h b/src/mapleall/maple_me/include/ssa_tab.h index c6d2256a5c..028e9f63b0 100644 --- a/src/mapleall/maple_me/include/ssa_tab.h +++ b/src/mapleall/maple_me/include/ssa_tab.h @@ -35,7 +35,7 @@ class SSATab : public AnalysisResult { ~SSATab() = default; - BaseNode *CreateSSAExpr(BaseNode &expr); + BaseNode *CreateSSAExpr(BaseNode *expr); void CreateSSAStmt(StmtNode &stmt, const BB *curbb); bool HasDefBB(OStIdx oidx) { return oidx < defBBs4Ost.size() && defBBs4Ost[oidx] && !defBBs4Ost[oidx]->empty(); diff --git a/src/mapleall/maple_me/src/me_function.cpp b/src/mapleall/maple_me/src/me_function.cpp index 7676ce7618..20755bd385 100644 --- a/src/mapleall/maple_me/src/me_function.cpp +++ b/src/mapleall/maple_me/src/me_function.cpp @@ -133,7 +133,9 @@ void MeFunction::Prepare(unsigned long rangeNum) { MIRLower mirLowerer(mirModule, CurFunction()); mirLowerer.Init(); mirLowerer.SetLowerME(); - mirLowerer.SetLowerExpandArray(); + if (mirModule.IsJavaModule()) { + mirLowerer.SetLowerExpandArray(); + } ASSERT(CurFunction() != nullptr, "nullptr check"); mirLowerer.LowerFunc(*CurFunction()); } diff --git a/src/mapleall/maple_me/src/ssa_tab.cpp b/src/mapleall/maple_me/src/ssa_tab.cpp index 7405f2dc0b..261099d9d6 100644 --- a/src/mapleall/maple_me/src/ssa_tab.cpp +++ b/src/mapleall/maple_me/src/ssa_tab.cpp @@ -18,6 +18,8 @@ #include "ssa_mir_nodes.h" #include "opcode_info.h" #include "mir_function.h" +#include "mir_lower.h" +#include "me_option.h" // Allocate data structures to store SSA information. Only statement nodes and // tree nodes that incur defs and uses are relevant. Tree nodes are made larger @@ -25,47 +27,54 @@ // stored in class SSATab's StmtsSSAPart, which has an array of pointers indexed // by the stmtID field of each statement node. namespace maple { -BaseNode *SSATab::CreateSSAExpr(BaseNode &expr) { - if (expr.GetOpCode() == OP_addrof || expr.GetOpCode() == OP_dread) { - if (expr.IsSSANode()) { - return mirModule.CurFunction()->GetCodeMemPool()->New(static_cast(expr)); +BaseNode *SSATab::CreateSSAExpr(BaseNode *expr) { + bool arrayLowered = false; + if (expr->GetOpCode() == OP_array && !mirModule.IsJavaModule() && + MeOption::strengthReduction /* && in-main-me-phase */) { + MIRLower mirLower(mirModule, mirModule.CurFunction()); + expr = mirLower.LowerCArray(static_cast(expr)); + arrayLowered = true; + } + if (expr->GetOpCode() == OP_addrof || expr->GetOpCode() == OP_dread) { + if (expr->IsSSANode()) { + return mirModule.CurFunction()->GetCodeMemPool()->New(*static_cast(expr)); } - auto &addrofNode = static_cast(expr); - AddrofSSANode *ssaNode = mirModule.CurFunction()->GetCodeMemPool()->New(addrofNode); + AddrofNode *addrofNode = static_cast(expr); + AddrofSSANode *ssaNode = mirModule.CurFunction()->GetCodeMemPool()->New(*addrofNode); MIRSymbol *st = mirModule.CurFunction()->GetLocalOrGlobalSymbol(ssaNode->GetStIdx()); OriginalSt *ost = FindOrCreateSymbolOriginalSt(*st, mirModule.CurFunction()->GetPuidx(), ssaNode->GetFieldID()); versionStTable.CreateZeroVersionSt(ost); ssaNode->SetSSAVar(*versionStTable.GetZeroVersionSt(ost)); return ssaNode; - } else if (expr.GetOpCode() == OP_regread) { - auto ®ReadNode = static_cast(expr); - RegreadSSANode *ssaNode = mirModule.CurFunction()->GetCodeMemPool()->New(regReadNode); + } else if (expr->GetOpCode() == OP_regread) { + RegreadNode *regReadNode = static_cast(expr); + RegreadSSANode *ssaNode = mirModule.CurFunction()->GetCodeMemPool()->New(*regReadNode); OriginalSt *ost = originalStTable.FindOrCreatePregOriginalSt(ssaNode->GetRegIdx(), mirModule.CurFunction()->GetPuidx()); versionStTable.CreateZeroVersionSt(ost); ssaNode->SetSSAVar(*versionStTable.GetZeroVersionSt(ost)); return ssaNode; - } else if (expr.GetOpCode() == OP_iread) { - auto &ireadNode = static_cast(expr); - IreadSSANode *ssaNode = mirModule.CurFunction()->GetCodeMempool()->New(ireadNode); - BaseNode *newOpnd = CreateSSAExpr(*ireadNode.Opnd(0)); + } else if (expr->GetOpCode() == OP_iread) { + IreadNode *ireadNode = static_cast(expr); + IreadSSANode *ssaNode = mirModule.CurFunction()->GetCodeMempool()->New(*ireadNode); + BaseNode *newOpnd = CreateSSAExpr(ireadNode->Opnd(0)); if (newOpnd != nullptr) { ssaNode->SetOpnd(newOpnd, 0); } return ssaNode; } - for (size_t i = 0; i < expr.NumOpnds(); ++i) { - BaseNode *newOpnd = CreateSSAExpr(*expr.Opnd(i)); + for (size_t i = 0; i < expr->NumOpnds(); ++i) { + BaseNode *newOpnd = CreateSSAExpr(expr->Opnd(i)); if (newOpnd != nullptr) { - expr.SetOpnd(newOpnd, i); + expr->SetOpnd(newOpnd, i); } } - return nullptr; + return arrayLowered ? expr : nullptr; } void SSATab::CreateSSAStmt(StmtNode &stmt, const BB *curbb) { for (size_t i = 0; i < stmt.NumOpnds(); ++i) { - BaseNode *newOpnd = CreateSSAExpr(*stmt.Opnd(i)); + BaseNode *newOpnd = CreateSSAExpr(stmt.Opnd(i)); if (newOpnd != nullptr) { stmt.SetOpnd(newOpnd, i); } -- Gitee