diff --git a/compiler/optimizer/optimizations/checks_elimination.cpp b/compiler/optimizer/optimizations/checks_elimination.cpp index 30171376e08637359d6e9eb514e3e8c17f473dc9..10f9eae728769d2a22edcc0edf9cae1bebb8d244 100644 --- a/compiler/optimizer/optimizations/checks_elimination.cpp +++ b/compiler/optimizer/optimizations/checks_elimination.cpp @@ -69,7 +69,7 @@ void ChecksElimination::VisitNullCheck(GraphVisitor *v, Inst *inst) COMPILER_LOG(DEBUG, CHECKS_ELIM) << "Start visit NullCheck with id = " << inst->GetId(); auto ref = inst->GetInput(0).GetInst(); - static_cast(v)->TryRemoveDominatedNullChecks(inst, ref); + static_cast(v)->TryRemoveDominatedNullChecks(inst, ref); if (!static_cast(v)->TryRemoveCheck(inst)) { COMPILER_LOG(DEBUG, CHECKS_ELIM) << "NullCheck couldn't be deleted"; @@ -577,6 +577,9 @@ void ChecksElimination::PushNewBoundsCheck(Loop *loop, Inst *len_array, Inst *in void ChecksElimination::TryRemoveDominatedNullChecks(Inst *inst, Inst *ref) { + if (is_osr_mode_) { + return; + } for (auto &user : ref->GetUsers()) { auto user_inst = user.GetInst(); if (((user_inst->GetOpcode() == Opcode::IsInstance && !user_inst->CastToIsInstance()->GetOmitNullCheck()) || @@ -598,6 +601,10 @@ void ChecksElimination::TryRemoveDominatedNullChecks(Inst *inst, Inst *ref) template void ChecksElimination::TryRemoveDominatedChecks(Inst *inst, CheckInputs check_inputs) { + if (is_osr_mode_) { + // In osr mode checks from different blocks not interchangeable in case of SaveStateOSR between + return; + } for (auto &user : inst->GetInput(0).GetInst()->GetUsers()) { auto user_inst = user.GetInst(); // NOLINTNEXTLINE(readability-magic-numbers) diff --git a/compiler/optimizer/optimizations/checks_elimination.h b/compiler/optimizer/optimizations/checks_elimination.h index 863c28bc3b5d58636dffc6a9268aa2a3b2eb142c..0c42c975cbb66c40a41061a6b4d96d6558626d74 100644 --- a/compiler/optimizer/optimizations/checks_elimination.h +++ b/compiler/optimizer/optimizations/checks_elimination.h @@ -48,7 +48,8 @@ public: : Optimization(graph), bounds_checks_(graph->GetLocalAllocator()->Adapter()), checks_for_move_out_of_loop_(graph->GetLocalAllocator()->Adapter()), - checks_must_throw_(graph->GetLocalAllocator()->Adapter()) + checks_must_throw_(graph->GetLocalAllocator()->Adapter()), + is_osr_mode_(graph->IsOsrMode()) { } @@ -179,6 +180,8 @@ private: InstVector checks_for_move_out_of_loop_; InstVector checks_must_throw_; + const bool is_osr_mode_; + bool is_applied_ {false}; bool is_loop_deleted_ {false}; }; diff --git a/compiler/tests/checks_elimination_test.cpp b/compiler/tests/checks_elimination_test.cpp index 49db051a6ba6f502d77f4ee982a206889657d355..545cf238d03b36dfbb8e33505b005968fc77fbdb 100644 --- a/compiler/tests/checks_elimination_test.cpp +++ b/compiler/tests/checks_elimination_test.cpp @@ -5199,6 +5199,115 @@ TEST_F(ChecksEliminationTest, NegOverflowAndZeroCheck3) ASSERT_TRUE(GraphComparator().Compare(graph1, graph2)); } +TEST_F(ChecksEliminationTest, SaveStateOSR) +{ + auto graph1 = CreateEmptyGraph(); + auto osr_graph1 = CreateOsrGraph(); // this 2 graphs content same except inst 4 SaveStateOsr + GRAPH(graph1) + { + PARAMETER(0, 0).u32(); + PARAMETER(42, 1).ref(); + CONSTANT(1, 0); + CONSTANT(2, 1); + + BASIC_BLOCK(2, 3) + { + INST(15, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(16, Opcode::NullCheck).ref().Inputs(42, 15); + INST(17, Opcode::LenArray).s32().Inputs(16); + } + + BASIC_BLOCK(3, 5, 4) + { + INST(3, Opcode::Phi).s32().Inputs(0, 7); + INST(5, Opcode::Compare).b().SrcType(DataType::Type::INT32).CC(CC_LE).Inputs(3, 1); + INST(6, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(5); + } + BASIC_BLOCK(4, 3) + { + INST(7, Opcode::Sub).s32().Inputs(3, 2); + INST(8, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(9, Opcode::NullCheck).ref().Inputs(42, 8); + INST(10, Opcode::LenArray).s32().Inputs(9); + } + BASIC_BLOCK(5, -1) + { + INST(11, Opcode::Return).u32().Inputs(3); + } + } + GRAPH(osr_graph1) + { + PARAMETER(0, 0).u32(); + PARAMETER(42, 1).ref(); + CONSTANT(1, 0); + CONSTANT(2, 1); + + BASIC_BLOCK(2, 3) + { + INST(15, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(16, Opcode::NullCheck).ref().Inputs(42, 15); + INST(17, Opcode::LenArray).s32().Inputs(16); + } + + BASIC_BLOCK(3, 5, 4) + { + INST(3, Opcode::Phi).s32().Inputs(0, 7); + INST(4, Opcode::SaveStateOsr).Inputs(3, 42).SrcVregs({1, 42}); + INST(5, Opcode::Compare).b().SrcType(DataType::Type::INT32).CC(CC_LE).Inputs(3, 1); + INST(6, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(5); + } + BASIC_BLOCK(4, 3) + { + INST(7, Opcode::Sub).s32().Inputs(3, 2); + INST(8, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(9, Opcode::NullCheck).ref().Inputs(42, 8); + INST(10, Opcode::LenArray).s32().Inputs(9); + } + BASIC_BLOCK(5, -1) + { + INST(11, Opcode::Return).u32().Inputs(3); + } + } + auto graph2 = CreateEmptyGraph(); + GRAPH(graph2) + { + PARAMETER(0, 0).u32(); + PARAMETER(42, 1).ref(); + CONSTANT(1, 0); + CONSTANT(2, 1); + + BASIC_BLOCK(2, 3) + { + INST(15, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(16, Opcode::NullCheck).ref().Inputs(42, 15); + INST(17, Opcode::LenArray).s32().Inputs(16); + } + + BASIC_BLOCK(3, 5, 4) + { + INST(3, Opcode::Phi).s32().Inputs(0, 7); + INST(5, Opcode::Compare).b().SrcType(DataType::Type::INT32).CC(CC_LE).Inputs(3, 1); + INST(6, Opcode::IfImm).SrcType(DataType::BOOL).CC(CC_NE).Imm(0).Inputs(5); + } + BASIC_BLOCK(4, 3) + { + INST(7, Opcode::Sub).s32().Inputs(3, 2); + INST(8, Opcode::SaveState).Inputs(42).SrcVregs({42}); + INST(43, Opcode::NOP); + INST(10, Opcode::LenArray).s32().Inputs(16); + } + BASIC_BLOCK(5, -1) + { + INST(11, Opcode::Return).u32().Inputs(3); + } + } + + ASSERT_TRUE(graph1->RunPass()); + ASSERT_TRUE(GraphComparator().Compare(graph1, graph2)); + + ASSERT_FALSE(osr_graph1->RunPass()); +} + // NOLINTEND(readability-magic-numbers) } // namespace panda::compiler diff --git a/compiler/tests/unit_test.h b/compiler/tests/unit_test.h index 6d84fdfbbc1e8f6a45aad99aca1598eeb0df1970..3cf66df948f193ff69e1829864582ccde4ea78d6 100644 --- a/compiler/tests/unit_test.h +++ b/compiler/tests/unit_test.h @@ -162,6 +162,11 @@ public: return GetAllocator()->New(GetAllocator(), GetLocalAllocator(), arch, false); } + Graph *CreateOsrGraph() const + { + return GetAllocator()->New(GetAllocator(), GetLocalAllocator(), arch_, true); + } + Graph *CreateGraphStartEndBlocks(bool is_dynamic = false) const { auto graph = GetAllocator()->New(GetAllocator(), GetLocalAllocator(), arch_, is_dynamic, false);