From 7bd8d53bfdd8e9b852911f44c2609b8ac51c2dff Mon Sep 17 00:00:00 2001 From: Jamie Cui Date: Mon, 15 Dec 2025 18:51:15 +0800 Subject: [PATCH 1/3] feat(kcal_python): add arith operators and improve share management - Add comprehensive arithmetic operators (ADD, SUB, MUL, DIV, - comparisons, aggregates) Implement proper shared pointer - management for MPC shares Add extensive test suite for new - operators Improve input handling for both string and numeric data - types Fix memory management issues in KcalMpcShare and related - classes Add clang-format configuration and formatting scripts - Update .gitignore with comprehensive ignore patterns --- .clang-format | 112 +++++++++ .gitignore | 58 +++++ MPC/kcal_python/src/kcal_wrapper.cc | 266 +++++++++++++++++--- MPC/kcal_python/src/kcal_wrapper.h | 23 ++ MPC/kcal_python/test/arith_demo.py | 365 ++++++++++++++++++++++++++-- MPC/middleware/kcal/utils/io.cc | 22 +- MPC/middleware/kcal/utils/io.h | 103 ++++++-- build.sh | 0 format-all.sh | 19 ++ 9 files changed, 881 insertions(+), 87 deletions(-) create mode 100644 .clang-format create mode 100644 .gitignore create mode 100644 MPC/kcal_python/src/kcal_wrapper.h mode change 100644 => 100755 build.sh create mode 100755 format-all.sh diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..be97eb4 --- /dev/null +++ b/.clang-format @@ -0,0 +1,112 @@ +Language: Cpp +BasedOnStyle: Google +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: false +# AllowShortEnumsOnASingleLine: false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*\.pb\.h"$' + Priority: 5 + - Regex: '^"kcal.*' + Priority: 4 + - Regex: '^".*' + Priority: 3 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 80 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 80 +PointerAlignment: Right +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 4 +UseTab: Never \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c2679c6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,58 @@ +# Others +.envrc +DataGuard-* +MPC/kcal_python/include +MPC/kcal_python/kcal +MPC/kcal_python/Makefile +MPC/kcal_python/dist +MPC/kcal_python/lib + +# Python +uv.lock +.pdm-python + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environments +venv/ +env/ +ENV/ + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Environment variables +.env + +# Distribution / packaging +.pytest_cache/ +.coverage +htmlcov/ + +# external deps +external/ + +# build directory +**/build/* + +# AI +CLAUDE.md +.claude +transcript.txt + +# project +.clang +.cache \ No newline at end of file diff --git a/MPC/kcal_python/src/kcal_wrapper.cc b/MPC/kcal_python/src/kcal_wrapper.cc index 66a9ddc..a501d87 100644 --- a/MPC/kcal_python/src/kcal_wrapper.cc +++ b/MPC/kcal_python/src/kcal_wrapper.cc @@ -1,16 +1,25 @@ // Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +#include "kcal_wrapper.h" + #include #include #include #include "context_ext.h" + #include "kcal/core/operator_base.h" #include "kcal/core/operator_manager.h" #include "kcal/operator/all_operator_register.h" +#include "kcal/operator/kcal_arithmetic.h" +#include "kcal/operator/kcal_avg.h" +#include "kcal/operator/kcal_make_share.h" +#include "kcal/operator/kcal_maximum.h" #include "kcal/operator/kcal_pir.h" #include "kcal/operator/kcal_psi.h" +#include "kcal/operator/kcal_reveal_share.h" +#include "kcal/operator/kcal_sum.h" #include "kcal/utils/io.h" namespace py = pybind11; @@ -19,30 +28,53 @@ namespace kcal { namespace { +using PyShare = std::shared_ptr; + void FeedKcalInput(const py::list &pyList, io::KcalInput *kcalInput) { - auto *dgString = new (std::nothrow) DG_String[pyList.size()]; - if (!dgString) { - throw std::bad_alloc(); + if (pyList.empty()) { + return; } - for (size_t i = 0; i < pyList.size(); ++i) { - if (!PyUnicode_Check(pyList[i].ptr())) { - throw std::runtime_error("need str"); + const auto &itemTemp = pyList[0]; + if (py::isinstance(itemTemp)) { + auto *dgString = new (std::nothrow) DG_String[pyList.size()]; + if (!dgString) { + throw std::bad_alloc(); } + for (size_t i = 0; i < pyList.size(); ++i) { + if (!PyUnicode_Check(pyList[i].ptr())) { + throw std::runtime_error("need str"); + } - Py_ssize_t sz; - const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz); - if (!utf8) { - throw std::bad_alloc(); + Py_ssize_t sz; + const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz); + if (!utf8) { + throw std::bad_alloc(); + } + + dgString[i].str = strdup(utf8); + dgString[i].size = static_cast(sz) + 1; } - dgString[i].str = strdup(utf8); - dgString[i].size = static_cast(sz) + 1; + DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer(); + (*internalInput)->data.strings = dgString; + (*internalInput)->size = pyList.size(); + (*internalInput)->dataType = MPC_STRING; + } else { + auto inData = std::make_unique(pyList.size()); + for (size_t i = 0; i < pyList.size(); ++i) { + if (py::isinstance(pyList[i]) || py::isinstance(pyList[i])) { + inData[i] = pyList[i].cast(); + } else { + throw std::runtime_error("need number type"); + } + } + + DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer(); + (*internalInput)->data.doubleNumbers = inData.release(); + (*internalInput)->size = pyList.size(); + (*internalInput)->dataType = MPC_DOUBLE; } - DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer(); - (*internalInput)->data.strings = dgString; - (*internalInput)->size = pyList.size(); - (*internalInput)->dataType = MPC_STRING; } void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairList *pairList) @@ -58,7 +90,6 @@ void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairLi } size_t i = 0; for (i = 0; i < key.size(); ++i) { - pairList->Get()->dgPair[i].key = new (std::nothrow) DG_String(); pairList->Get()->dgPair[i].key = new (std::nothrow) DG_String(); pairList->Get()->dgPair[i].value = new (std::nothrow) DG_String(); if (!pairList->Get()->dgPair[i].key || !pairList->Get()->dgPair[i].value) { @@ -73,10 +104,10 @@ void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairLi Py_ssize_t sz; const char *utf8 = PyUnicode_AsUTF8AndSize(key[i].ptr(), &sz); if (!utf8) { + throw std::bad_alloc(); } pairList->Get()->dgPair[i].key->str = strdup(utf8); - pairList->Get()->dgPair[i].key->str = strdup(utf8); pairList->Get()->dgPair[i].key->size = static_cast(sz) + 1; } // 填充 value @@ -90,7 +121,6 @@ void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairLi throw std::bad_alloc(); } pairList->Get()->dgPair[i].value->str = strdup(utf8); - pairList->Get()->dgPair[i].value->str = strdup(utf8); pairList->Get()->dgPair[i].value->size = static_cast(sz) + 1; } } @@ -112,7 +142,7 @@ void FeedPsiOutput(io::KcalOutput &kcalOutput, py::list &pyList, DG_TeeMode mode void FeedKcalOutput(io::KcalOutput &kcalOutput, py::list &pyList) { auto *outPtr = kcalOutput.Get(); - auto dataType = kcalOutput.Get()->dataType; + auto dataType = outPtr->dataType; for (size_t i = 0; i < outPtr->size; ++i) { if (dataType == MPC_STRING) { pyList.append(outPtr->data.strings[i].str); @@ -124,7 +154,7 @@ void FeedKcalOutput(io::KcalOutput &kcalOutput, py::list &pyList) } } -} // namespace +} // namespace class PyCallbackAdapter { public: @@ -182,11 +212,11 @@ public: void BindIoClasses(py::module_ &m) { - py::class_(m, "MpcShare") + py::class_(m, "MpcShare") .def(py::init<>()) .def(py::init()) - .def_static("Create", &io::KcalMpcShare::Create, py::return_value_policy::take_ownership) .def("Set", &io::KcalMpcShare::Set) + .def("Create", &io::KcalMpcShare::Create, py::return_value_policy::take_ownership) .def( "Get", [](io::KcalMpcShare &self) -> DG_MpcShare * { return self.Get(); }, py::return_value_policy::reference) @@ -195,16 +225,16 @@ void BindIoClasses(py::module_ &m) py::class_(m, "MpcShareSet") .def(py::init<>()) - .def_static( - "Create", [](const std::vector &shares) { return io::KcalMpcShareSet::Create(shares); }, - py::return_value_policy::take_ownership) .def( "Get", [](io::KcalMpcShareSet &self) -> DG_MpcShareSet * { return self.Get(); }, py::return_value_policy::reference); py::class_(m, "Input") .def(py::init<>()) .def(py::init()) - .def_static("Create", &io::KcalInput::Create, py::return_value_policy::take_ownership) + .def_static( + "Create", + []() -> std::shared_ptr { return std::shared_ptr(io::KcalInput::Create()); }, + py::return_value_policy::take_ownership) .def("Set", &io::KcalInput::Set) .def("Get", &io::KcalInput::Get, py::return_value_policy::reference) .def("Fill", &io::KcalInput::Fill) @@ -219,7 +249,7 @@ void BindOtherOperators(py::module_ &m) py::class_>(m, "Psi") .def(py::init<>()) .def("run", [](Psi &self, const py::list &input, py::list &output, DG_TeeMode mode) -> int { - std::unique_ptr kcalInput(io::KcalInput::Create()); + std::shared_ptr kcalInput(io::KcalInput::Create()); FeedKcalInput(input, kcalInput.get()); io::KcalOutput kcalOutput; int ret = self.Run(kcalInput->Get(), kcalOutput.GetSecondaryPointer(), mode); @@ -230,7 +260,7 @@ void BindOtherOperators(py::module_ &m) .def(py::init<>()) .def("ServerPreProcess", [](Pir &self, const py::list &key, py::list &value) -> int { - std::unique_ptr kcalInput(io::KcalPairList::Create()); + std::shared_ptr kcalInput(io::KcalPairList::Create()); // build DG_PairList FeedKcalPairList(key, value, kcalInput.get()); int ret = self.ServerPreProcess(kcalInput->Get()); @@ -238,7 +268,7 @@ void BindOtherOperators(py::module_ &m) }) .def("ClientQuery", [](Pir &self, const py::list &input, py::list &output, DG_DummyMode mode) -> int { - std::unique_ptr kcalInput(io::KcalInput::Create()); + std::shared_ptr kcalInput(io::KcalInput::Create()); FeedKcalInput(input, kcalInput.get()); io::KcalOutput kcalOutput; int ret = self.ClientQuery(kcalInput->Get(), kcalOutput.GetSecondaryPointer(), mode); @@ -249,6 +279,129 @@ void BindOtherOperators(py::module_ &m) int ret = self.ServerAnswer(); return ret; }); + + // Arithmetic Operators + py::class_>(m, "Arithmetic").def(py::init<>()); + + // Basic Arithmetic Operations + py::class_>(m, "Add") + .def(py::init<>()) + .def("run", [](Add &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Sub") + .def(py::init<>()) + .def("run", [](Sub &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Mul") + .def(py::init<>()) + .def("run", [](Mul &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Div") + .def(py::init<>()) + .def("run", [](Div &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + // Comparison Operations + py::class_>(m, "Less") + .def(py::init<>()) + .def("run", [](Less &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "LessEqual") + .def(py::init<>()) + .def("run", [](LessEqual &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Greater") + .def(py::init<>()) + .def("run", [](Greater &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "GreaterEqual") + .def(py::init<>()) + .def("run", [](GreaterEqual &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Equal") + .def(py::init<>()) + .def("run", [](Equal &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "NoEqual") + .def(py::init<>()) + .def("run", [](NoEqual &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + // Aggregate Operations + py::class_>(m, "Sum") + .def(py::init<>()) + .def("run", [](Sum &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Avg") + .def(py::init<>()) + .def("run", [](Avg &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Max") + .def(py::init<>()) + .def("run", [](Max &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + py::class_>(m, "Min") + .def(py::init<>()) + .def("run", [](Min &self, const std::vector &shares, PyShare &outShare) -> int { + auto ptr = outShare.get(); + return self.Run(io::KcalMpcShareSet(shares), ptr); + }); + + // Share Management + py::class_>(m, "MakeShare") + .def(py::init<>()) + .def("run", [](MakeShare &self, const py::list &input, int isRecvShare, PyShare &share) -> int { + io::KcalInput kcalInput(new DG_TeeInput()); + FeedKcalInput(input, &kcalInput); + auto data = share.get(); // must have this + return self.Run(kcalInput, isRecvShare, data); + }); + + py::class_>(m, "RevealShare") + .def(py::init<>()) + .def("run", [](RevealShare &self, const PyShare &share, py::list &output) -> int { + io::KcalOutput kcalOutput; + int ret = self.Run(share.get(), kcalOutput); + FeedKcalOutput(kcalOutput, output); + return ret; + }); } PYBIND11_MODULE(kcal, m) @@ -258,6 +411,23 @@ PYBIND11_MODULE(kcal, m) py::enum_(m, "AlgorithmsType") .value("PSI", KCAL_AlgorithmsType::PSI) .value("PIR", KCAL_AlgorithmsType::PIR) + .value("ARITHMETIC", KCAL_AlgorithmsType::ARITHMETIC) + .value("MAKE_SHARE", KCAL_AlgorithmsType::MAKE_SHARE) + .value("REVEAL_SHARE", KCAL_AlgorithmsType::REVEAL_SHARE) + .value("ADD", KCAL_AlgorithmsType::ADD) + .value("SUB", KCAL_AlgorithmsType::SUB) + .value("MUL", KCAL_AlgorithmsType::MUL) + .value("DIV", KCAL_AlgorithmsType::DIV) + .value("LESS", KCAL_AlgorithmsType::LESS) + .value("LESS_EQUAL", KCAL_AlgorithmsType::LESS_EQUAL) + .value("GREATER", KCAL_AlgorithmsType::GREATER) + .value("GREATER_EQUAL", KCAL_AlgorithmsType::GREATER_EQUAL) + .value("EQUAL", KCAL_AlgorithmsType::EQUAL) + .value("NO_EQUAL", KCAL_AlgorithmsType::NO_EQUAL) + .value("SUM", KCAL_AlgorithmsType::SUM) + .value("AVG", KCAL_AlgorithmsType::AVG) + .value("MAX", KCAL_AlgorithmsType::MAX) + .value("MIN", KCAL_AlgorithmsType::MIN) .export_values(); py::enum_(m, "TeeMode") @@ -316,6 +486,40 @@ PYBIND11_MODULE(kcal, m) return OperatorManager::CreateOperator(context->GetKcalContext()); case KCAL_AlgorithmsType::PIR: return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::ARITHMETIC: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::MAKE_SHARE: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::REVEAL_SHARE: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::ADD: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::SUB: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::MUL: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::DIV: + return OperatorManager::CreateOperator
(context->GetKcalContext()); + case KCAL_AlgorithmsType::LESS: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::LESS_EQUAL: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::GREATER: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::GREATER_EQUAL: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::EQUAL: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::NO_EQUAL: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::SUM: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::AVG: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::MAX: + return OperatorManager::CreateOperator(context->GetKcalContext()); + case KCAL_AlgorithmsType::MIN: + return OperatorManager::CreateOperator(context->GetKcalContext()); default: throw std::runtime_error("Unsupported operator type"); } @@ -343,4 +547,4 @@ PYBIND11_MODULE(kcal, m) m.def("release_mpc_share", [](DG_MpcShare *share) { io::DataHelper::ReleaseMpcShare(&share); }); } -} // namespace kcal +} // namespace kcal diff --git a/MPC/kcal_python/src/kcal_wrapper.h b/MPC/kcal_python/src/kcal_wrapper.h new file mode 100644 index 0000000..cd4773f --- /dev/null +++ b/MPC/kcal_python/src/kcal_wrapper.h @@ -0,0 +1,23 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#pragma once + +#include + +#include "kcal/utils/io.h" + +namespace kcal { + +// Obtain a non-owning vector of shared pointers +const std::vector BorrowPtrs(const std::vector &in) +{ + std::vector out; + out.reserve(in.size()); + for (const auto &elem : in) { + // FIXME(cuijiming): we should not use const_cast here, but it's harmless since we return a const + out.emplace_back(const_cast(&elem)); + } + return out; +} + +} // namespace kcal diff --git a/MPC/kcal_python/test/arith_demo.py b/MPC/kcal_python/test/arith_demo.py index 0110f33..038ddea 100644 --- a/MPC/kcal_python/test/arith_demo.py +++ b/MPC/kcal_python/test/arith_demo.py @@ -39,7 +39,8 @@ def on_recv_data(node_info: dict, buffer: memoryview) -> int: return socket_util.recv_data(s, buffer) -def psi_demo(is_server: bool): +def create_context(is_server: bool): + """Create and return KCAL context for testing""" config = kcal.Config() config.nodeId = 0 if is_server else 1 config.worldSize = 2 @@ -47,54 +48,366 @@ def psi_demo(is_server: bool): config.threadCount = 32 config.useSMAlg = False - context = kcal.Context.create(config, on_send_data, on_recv_data) + return kcal.Context.create(config, on_send_data, on_recv_data) - makeshare_op = kcal.create_operator(context, kcal.AlgorithmsType.NAKESHARE) - revealshare_op = kcal.create_operator(context, kcal.AlgorithmsType.NAKESHARE) + +def test_basic_arithmetic(context, is_server: bool): + """Test basic arithmetic operations: ADD, SUB, MUL, DIV""" + print("\n=== Testing Basic Arithmetic Operations ===") + + # Create operators + make_share_op = kcal.create_operator(context, kcal.AlgorithmsType.MAKE_SHARE) + reveal_share_op = kcal.create_operator(context, kcal.AlgorithmsType.REVEAL_SHARE) + add_op = kcal.create_operator(context, kcal.AlgorithmsType.ADD) + sub_op = kcal.create_operator(context, kcal.AlgorithmsType.SUB) mul_op = kcal.create_operator(context, kcal.AlgorithmsType.MUL) - input0 = ["4", "3", "2", "1"] - input1 = ["1", "3", "4", "5"] - output = [] + div_op = kcal.create_operator(context, kcal.AlgorithmsType.DIV) + + # Test data + input1 = [10, 20, 30, 40] + input2 = [5, 10, 15, 20] + import time start_time = time.time() + + share1 = kcal.MpcShare.Create() + share2 = kcal.MpcShare.Create() + if is_server: - makeshare_op.run() - makeshare_op.run() - mul_op.run() - revealshare_op.run() + print("Server: Processing arithmetic operations...") + + # Create shares for both inputs + make_share_op.run(input1, 1, share1) # isRecvShare = 1 + make_share_op.run(input2, 1, share2) + else: - makeshare_op.run() - makeshare_op.run() - mul_op.run() - revealshare_op.run() - print(len(output)) + print("Client: Processing arithmetic operations...") + + # Create shares for both inputs + make_share_op.run(input1, 0, share1) # isRecvShare = 0 + make_share_op.run(input2, 0, share2) + + # Test ADD: (10+5, 20+10, 30+15, 40+20) = [15, 30, 45, 60] + add_out_share = kcal.MpcShare.Create() + add_result = add_op.run([share1, share2], add_out_share) + + # Test SUB: (10-5, 20-10, 30-15, 40-20) = [5, 10, 15, 20] + sub_out_share = kcal.MpcShare.Create() + sub_result = sub_op.run([share1, share2], sub_out_share) + + # Test MUL: (10*5, 20*10, 30*15, 40*20) = [50, 200, 450, 800] + mul_out_share = kcal.MpcShare.Create() + mul_result = mul_op.run([share1, share2], mul_out_share ) + + # Test DIV: (10/5, 20/10, 30/15, 40/20) = [2, 2, 2, 2] + div_out_share = kcal.MpcShare.Create() + div_result = div_op.run([share1, share2], div_out_share) + + # Reveal results + add_output = [] + sub_output = [] + mul_output = [] + div_output = [] + + reveal_share_op.run(add_out_share, add_output) + reveal_share_op.run(sub_out_share, sub_output) + reveal_share_op.run(mul_out_share, mul_output) + reveal_share_op.run(div_out_share, div_output) + + print(f"ADD result: {add_output}") + print(f"SUB result: {sub_output}") + print(f"MUL result: {mul_output}") + print(f"DIV result: {div_output}") + end_time = time.time() - duration_ms = (end_time - start_time) * 1000 # ms - print(f"run cost: {duration_ms:.2f} ms") + duration_ms = (end_time - start_time) * 1000 + print(f"Basic arithmetic test completed in: {duration_ms:.2f} ms") -def main(argv=None): +def test_comparison_operations(context, is_server: bool): + """Test comparison operations: LESS, GREATER, EQUAL, etc.""" + print("\n=== Testing Comparison Operations ===") + + # Create operators + make_share_op = kcal.create_operator(context, kcal.AlgorithmsType.MAKE_SHARE) + reveal_share_op = kcal.create_operator(context, kcal.AlgorithmsType.REVEAL_SHARE) + lt_op = kcal.create_operator(context, kcal.AlgorithmsType.LESS) + gt_op = kcal.create_operator(context, kcal.AlgorithmsType.GREATER) + eq_op = kcal.create_operator(context, kcal.AlgorithmsType.EQUAL) + less_eq_op = kcal.create_operator(context, kcal.AlgorithmsType.LESS_EQUAL) + greater_eq_op = kcal.create_operator(context, kcal.AlgorithmsType.GREATER_EQUAL) + no_eq_op = kcal.create_operator(context, kcal.AlgorithmsType.NO_EQUAL) + + # Test data + input1 = [10, 20, 30, 40] + input2 = [15, 20, 25, 50] + + import time + start_time = time.time() + + share1 = kcal.MpcShare.Create() + share2 = kcal.MpcShare.Create() + + if is_server: + print("Server: Processing comparison operations...") + + # Create shares + make_share_op.run(input1, 1, share1) # isRecvShare = 1 + make_share_op.run(input2, 1, share2) + + else: + print("Client: Processing comparison operations...") + + make_share_op.run(input1, 0, share1) # isRecvShare = 0 + make_share_op.run(input2, 0, share2) + + # Test comparison operations + lt_out_share = kcal.MpcShare.Create() + lt_op.run([share1, share2], lt_out_share) # 10<15, 20<20, 30<25, 40<50 = [1,0,0,1] + + gt_out_share = kcal.MpcShare.Create() + gt_op.run([share1, share2], gt_out_share) # 10>15, 20>20, 30>25, 40>50 = [0,0,1,0] + + eq_out_share = kcal.MpcShare.Create() + eq_op.run([share1, share2], eq_out_share) # 10=15, 20=20, 30=25, 40=50 = [0,1,0,0] + + # Reveal results + lt_output = [] + gt_output = [] + eq_output = [] + + reveal_share_op.run(lt_out_share, lt_output) + reveal_share_op.run(gt_out_share, gt_output) + reveal_share_op.run(eq_out_share, eq_output) + + print(f"LESS (input1input2) result: {gt_output}") + print(f"EQUAL (input1==input2) result: {eq_output}") + + end_time = time.time() + duration_ms = (end_time - start_time) * 1000 + print(f"Comparison operations test completed in: {duration_ms:.2f} ms") + + +def test_aggregate_operations(context, is_server: bool): + """Test aggregate operations: SUM, AVG, MAX, MIN""" + print("\n=== Testing Aggregate Operations ===") + + # Create operators + make_share_op = kcal.create_operator(context, kcal.AlgorithmsType.MAKE_SHARE) + reveal_share_op = kcal.create_operator(context, kcal.AlgorithmsType.REVEAL_SHARE) + sum_op = kcal.create_operator(context, kcal.AlgorithmsType.SUM) + avg_op = kcal.create_operator(context, kcal.AlgorithmsType.AVG) + max_op = kcal.create_operator(context, kcal.AlgorithmsType.MAX) + min_op = kcal.create_operator(context, kcal.AlgorithmsType.MIN) + + # Test data from multiple parties + input1 = [10, 20, 30, 40] # Party 1 + input2 = [5, 15, 25, 35] # Party 2 + + import time + start_time = time.time() + + share1 = kcal.MpcShare.Create() + share2 = kcal.MpcShare.Create() + + if is_server: + print("Server: Processing aggregate operations...") + + # Create shares + make_share_op.run(input1, 1, share1) + make_share_op.run(input2, 1, share2) + + else: + print("Client: Processing aggregate operations...") + + # Create shares + make_share_op.run(input1, 0, share1) + make_share_op.run(input2, 0, share2) + + # Test aggregate operations (combine shares from both parties) + sum_out_share = kcal.MpcShare.Create() + avg_out_share = kcal.MpcShare.Create() + max_out_share = kcal.MpcShare.Create() + min_out_share = kcal.MpcShare.Create() + + sum_result = sum_op.run([share1, share2], sum_out_share) # [10+5, 20+15, 30+25, 40+35] = [15, 35, 55, 75] + avg_result = avg_op.run([share1, share2], avg_out_share) # Average of each position + max_result = max_op.run([share1, share2], max_out_share) # Max of each position: [10,20,30,40] + min_result = min_op.run([share1, share2], min_out_share) # Min of each position: [5,15,25,35] + + # Reveal results + sum_output = [] + avg_output = [] + max_output = [] + min_output = [] + + reveal_share_op.run(sum_out_share, sum_output) + reveal_share_op.run(avg_out_share, avg_output) + reveal_share_op.run(max_out_share, max_output) + reveal_share_op.run(min_out_share, min_output) + + print(f"SUM result: {sum_output}") + print(f"AVG result: {avg_output}") + print(f"MAX result: {max_output}") + print(f"MIN result: {min_output}") + + end_time = time.time() + duration_ms = (end_time - start_time) * 1000 + print(f"Aggregate operations test completed in: {duration_ms:.2f} ms") + + +def test_share_management(context, is_server: bool): + """Test share management operations: MAKE_SHARE, REVEAL_SHARE""" + print("\n=== Testing Share Management ===") + + make_share_op = kcal.create_operator(context, kcal.AlgorithmsType.MAKE_SHARE) + reveal_share_op = kcal.create_operator(context, kcal.AlgorithmsType.REVEAL_SHARE) + + # Test data with different types + test_inputs = [ + [1, 2, 3, 4, 5], # Small integers + [100, 200, 300, 400, 500], # Medium integers + [1000, 2000, 3000, 4000], # Large integers + ] + + for i, test_input in enumerate(test_inputs): + print(f"\nTest case {i+1}: {test_input}") + + import time + start_time = time.time() + + share = kcal.MpcShare.Create() + + if is_server: + # Create share (server always receives shares) + make_share_op.run(test_input, 1, share) + + # Reveal the share back to original values + output = [] + reveal_share_op.run(share, output) + + print(f"Revealed values: {output}") + + else: + # Create share (client doesn't receive shares) + make_share_op.run(test_input, 0, share) + + # Reveal the share back to original values + output = [] + reveal_share_op.run(share, output) + + print(f"Revealed values: {output}") + + end_time = time.time() + duration_ms = (end_time - start_time) * 1000 + print(f"Share management test completed in: {duration_ms:.2f} ms") + + +def run_comprehensive_tests(is_server: bool): + """Run all tests for the new arithmetic operators""" + print(f"=== KCAL Arithmetic Operators Test Suite ===") + print(f"Running as: {'Server' if is_server else 'Client'}") + + context = create_context(is_server) + + try: + # Run all test categories + test_share_management(context, is_server) + test_basic_arithmetic(context, is_server) + test_comparison_operations(context, is_server) + test_aggregate_operations(context, is_server) + + print("\n=== All Tests Completed Successfully! ===") + + except Exception as e: + print(f"Test failed with error: {e}") + import traceback + traceback.print_exc() + + + +def create_parser(): + """Create and return the argument parser""" parser = argparse.ArgumentParser(description="KCAL python wrapper demo.") try: + # Main mode selection group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--server", action="store_true", default=False, help="start server") group.add_argument("--client", action="store_true", default=False, help="start client") - parser.add_argument("--host", type=str, default="127.0.0.1") - parser.add_argument("-p", "--port", type=int, required=True) - args = parser.parse_args(argv) + + # Test selection + test_group = parser.add_mutually_exclusive_group() + test_group.add_argument("--test-all", action="store_true", default=True, + help="run comprehensive tests for all arithmetic operators (default)") + test_group.add_argument("--test-basic", action="store_true", default=False, + help="test basic arithmetic operations (ADD, SUB, MUL, DIV)") + test_group.add_argument("--test-comparison", action="store_true", default=False, + help="test comparison operations (LESS, GREATER, EQUAL, etc.)") + test_group.add_argument("--test-aggregate", action="store_true", default=False, + help="test aggregate operations (SUM, AVG, MAX, MIN)") + test_group.add_argument("--test-shares", action="store_true", default=False, + help="test share management (MAKE_SHARE, REVEAL_SHARE)") + test_group.add_argument("--original", action="store_true", default=False, + help="run original demo for compatibility") + + # Network configuration + parser.add_argument("--host", type=str, default="127.0.0.1", + help="server host address (default: 127.0.0.1)") + parser.add_argument("-p", "--port", type=int, required=True, + help="port number for communication") + + return parser except argparse.ArgumentParser: parser.print_help() sys.exit(1) + +def main(argv=None): + parser = create_parser() + args = parser.parse_args(argv) + global _client_socket, _server_socket + if args.server: _client_socket = socket_util.init_server(args.host, args.port) - psi_demo(True) - _client_socket.close() + try: + if args.test_basic: + context = create_context(True) + test_basic_arithmetic(context, True) + elif args.test_comparison: + context = create_context(True) + test_comparison_operations(context, True) + elif args.test_aggregate: + context = create_context(True) + test_aggregate_operations(context, True) + elif args.test_shares: + context = create_context(True) + test_share_management(context, True) + else: # default: test_all + run_comprehensive_tests(True) + finally: + _client_socket.close() + elif args.client: _server_socket = socket_util.init_client(args.host, args.port) - psi_demo(False) - _server_socket.close() + try: + if args.test_basic: + context = create_context(False) + test_basic_arithmetic(context, False) + elif args.test_comparison: + context = create_context(False) + test_comparison_operations(context, False) + elif args.test_aggregate: + context = create_context(False) + test_aggregate_operations(context, False) + elif args.test_shares: + context = create_context(False) + test_share_management(context, False) + else: # default: test_all + run_comprehensive_tests(False) + finally: + _server_socket.close() if __name__ == "__main__": diff --git a/MPC/middleware/kcal/utils/io.cc b/MPC/middleware/kcal/utils/io.cc index c8a71f5..7f2f21f 100644 --- a/MPC/middleware/kcal/utils/io.cc +++ b/MPC/middleware/kcal/utils/io.cc @@ -11,6 +11,7 @@ */ #include "kcal/utils/io.h" + #include #include @@ -47,7 +48,7 @@ void DataHelper::ReleaseDgPairList(DG_PairList *pairList) delete pairList->dgPair[i].value; }; } - delete [] pairList->dgPair; + delete[] pairList->dgPair; pairList = nullptr; } } @@ -102,12 +103,6 @@ KcalMpcShare::~KcalMpcShare() } } -KcalMpcShare *KcalMpcShare::Create() -{ - auto share = std::make_unique(); - return share.release(); -} - // =========================== // KcalMpcShareSet impl // =========================== @@ -122,6 +117,19 @@ KcalMpcShareSet::~KcalMpcShareSet() } } +KcalMpcShareSet::KcalMpcShareSet(const std::vector> &shares) +{ + shareSet_ = new (std::nothrow) DG_MpcShareSet(); + shareSet_->size = shares.size(); + + auto shareDatas = std::make_unique(shareSet_->size); + shareSet_->shareSet = shareDatas.release(); + + for (size_t i = 0; i < shares.size(); ++i) { + shareSet_->shareSet[i] = *shares[i]->Get(); + } +} + KcalMpcShareSet KcalMpcShareSet::Create(const std::vector &shares) { KcalMpcShareSet shareSet{}; diff --git a/MPC/middleware/kcal/utils/io.h b/MPC/middleware/kcal/utils/io.h index 20b38d5..73b6a5d 100644 --- a/MPC/middleware/kcal/utils/io.h +++ b/MPC/middleware/kcal/utils/io.h @@ -13,11 +13,11 @@ #ifndef KCAL_MIDDLEWARE_IO_H #define KCAL_MIDDLEWARE_IO_H -#include +#include #include -#include "kcal/api/kcal_api.h" +#include -#include +#include "kcal/api/kcal_api.h" namespace kcal::io { @@ -33,17 +33,41 @@ public: class KcalMpcShare { public: KcalMpcShare() = default; - explicit KcalMpcShare(DG_MpcShare *share) : share_(share) {} + explicit KcalMpcShare(DG_MpcShare *share) : share_(share) + {} ~KcalMpcShare(); - static KcalMpcShare *Create(); - - void Set(DG_MpcShare *share) { share_ = share; } - DG_MpcShare *&Get() { return share_; } - DG_MpcShare *Get() const { return share_; } - - unsigned long Size() { return share_->size; } - DG_ShareType Type() { return share_->shareType; } + // static function that creates KcalMpcShare shared_ptr + static std::shared_ptr Create() + { + return std::make_shared(); + } + + // delete copy and assign constructor + KcalMpcShare(const KcalMpcShare &) = delete; + KcalMpcShare &operator=(const KcalMpcShare &) = delete; + + void Set(DG_MpcShare *share) + { + share_ = share; + } + DG_MpcShare *&Get() + { + return share_; + } + DG_MpcShare *Get() const + { + return share_; + } + + unsigned long Size() + { + return share_->size; + } + DG_ShareType Type() + { + return share_->shareType; + } private: DG_MpcShare *share_ = nullptr; // manage memory release @@ -52,10 +76,17 @@ private: class KcalMpcShareSet { public: KcalMpcShareSet() = default; + KcalMpcShareSet(const std::vector> &shares); ~KcalMpcShareSet(); - DG_MpcShareSet *Get() { return shareSet_; } - DG_MpcShareSet *Get() const { return shareSet_; } + DG_MpcShareSet *Get() + { + return shareSet_; + } + DG_MpcShareSet *Get() const + { + return shareSet_; + } static KcalMpcShareSet Create(const std::vector &shares); @@ -67,18 +98,34 @@ private: class KcalInput { public: KcalInput() = default; - explicit KcalInput(DG_TeeInput *input) : input_(input) {} - ~KcalInput() { DataHelper::ReleaseOutput(&input_); } + explicit KcalInput(DG_TeeInput *input) : input_(input) + {} + ~KcalInput() + { + DataHelper::ReleaseOutput(&input_); + } static KcalInput *Create(); - void Set(DG_TeeInput *input) { input_ = input; } - DG_TeeInput *Get() { return input_; } - DG_TeeInput **GetSecondaryPointer() { return &input_; } + void Set(DG_TeeInput *input) + { + input_ = input; + } + DG_TeeInput *Get() + { + return input_; + } + DG_TeeInput **GetSecondaryPointer() + { + return &input_; + } void Fill(const std::vector &data); - int Size() { return input_->size; } + int Size() + { + return input_->size; + } private: DG_TeeInput *input_ = nullptr; // manage memory release @@ -90,10 +137,20 @@ class KcalPairList { public: KcalPairList() = default; explicit KcalPairList(DG_PairList *pairList) : pairList_(pairList) {}; - ~KcalPairList() {DataHelper::ReleaseDgPairList(pairList_);}; + ~KcalPairList() + { + DataHelper::ReleaseDgPairList(pairList_); + }; static KcalPairList *Create(); - DG_PairList *Get() {return pairList_;}; - DG_PairList **GetSecondaryPointer() {return &pairList_;}; + DG_PairList *Get() + { + return pairList_; + }; + DG_PairList **GetSecondaryPointer() + { + return &pairList_; + }; + private: DG_PairList *pairList_ = nullptr; }; diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/format-all.sh b/format-all.sh new file mode 100755 index 0000000..bfcbc17 --- /dev/null +++ b/format-all.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Check if clang-format exists +if ! command -v clang-format &> /dev/null; then + echo "clang-format not found" + exit 1 +fi + +# Check if cmake-format exists +if ! command -v cmake-format &> /dev/null; then + echo "cmake-format not found" + exit 1 +fi + +# format c/c++ code +find . -name "*.cpp" -o -name "*.hpp" -o -name "*.h" | xargs clang-format -i + +# format cmake code +find . -name "CMakeLists.txt" -exec cmake-format -i {} \; -- Gitee From d631305a8797529a06dfef8e2e1988eb47d18126 Mon Sep 17 00:00:00 2001 From: Jamie Cui Date: Mon, 15 Dec 2025 19:57:23 +0800 Subject: [PATCH 2/3] refactor(test): simplify aggregate operations test with single input Remove redundant second input array and streamline MPC share creation for arithmetic demo. Test now uses single input data from one party instead of combining shares from two parties. --- MPC/kcal_python/test/arith_demo.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/MPC/kcal_python/test/arith_demo.py b/MPC/kcal_python/test/arith_demo.py index 038ddea..39a5e6a 100644 --- a/MPC/kcal_python/test/arith_demo.py +++ b/MPC/kcal_python/test/arith_demo.py @@ -202,28 +202,24 @@ def test_aggregate_operations(context, is_server: bool): min_op = kcal.create_operator(context, kcal.AlgorithmsType.MIN) # Test data from multiple parties - input1 = [10, 20, 30, 40] # Party 1 - input2 = [5, 15, 25, 35] # Party 2 + input = [10, 20, 30, 40] # Party 1 import time start_time = time.time() - share1 = kcal.MpcShare.Create() - share2 = kcal.MpcShare.Create() + share = kcal.MpcShare.Create() if is_server: print("Server: Processing aggregate operations...") # Create shares - make_share_op.run(input1, 1, share1) - make_share_op.run(input2, 1, share2) + make_share_op.run(input, 1, share) else: print("Client: Processing aggregate operations...") # Create shares - make_share_op.run(input1, 0, share1) - make_share_op.run(input2, 0, share2) + make_share_op.run(input, 0, share) # Test aggregate operations (combine shares from both parties) sum_out_share = kcal.MpcShare.Create() @@ -231,10 +227,10 @@ def test_aggregate_operations(context, is_server: bool): max_out_share = kcal.MpcShare.Create() min_out_share = kcal.MpcShare.Create() - sum_result = sum_op.run([share1, share2], sum_out_share) # [10+5, 20+15, 30+25, 40+35] = [15, 35, 55, 75] - avg_result = avg_op.run([share1, share2], avg_out_share) # Average of each position - max_result = max_op.run([share1, share2], max_out_share) # Max of each position: [10,20,30,40] - min_result = min_op.run([share1, share2], min_out_share) # Min of each position: [5,15,25,35] + sum_op.run([share], sum_out_share) + avg_op.run([share], avg_out_share) + max_op.run([share], max_out_share) + min_op.run([share], min_out_share) # Reveal results sum_output = [] -- Gitee From 3c9b6bb6cc105ef2f724660d594feae28a6c80a5 Mon Sep 17 00:00:00 2001 From: Jamie Cui Date: Thu, 18 Dec 2025 15:39:06 +0800 Subject: [PATCH 3/3] fix(kcal_wrapper): remove unnecessary blank line in error handling Remove extraneous newline character in FeedKcalPairList function that was causing formatting inconsistency. This change improves code readability and maintains consistent indentation patterns throughout the function. --- MPC/kcal_python/src/kcal_wrapper.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/MPC/kcal_python/src/kcal_wrapper.cc b/MPC/kcal_python/src/kcal_wrapper.cc index a501d87..810cbb1 100644 --- a/MPC/kcal_python/src/kcal_wrapper.cc +++ b/MPC/kcal_python/src/kcal_wrapper.cc @@ -104,7 +104,6 @@ void FeedKcalPairList(const py::list &key, const py::list &value, io::KcalPairLi Py_ssize_t sz; const char *utf8 = PyUnicode_AsUTF8AndSize(key[i].ptr(), &sz); if (!utf8) { - throw std::bad_alloc(); } pairList->Get()->dgPair[i].key->str = strdup(utf8); -- Gitee