From fdbce96b3f74e7ea635136e723662d9c5dc3a2b1 Mon Sep 17 00:00:00 2001 From: gongzequn Date: Sat, 12 Jul 2025 17:15:01 +0800 Subject: [PATCH] add dllm support, and some corresbonding bugfix patch (cherry picked from commit 4a2edc60414f3e163b6f823ef2207a251a3b302a) --- ...ower-version-setuptools-on-openeuler.patch | 27 + 0002-bugfix-prefix-cache.patch | 44 + 0003-bugfix-for-dllm-register.patch | 67 + 0004-feature-dllm-tools.patch | 13654 ++++++++++++++++ vllm.spec | 12 +- 5 files changed, 13802 insertions(+), 2 deletions(-) create mode 100644 0001-bugfix-support-lower-version-setuptools-on-openeuler.patch create mode 100644 0002-bugfix-prefix-cache.patch create mode 100644 0003-bugfix-for-dllm-register.patch create mode 100644 0004-feature-dllm-tools.patch diff --git a/0001-bugfix-support-lower-version-setuptools-on-openeuler.patch b/0001-bugfix-support-lower-version-setuptools-on-openeuler.patch new file mode 100644 index 0000000..dd6cc13 --- /dev/null +++ b/0001-bugfix-support-lower-version-setuptools-on-openeuler.patch @@ -0,0 +1,27 @@ +From 7c88c924f7969ca492313f8f5d0c61228e3fe7ad Mon Sep 17 00:00:00 2001 +From: gongzequn +Date: Fri, 25 Jul 2025 17:40:56 +0800 +Subject: [PATCH 1/4] bugfix: support lower version setuptools on openeuler + build + +--- + pyproject.toml | 3 +-- + 1 file changed, 1 insertion(+), 2 deletions(-) + +diff --git a/pyproject.toml b/pyproject.toml +index 307878f7e..a8021bbc0 100644 +--- a/pyproject.toml ++++ b/pyproject.toml +@@ -15,8 +15,7 @@ build-backend = "setuptools.build_meta" + [project] + name = "vllm" + authors = [{name = "vLLM Team"}] +-license = "Apache-2.0" +-license-files = ["LICENSE"] ++license = { text = "Apache-2.0" } + readme = "README.md" + description = "A high-throughput and memory-efficient inference and serving engine for LLMs" + classifiers = [ +-- +2.35.1.windows.2 + diff --git a/0002-bugfix-prefix-cache.patch b/0002-bugfix-prefix-cache.patch new file mode 100644 index 0000000..8982f71 --- /dev/null +++ b/0002-bugfix-prefix-cache.patch @@ -0,0 +1,44 @@ +From afa1f2465784aabb13e262d7dde83a66c543439b Mon Sep 17 00:00:00 2001 +From: chenwenjing +Date: Fri, 25 Jul 2025 17:42:22 +0800 +Subject: [PATCH 2/4] bugfix: prefix cache + +--- + vllm/v1/core/kv_cache_manager.py | 7 ++++--- + vllm/v1/core/sched/scheduler.py | 2 +- + 2 files changed, 5 insertions(+), 4 deletions(-) + +diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py +index 2e09f4c0a..763b958da 100644 +--- a/vllm/v1/core/kv_cache_manager.py ++++ b/vllm/v1/core/kv_cache_manager.py +@@ -382,9 +382,10 @@ class KVCacheManager: + + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: + """Cache the blocks for the request.""" +- block_hashes = self.req_to_block_hashes[request.request_id] +- self.coordinator.cache_blocks(request, block_hashes, +- num_computed_tokens) ++ if self.enable_caching: ++ block_hashes = self.req_to_block_hashes[request.request_id] ++ self.coordinator.cache_blocks(request, block_hashes, ++ num_computed_tokens) + + def create_empty_block_list(self) -> KVCacheBlocks: + """Creates a new KVCacheBlocks instance with no blocks.""" +diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py +index 3d7bbe7e0..0ca5d6a7e 100644 +--- a/vllm/v1/core/sched/scheduler.py ++++ b/vllm/v1/core/sched/scheduler.py +@@ -724,7 +724,7 @@ class Scheduler(SchedulerInterface): + continue + + req_index = model_runner_output.req_id_to_index[req_id] +- generated_token_ids = sampled_token_ids[req_index] ++ generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) +-- +2.35.1.windows.2 + diff --git a/0003-bugfix-for-dllm-register.patch b/0003-bugfix-for-dllm-register.patch new file mode 100644 index 0000000..ef87023 --- /dev/null +++ b/0003-bugfix-for-dllm-register.patch @@ -0,0 +1,67 @@ +From 3d75da0fb40ec2f8e6a2b69a81e435497eac9f3e Mon Sep 17 00:00:00 2001 +From: liujunhong +Date: Fri, 25 Jul 2025 17:45:23 +0800 +Subject: [PATCH 3/4] bugfix: for dllm register + +Co-authored-by: gongzequn +--- + vllm/distributed/kv_transfer/kv_connector/factory.py | 8 ++++++++ + vllm/v1/attention/backends/mla/common.py | 7 +++++-- + vllm/v1/worker/gpu_worker.py | 5 +++++ + 3 files changed, 18 insertions(+), 2 deletions(-) + +diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py +index 58dfa251c..6365ac70a 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/factory.py ++++ b/vllm/distributed/kv_transfer/kv_connector/factory.py +@@ -126,3 +126,11 @@ KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") ++KVConnectorFactory.register_connector( ++ "DLLMDsConnector", ++ "dllm.dkvc.v1.dllm_ds_connector", ++ "DLLMDsConnector") ++KVConnectorFactory.register_connector( ++ "DLLMDsD2DConnector", ++ "dllm.dkvc.v1.dllm_ds_d2d_connector", ++ "DLLMDsD2DConnector") +\ No newline at end of file +diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py +index e6b4f6404..efed2e5bb 100644 +--- a/vllm/v1/attention/backends/mla/common.py ++++ b/vllm/v1/attention/backends/mla/common.py +@@ -215,9 +215,12 @@ try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True + except ImportError: +- # For rocm use upstream flash attention +- from flash_attn import flash_attn_varlen_func + is_vllm_fa = False ++ try: ++ # For rocm use upstream flash attention ++ from flash_attn import flash_attn_varlen_func ++ except ImportError: ++ flash_attn_varlen_func = None + + if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index b7d244f27..ed3ebc802 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -47,6 +47,11 @@ class Worker(WorkerBase): + is_driver_worker: bool = False, + ): + ++ if vllm_config.kv_transfer_config and \ ++ vllm_config.kv_transfer_config.kv_connector in ["DLLMDsConnector", "DLLMDsD2DConnector"] and \ ++ vllm_config.kv_transfer_config.kv_connector_extra_config: ++ local_rank = vllm_config.kv_transfer_config.kv_connector_extra_config["device_ids"][rank] ++ + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, +-- +2.35.1.windows.2 + diff --git a/0004-feature-dllm-tools.patch b/0004-feature-dllm-tools.patch new file mode 100644 index 0000000..0842a9b --- /dev/null +++ b/0004-feature-dllm-tools.patch @@ -0,0 +1,13654 @@ +From 99253656b14804739512149f101376fb2c14bac0 Mon Sep 17 00:00:00 2001 +From: dllm team +Date: Fri, 25 Jul 2025 17:47:46 +0800 +Subject: [PATCH 4/4] feature: dllm tools + +--- + dllm_tools/.coveragerc | 3 + + dllm_tools/.gitignore | 182 ++++ + dllm_tools/CMakeLists.txt | 26 + + dllm_tools/MANIFEST.in | 1 + + dllm_tools/README.md | 84 ++ + dllm_tools/build.sh | 99 ++ + dllm_tools/clean.sh | 21 + + dllm_tools/cmake/dependency.cmake | 23 + + dllm_tools/cmake/external_libs/ascend.cmake | 8 + + dllm_tools/cmake/external_libs/spdlog.cmake | 2 + + dllm_tools/cmake/modules/FindAscend.cmake | 32 + + dllm_tools/cmake/util.cmake | 583 ++++++++++++ + dllm_tools/csrc/.clang-format | 78 ++ + dllm_tools/csrc/CMakeLists.txt | 23 + + dllm_tools/csrc/include/kvc/c_api.h | 116 +++ + dllm_tools/csrc/include/kvc/common.h | 59 ++ + dllm_tools/csrc/include/kvc/kvc_future.h | 139 +++ + dllm_tools/csrc/include/kvc/kvc_store.h | 65 ++ + dllm_tools/csrc/include/kvc/page_attn_utils.h | 46 + + dllm_tools/csrc/include/kvc/torch_adaptor.h | 252 ++++++ + dllm_tools/csrc/include/perf/perf_manager.h | 227 +++++ + dllm_tools/csrc/include/utils/logging.h | 80 ++ + dllm_tools/csrc/kvc/c_api.cpp | 359 ++++++++ + dllm_tools/csrc/kvc/kvc_future.cpp | 88 ++ + dllm_tools/csrc/kvc/kvc_store.cpp | 179 ++++ + dllm_tools/csrc/kvc/page_attn_utils.cpp | 76 ++ + dllm_tools/csrc/kvc/pybind.h | 191 ++++ + dllm_tools/csrc/perf/perf_manager.cpp | 221 +++++ + dllm_tools/csrc/perf/pybind.h | 40 + + dllm_tools/csrc/pybind_register.cpp | 35 + + dllm_tools/csrc/utils/expected.h | 291 ++++++ + dllm_tools/csrc/utils/file_util.h | 101 +++ + dllm_tools/csrc/utils/kvc_future.h | 59 ++ + dllm_tools/csrc/utils/logging.cpp | 117 +++ + dllm_tools/csrc/utils/pybind.h | 58 ++ + dllm_tools/csrc/utils/strings_util.h | 72 ++ + dllm_tools/csrc/utils/thread_pool.cpp | 179 ++++ + dllm_tools/csrc/utils/thread_pool.h | 245 +++++ + dllm_tools/csrc/utils/timer.h | 125 +++ + dllm_tools/dllm/__init__.py | 13 + + dllm_tools/dllm/balancer/README.md | 5 + + dllm_tools/dllm/balancer/__init__.py | 15 + + dllm_tools/dllm/balancer/balancer.py | 363 ++++++++ + dllm_tools/dllm/balancer/policy/README.md | 2 + + dllm_tools/dllm/balancer/policy/__init__.py | 12 + + dllm_tools/dllm/config.py | 164 ++++ + dllm_tools/dllm/constants.py | 42 + + dllm_tools/dllm/controller/README.md | 7 + + dllm_tools/dllm/controller/__init__.py | 15 + + dllm_tools/dllm/controller/controller.py | 528 +++++++++++ + dllm_tools/dllm/controller/endpoint.py | 282 ++++++ + dllm_tools/dllm/controller/vllm_instance.py | 372 ++++++++ + dllm_tools/dllm/cpp_ext/__init__.pyi | 8 + + dllm_tools/dllm/cpp_ext/kvc.pyi | 290 ++++++ + dllm_tools/dllm/cpp_ext/perf.pyi | 10 + + dllm_tools/dllm/cpp_ext/utils.pyi | 27 + + dllm_tools/dllm/dkvc/README.md | 142 +++ + dllm_tools/dllm/dkvc/__init__.py | 0 + dllm_tools/dllm/dkvc/cpp_ext.pyi | 124 +++ + dllm_tools/dllm/dkvc/cpu_cache_evictor.py | 57 ++ + dllm_tools/dllm/dkvc/dllm_cache_engine.py | 366 ++++++++ + dllm_tools/dllm/dkvc/dllm_connector.py | 277 ++++++ + .../prefix_sharing_multi_level/__init__.py | 0 + .../dkvc/prefix_sharing_multi_level/block.py | 67 ++ + .../core/__init__.py | 0 + .../core/block_manager.py | 706 +++++++++++++++ + .../core/scheduler.py | 521 +++++++++++ + .../core/swap_in_watcher.py | 121 +++ + .../prefix_sharing_multi_level/index/LRU.py | 81 ++ + .../index/__init__.py | 0 + .../index/index_tree_manager.py | 149 +++ + .../index/radix_cache.py | 619 +++++++++++++ + dllm_tools/dllm/dkvc/util.py | 17 + + dllm_tools/dllm/dkvc/v1/__init__.py | 0 + dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py | 849 ++++++++++++++++++ + .../dllm/dkvc/v1/dllm_ds_d2d_connector.py | 798 ++++++++++++++++ + dllm_tools/dllm/entities.py | 114 +++ + dllm_tools/dllm/kvc/__init__.py | 15 + + dllm_tools/dllm/kvc/torch_adaptor.py | 139 +++ + dllm_tools/dllm/logging.py | 27 + + dllm_tools/dllm/monkey_patch/README.md | 22 + + dllm_tools/dllm/monkey_patch/__init__.py | 0 + .../dllm/monkey_patch/viz_profile/__init__.py | 0 + .../dllm/monkey_patch/viz_profile/common.py | 82 ++ + .../viz_profile/viz_profile_plugin.py | 20 + + .../viz_profile/vllm_api_server_patch.py | 66 ++ + .../viz_profile/vllm_engine_core_patch.py | 49 + + dllm_tools/dllm/scripts.py | 238 +++++ + dllm_tools/dllm/utils.py | 140 +++ + dllm_tools/launch_test.py | 103 +++ + dllm_tools/pyproject.toml | 51 ++ + dllm_tools/pytest.ini | 9 + + dllm_tools/requirements.txt | 12 + + dllm_tools/setup.py | 163 ++++ + 94 files changed, 12884 insertions(+) + create mode 100644 dllm_tools/.coveragerc + create mode 100644 dllm_tools/.gitignore + create mode 100644 dllm_tools/CMakeLists.txt + create mode 100644 dllm_tools/MANIFEST.in + create mode 100644 dllm_tools/README.md + create mode 100644 dllm_tools/build.sh + create mode 100644 dllm_tools/clean.sh + create mode 100644 dllm_tools/cmake/dependency.cmake + create mode 100644 dllm_tools/cmake/external_libs/ascend.cmake + create mode 100644 dllm_tools/cmake/external_libs/spdlog.cmake + create mode 100644 dllm_tools/cmake/modules/FindAscend.cmake + create mode 100644 dllm_tools/cmake/util.cmake + create mode 100644 dllm_tools/csrc/.clang-format + create mode 100644 dllm_tools/csrc/CMakeLists.txt + create mode 100644 dllm_tools/csrc/include/kvc/c_api.h + create mode 100644 dllm_tools/csrc/include/kvc/common.h + create mode 100644 dllm_tools/csrc/include/kvc/kvc_future.h + create mode 100644 dllm_tools/csrc/include/kvc/kvc_store.h + create mode 100644 dllm_tools/csrc/include/kvc/page_attn_utils.h + create mode 100644 dllm_tools/csrc/include/kvc/torch_adaptor.h + create mode 100644 dllm_tools/csrc/include/perf/perf_manager.h + create mode 100644 dllm_tools/csrc/include/utils/logging.h + create mode 100644 dllm_tools/csrc/kvc/c_api.cpp + create mode 100644 dllm_tools/csrc/kvc/kvc_future.cpp + create mode 100644 dllm_tools/csrc/kvc/kvc_store.cpp + create mode 100644 dllm_tools/csrc/kvc/page_attn_utils.cpp + create mode 100644 dllm_tools/csrc/kvc/pybind.h + create mode 100644 dllm_tools/csrc/perf/perf_manager.cpp + create mode 100644 dllm_tools/csrc/perf/pybind.h + create mode 100644 dllm_tools/csrc/pybind_register.cpp + create mode 100644 dllm_tools/csrc/utils/expected.h + create mode 100644 dllm_tools/csrc/utils/file_util.h + create mode 100644 dllm_tools/csrc/utils/kvc_future.h + create mode 100644 dllm_tools/csrc/utils/logging.cpp + create mode 100644 dllm_tools/csrc/utils/pybind.h + create mode 100644 dllm_tools/csrc/utils/strings_util.h + create mode 100644 dllm_tools/csrc/utils/thread_pool.cpp + create mode 100644 dllm_tools/csrc/utils/thread_pool.h + create mode 100644 dllm_tools/csrc/utils/timer.h + create mode 100644 dllm_tools/dllm/__init__.py + create mode 100644 dllm_tools/dllm/balancer/README.md + create mode 100644 dllm_tools/dllm/balancer/__init__.py + create mode 100644 dllm_tools/dllm/balancer/balancer.py + create mode 100644 dllm_tools/dllm/balancer/policy/README.md + create mode 100644 dllm_tools/dllm/balancer/policy/__init__.py + create mode 100644 dllm_tools/dllm/config.py + create mode 100644 dllm_tools/dllm/constants.py + create mode 100644 dllm_tools/dllm/controller/README.md + create mode 100644 dllm_tools/dllm/controller/__init__.py + create mode 100644 dllm_tools/dllm/controller/controller.py + create mode 100644 dllm_tools/dllm/controller/endpoint.py + create mode 100644 dllm_tools/dllm/controller/vllm_instance.py + create mode 100644 dllm_tools/dllm/cpp_ext/__init__.pyi + create mode 100644 dllm_tools/dllm/cpp_ext/kvc.pyi + create mode 100644 dllm_tools/dllm/cpp_ext/perf.pyi + create mode 100644 dllm_tools/dllm/cpp_ext/utils.pyi + create mode 100644 dllm_tools/dllm/dkvc/README.md + create mode 100644 dllm_tools/dllm/dkvc/__init__.py + create mode 100644 dllm_tools/dllm/dkvc/cpp_ext.pyi + create mode 100644 dllm_tools/dllm/dkvc/cpu_cache_evictor.py + create mode 100644 dllm_tools/dllm/dkvc/dllm_cache_engine.py + create mode 100644 dllm_tools/dllm/dkvc/dllm_connector.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/__init__.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/block.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/__init__.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/block_manager.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/scheduler.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/swap_in_watcher.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/LRU.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/__init__.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/index_tree_manager.py + create mode 100644 dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/radix_cache.py + create mode 100644 dllm_tools/dllm/dkvc/util.py + create mode 100644 dllm_tools/dllm/dkvc/v1/__init__.py + create mode 100644 dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py + create mode 100644 dllm_tools/dllm/dkvc/v1/dllm_ds_d2d_connector.py + create mode 100644 dllm_tools/dllm/entities.py + create mode 100644 dllm_tools/dllm/kvc/__init__.py + create mode 100644 dllm_tools/dllm/kvc/torch_adaptor.py + create mode 100644 dllm_tools/dllm/logging.py + create mode 100644 dllm_tools/dllm/monkey_patch/README.md + create mode 100644 dllm_tools/dllm/monkey_patch/__init__.py + create mode 100644 dllm_tools/dllm/monkey_patch/viz_profile/__init__.py + create mode 100644 dllm_tools/dllm/monkey_patch/viz_profile/common.py + create mode 100644 dllm_tools/dllm/monkey_patch/viz_profile/viz_profile_plugin.py + create mode 100644 dllm_tools/dllm/monkey_patch/viz_profile/vllm_api_server_patch.py + create mode 100644 dllm_tools/dllm/monkey_patch/viz_profile/vllm_engine_core_patch.py + create mode 100644 dllm_tools/dllm/scripts.py + create mode 100644 dllm_tools/dllm/utils.py + create mode 100644 dllm_tools/launch_test.py + create mode 100644 dllm_tools/pyproject.toml + create mode 100644 dllm_tools/pytest.ini + create mode 100644 dllm_tools/requirements.txt + create mode 100644 dllm_tools/setup.py + +diff --git a/dllm_tools/.coveragerc b/dllm_tools/.coveragerc +new file mode 100644 +index 000000000..0c215b64d +--- /dev/null ++++ b/dllm_tools/.coveragerc +@@ -0,0 +1,3 @@ ++[run] ++omit = ++ dllm/dkvc/prefix_sharing_multi_level/* +\ No newline at end of file +diff --git a/dllm_tools/.gitignore b/dllm_tools/.gitignore +new file mode 100644 +index 000000000..7166e7f96 +--- /dev/null ++++ b/dllm_tools/.gitignore +@@ -0,0 +1,182 @@ ++# Byte-compiled / optimized / DLL files ++__pycache__/ ++*.py[cod] ++*$py.class ++ ++# C extensions ++*.so ++ ++# Distribution / packaging ++.Python ++build/ ++develop-eggs/ ++dist/ ++downloads/ ++eggs/ ++.eggs/ ++lib/ ++lib64/ ++parts/ ++sdist/ ++var/ ++wheels/ ++share/python-wheels/ ++*.egg-info/ ++.installed.cfg ++*.egg ++MANIFEST ++ ++# PyInstaller ++# Usually these files are written by a python script from a template ++# before PyInstaller builds the exe, so as to inject date/other infos into it. ++*.manifest ++*.spec ++ ++# Installer logs ++pip-log.txt ++pip-delete-this-directory.txt ++ ++# Unit test / coverage reports ++htmlcov/ ++.tox/ ++.nox/ ++.coverage ++.coverage.* ++.cache ++nosetests.xml ++coverage.xml ++*.cover ++*.py,cover ++.hypothesis/ ++.pytest_cache/ ++cover/ ++ ++# Translations ++*.mo ++*.pot ++ ++# Django stuff: ++*.log ++local_settings.py ++db.sqlite3 ++db.sqlite3-journal ++ ++# Flask stuff: ++instance/ ++.webassets-cache ++ ++# Scrapy stuff: ++.scrapy ++ ++# Sphinx documentation ++docs/_build/ ++ ++# PyBuilder ++.pybuilder/ ++target/ ++ ++# Jupyter Notebook ++.ipynb_checkpoints ++ ++# IPython ++profile_default/ ++ipython_config.py ++ ++# pyenv ++# For a library or package, you might want to ignore these files since the code is ++# intended to run in multiple environments; otherwise, check them in: ++# .python-version ++ ++# pipenv ++# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. ++# However, in case of collaboration, if having platform-specific dependencies or dependencies ++# having no cross-platform support, pipenv may install dependencies that don't work, or not ++# install all needed dependencies. ++#Pipfile.lock ++ ++# UV ++# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. ++# This is especially recommended for binary packages to ensure reproducibility, and is more ++# commonly ignored for libraries. ++#uv.lock ++ ++# poetry ++# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. ++# This is especially recommended for binary packages to ensure reproducibility, and is more ++# commonly ignored for libraries. ++# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control ++#poetry.lock ++ ++# pdm ++# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. ++#pdm.lock ++# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it ++# in version control. ++# https://pdm.fming.dev/latest/usage/project/#working-with-version-control ++.pdm.toml ++.pdm-python ++.pdm-build/ ++ ++# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm ++__pypackages__/ ++ ++# Celery stuff ++celerybeat-schedule ++celerybeat.pid ++ ++# SageMath parsed files ++*.sage.py ++ ++# Environments ++.env ++.venv ++env/ ++venv/ ++ENV/ ++env.bak/ ++venv.bak/ ++ ++# Spyder project settings ++.spyderproject ++.spyproject ++ ++# Rope project settings ++.ropeproject ++ ++# mkdocs documentation ++/site ++ ++# mypy ++.mypy_cache/ ++.dmypy.json ++dmypy.json ++ ++# Pyre type checker ++.pyre/ ++ ++# pytype static type analyzer ++.pytype/ ++ ++# Cython debug symbols ++cython_debug/ ++ ++# PyCharm ++# JetBrains specific template is maintained in a separate JetBrains.gitignore that can ++# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore ++# and can be added to the global gitignore or merged into this file. For a more nuclear ++# option (not recommended) you can uncomment the following to ignore the entire idea folder. ++#.idea/ ++ ++# Ruff stuff: ++.ruff_cache/ ++ ++# PyPI configuration file ++.pypirc ++ ++# IDE ++.idea ++.vscode ++ ++# DLLM ++stress_test.log ++test_results/ +diff --git a/dllm_tools/CMakeLists.txt b/dllm_tools/CMakeLists.txt +new file mode 100644 +index 000000000..f804345fc +--- /dev/null ++++ b/dllm_tools/CMakeLists.txt +@@ -0,0 +1,26 @@ ++cmake_minimum_required(VERSION 3.12) ++ ++project(dllm) ++ ++set(CMAKE_EXPORT_COMPILE_COMMANDS ON) ++# Use C++17 standard. ++set(CMAKE_CXX_STANDARD 17) ++ ++set(CMAKE_BUILD_TYPE RelWithDebInfo) ++set(CMAKE_POSITION_INDEPENDENT_CODE ON) ++set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -g3 -fsigned-char -Wextra -Wfloat-equal -fno-common -rdynamic") ++set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fstack-protector-strong -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") ++set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_FORTIFY_SOURCE=2 -fPIE -pie -Wl,--build-id=none -g -FPIC") ++ ++if (ENABLE_PERF) ++ add_compile_definitions(ENABLE_PERF) ++ message(STATUS "Enable perf point log") ++endif () ++ ++set(CMAKE_HTTP_SSL_VERIFY OFF CACHE BOOL "Disable SSL verification") ++set(CMAKE_TLS_VERIFY OFF CACHE BOOL "Disable TLS verification") ++# Import the third party we depends. ++include(cmake/util.cmake) ++include(cmake/dependency.cmake) ++ ++add_subdirectory(csrc) +diff --git a/dllm_tools/MANIFEST.in b/dllm_tools/MANIFEST.in +new file mode 100644 +index 000000000..93efd9fd4 +--- /dev/null ++++ b/dllm_tools/MANIFEST.in +@@ -0,0 +1 @@ ++recursive-include dllm/include *.h +\ No newline at end of file +diff --git a/dllm_tools/README.md b/dllm_tools/README.md +new file mode 100644 +index 000000000..b3f7eb0b0 +--- /dev/null ++++ b/dllm_tools/README.md +@@ -0,0 +1,84 @@ ++# dllm ++ ++stand for "distributed llm", aims at providing better tools for distributed vllm serving framework. ++ ++## Build guide ++ ++> **TL;DR** ++> ++> ``` ++> yum install python3-pip gcc g++ cmake spdlog-devel -y ++> pip install --upgrade pip ++> pip install --upgrade wheel setuptools ninja pybind11 chariot-ds ++> ++> python3 setup.py bdist_wheel ++> ``` ++ ++### Build requires ++ ++**build tools** ++ ++* `gcc/g++/make/cmake`: can be installed by `yum install gcc g++ cmake -y` ++* `ninja`: can be installed by `pip install ninja` ++* `python/pip`: can be installed by `yum install python3-pip; pip install --upgrade pip;` ++* `wheel/setuptools`: can be installed by `pip install --upgrade pip wheel setuptools` ++ ++> NOTE: upgrade setuptools is necessary in most of OS ++ ++**dependencies** ++ ++* `spdlog`: can be installed by `yum install spdlog-devel -y` ++* `pybind11`: can be installed by `pip install pybind11` ++* `chariot-ds`: can be installed by `pip install chariot-ds` ++* `ascend cann`: access https://www.hiascend.com/software/cann for installation ++ ++### Build command ++ ++```bash ++bash build.sh ++# or python3 setup.py bdist_wheel ++``` ++ ++## Install guide ++ ++```bash ++pip install dist/dllm-*.whl ++``` ++ ++## Use guide ++ ++### deploy dependencies ++ ++> NOTE: After deploy chariot-ds, you need to set the envrionment `DS_WORKER_ADDR="{IP}:{PORT}"` on each node before start ray. ++ ++1. chariot-ds: follow https://pypi.org/project/chariot-ds/ ++2. Ray: follow https://docs.ray.io/en/latest/cluster/vms/user-guides/launching-clusters/on-premises.html#on-prem ++ ++### deploy dllm ++ ++use vllm-mindspore as an example, when use, ++ ++* 1 Prefill instance, with parallel config: [TP: 4, DP: 4, EP: 16] ++* 1 Decode instance, with parallel config: [TP: 4, DP: 4, EP: 16] ++ ++the command should be like: ++ ++```bash ++dllm deploy \ ++ --prefill-instances-num=1 \ ++ --decode-instances-num=1 \ ++ -ptp=4 -dtp=4 -pdp=4 -ddp=4 -pep=16 -dep=16 \ ++ --prefill-startup-params="vllm-mindspore serve --model=/workspace/models/qwen2.5_7B --trust_remote_code --max-num-seqs=256 --max_model_len=1024 --max-num-batched-tokens=1024 --block-size=128 --gpu-memory-utilization=0.93" \ ++ --decode-startup-params="vllm-mindspore serve --model=/workspace/models/qwen2.5_7B --trust_remote_code --max-num-seqs=256 --max_model_len=1024 --max-num-batched-tokens=1024 --block-size=128 --gpu-memory-utilization=0.93" ++``` ++ ++After deploy success, can access the localhost:8000 as a general openai api endpoint (which is fully compatible) ++ ++```bash ++curl -X POST "http://127.0.0.1:8000/v1/completions" -H "Content-Type: application/json" -H "Authorization: Bearer YOUR_API_KEY" -d '{ ++ "model": "/workspace/models/qwen2.5_7B", ++ "prompt": "Alice is ", ++ "max_tokens": 50, ++ "temperature": 0 ++}' ++``` +diff --git a/dllm_tools/build.sh b/dllm_tools/build.sh +new file mode 100644 +index 000000000..0c0a83f65 +--- /dev/null ++++ b/dllm_tools/build.sh +@@ -0,0 +1,99 @@ ++#!/bin/bash ++# ++# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++# ++ ++set -e ++set -o errexit ++set -o pipefail ++ ++BASE_DIR=$( ++ cd "$(dirname "$0")" ++ pwd ++) ++ ++BUILD_VERSION="3.11" ++VALID_VERSIONS=("3.9" "3.10" "3.11") ++ ++# 解析命令行参数 ++for i in "$@" ++do ++ case $i in ++ --BUILD_VERSION=*) ++ BUILD_VERSION="${i#*=}" ++ # 验证传入的版本是否有效 ++ if [[ ! " ${VALID_VERSIONS[@]} " =~ " ${BUILD_VERSION} " ]]; then ++ echo "Invalid Python version specified: $BUILD_VERSION" ++ exit 1 ++ fi ++ echo "Python version set to: $BUILD_VERSION" ++ shift # 处理当前参数,移除 ++ ;; ++ *) ++ echo "Unknown parameter: $i" ++ exit 1 ++ ;; ++ esac ++done ++ ++SRC_DIR=${BASE_DIR} ++if [ ! $ASCEND_HOME_PATH ]; then ++ source "/usr/local/Ascend8/ascend-toolkit/set_env.sh" ++fi ++ ++ ++ ++log_info() { ++ echo "[BUILD_INFO][$(date +%b\ %d\ %H:%M:%S)]$*" ++} ++ ++log_error() { ++ echo "[BUILD_ERROR][$(date +%b\ %d\ %H:%M:%S)]$*" ++} ++ ++to_lower() { ++ echo "$1" | tr '[:upper:]' '[:lower:]' ++} ++ ++function build_dllm() { ++ log_info "build dllm wheel" ++ cd ${SRC_DIR} ++ ++ if [ -n "$BUILD_FOR" ]; then ++ export DLLM_BUILD_FOR=${BUILD_FOR} ++ fi ++ ++ python=$(which python${BUILD_VERSION}) || log_error "Could not find python: ${BUILD_VERSION}" || exit 1 ++ if ! ${python} "${SRC_DIR}"/setup.py bdist_wheel; then ++ log_error "Failed to build wheel!" ++ exit 1 ++ fi ++ ++ # for running directly from the source directory ++ find ${SRC_DIR}/build -name '*.so' -exec cp {} ${SRC_DIR}/dllm \; ++ ++ log_info "Success build dllm wheel" ++} ++ ++function install_requirements() { ++ if [ -f "${SRC_DIR}/requirements.txt" ]; then ++ pip${BUILD_VERSION} install -r "${SRC_DIR}"/requirements.txt ++ else ++ echo "requirements.txt does not exist." ++ fi ++} ++ ++ ++ARCH=$(uname -m) ++export COMPILE_WITH_YR=true ++if [ "${ARCH}" == "aarch64" ]; then ++ bash ${BASE_DIR}/clean.sh ++ install_requirements ++ build_dllm & ++ pid_build_dllm=$! ++ wait $pid_build_dllm ++ log_info "Finished building vllm version" ++ exit 0 ++else ++ log_error "It is not system of aarch64" ++fi +\ No newline at end of file +diff --git a/dllm_tools/clean.sh b/dllm_tools/clean.sh +new file mode 100644 +index 000000000..de69cb802 +--- /dev/null ++++ b/dllm_tools/clean.sh +@@ -0,0 +1,21 @@ ++#!/bin/bash ++# ++# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++# ++ ++set -e ++BASE_DIR=$(cd "$(dirname "$0")"; pwd) ++SRC_DIR=${BASE_DIR} ++ ++find "${SRC_DIR}"/ -name "*.egg-info" -type d -exec rm -rf {} + ++find "${SRC_DIR}"/ -name build -maxdepth 1 -type d -exec rm -rf {} + ++find "${SRC_DIR}"/ -name "*.so" -type f -delete ++find "${SRC_DIR}"/ -name "*.whl" -type f -delete ++ ++if [ -d "${SRC_DIR}"/dist ]; then ++ rm -rf "${SRC_DIR}"/dist ++fi ++ ++if [ -d "${SRC_DIR}"/include ]; then ++ rm -rf "${SRC_DIR}"/include ++fi +diff --git a/dllm_tools/cmake/dependency.cmake b/dllm_tools/cmake/dependency.cmake +new file mode 100644 +index 000000000..c42a4895c +--- /dev/null ++++ b/dllm_tools/cmake/dependency.cmake +@@ -0,0 +1,23 @@ ++include(FindThreads) ++ ++# pybind11 ++find_package(pybind11 REQUIRED) ++ ++message("Datasystem_LIBRARY_DIR: ${Datasystem_LIBRARY_DIR}") ++ ++if (Datasystem_LIBRARY_DIR) ++ set(HAS_DS TRUE CACHE BOOL "Enable Building with datasystem") ++ link_directories(${Datasystem_LIBRARY_DIR}) ++ add_library(datasystem SHARED IMPORTED) ++ set_target_properties(datasystem PROPERTIES ++ IMPORTED_LOCATION "${Datasystem_LIBRARY_DIR}/libdatasystem.so" ++ IMPORTED_SONAME "libdatasystem.so" ++ ) ++ include_directories(${Datasystem_INCLUDE_DIR}) ++else () ++ set(HAS_DS FALSE CACHE BOOL "DISABLE Building with datasystem") ++ message(WARNING "Building distributed kv cache without YR") ++endif() ++ ++include(${CMAKE_SOURCE_DIR}/cmake/external_libs/spdlog.cmake) ++include(${CMAKE_SOURCE_DIR}/cmake/external_libs/ascend.cmake) +\ No newline at end of file +diff --git a/dllm_tools/cmake/external_libs/ascend.cmake b/dllm_tools/cmake/external_libs/ascend.cmake +new file mode 100644 +index 000000000..8610a6fc9 +--- /dev/null ++++ b/dllm_tools/cmake/external_libs/ascend.cmake +@@ -0,0 +1,8 @@ ++# The environment variable ASCEND_CUSTOM_PATH is used to locate the Ascend install path. ++# So cmake can find the header files and libraries in the compile stage. ++# If user don't set ASCEND_CUSTOM_PATH, find Ascend in `/usr/local/Ascend/ascend-toolkit/latest` in default. ++set(Ascend_ROOT $ENV{ASCEND_HOME_PATH}) ++ ++find_package(Ascend REQUIRED) ++ ++include_directories(SYSTEM ${ASCEND_INCLUDE_DIR}) +\ No newline at end of file +diff --git a/dllm_tools/cmake/external_libs/spdlog.cmake b/dllm_tools/cmake/external_libs/spdlog.cmake +new file mode 100644 +index 000000000..8d09f90fa +--- /dev/null ++++ b/dllm_tools/cmake/external_libs/spdlog.cmake +@@ -0,0 +1,2 @@ ++find_package(spdlog REQUIRED) ++message(STATUS "Found spdlog") +diff --git a/dllm_tools/cmake/modules/FindAscend.cmake b/dllm_tools/cmake/modules/FindAscend.cmake +new file mode 100644 +index 000000000..737be5ef9 +--- /dev/null ++++ b/dllm_tools/cmake/modules/FindAscend.cmake +@@ -0,0 +1,32 @@ ++# - Find ASCEND (acl_base.h, acl.h, libascendcl.so ) ++# This module defines ++# ASCEND_INCLUDE_DIR, directory containing headers ++# ASCEND_LIBRARY, Location of libascendcl's shared ++# ASCEND_FOUND, whether ascend has been found ++ ++find_path(ASCEND_INCLUDE_DIR acl/acl.h ++ DOC "Path to the ASCEND header file" ++ HINTS ${Ascend_ROOT}/include ++ NO_CMAKE_SYSTEM_PATH ++ NO_SYSTEM_ENVIRONMENT_PATH) ++ ++find_library(ASCEND_LIBRARY ${CMAKE_SHARED_LIBRARY_PREFIX}ascendcl${CMAKE_SHARED_LIBRARY_SUFFIX} ++ ${CMAKE_SHARED_LIBRARY_PREFIX}hccl${CMAKE_SHARED_LIBRARY_SUFFIX} ++ DOC "Path to Ascend library" ++ HINTS ${Ascend_ROOT}/lib64 ++ NO_CMAKE_SYSTEM_PATH ++ NO_SYSTEM_ENVIRONMENT_PATH) ++ ++find_library(HCCL_LIBRARY ++ ${CMAKE_SHARED_LIBRARY_PREFIX}hccl${CMAKE_SHARED_LIBRARY_SUFFIX} ++ DOC "Path to HCCL library" ++ HINTS ${Ascend_ROOT}/lib64 ++ NO_CMAKE_SYSTEM_PATH ++ NO_SYSTEM_ENVIRONMENT_PATH) ++ ++message("ascend lib: ${ASCEND_LIBRARY}") ++message("hccl lib: ${HCCL_LIBRARY}") ++ ++include(FindPackageHandleStandardArgs) ++find_package_handle_standard_args(Ascend REQUIRED_VARS ++ ASCEND_LIBRARY HCCL_LIBRARY ASCEND_INCLUDE_DIR) +diff --git a/dllm_tools/cmake/util.cmake b/dllm_tools/cmake/util.cmake +new file mode 100644 +index 000000000..7f7f6bc4b +--- /dev/null ++++ b/dllm_tools/cmake/util.cmake +@@ -0,0 +1,583 @@ ++include(FetchContent) ++ ++set(THIRDPARTY_SAFE_FLAGS "-fPIC -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -ffunction-sections -fdata-sections -Wl,--gc-sections -Wl,--build-id=none -Wl,-z,relro,-z,noexecstack,-z,now ${EXT_FLAGS}") ++ ++list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/modules) ++ ++# We provide a way to cache the third party libs to avoid the repeated third-party compilation. ++# User can configure the cache by setting the environment variables or CMake configuration, and ++# the priority of the CMake configuration is higher. ++if(NOT YR_OPENSOURCE_DIR) ++ if(DEFINED ENV{YR_OPENSOURCE_DIR}) ++ set(YR_OPENSOURCE_DIR $ENV{YR_OPENSOURCE_DIR}) ++ else() ++ string(SHA256 _TEMP_PATH ${CMAKE_BINARY_DIR}) ++ set(YR_OPENSOURCE_DIR "/tmp/${_TEMP_PATH}") ++ endif() ++endif() ++get_filename_component(YR_OPENSOURCE_DIR ${YR_OPENSOURCE_DIR} ABSOLUTE) ++ ++if (NOT BUILD_THREAD_NUM) ++ set(BUILD_THREAD_NUM 8) ++endif() ++ ++message(STATUS "Cache the third party libs to ${YR_OPENSOURCE_DIR}, " ++ "build them with ${BUILD_THREAD_NUM} parallelism.") ++ ++find_program(Patch_EXECUTABLE patch) ++set(Patch_FOUND ${Patch_EXECUTABLE}) ++find_program(Meson_EXECUTABLE meson) ++set(Meson_FOUND ${Meson_EXECUTABLE}) ++find_program(Ninja_EXECUTABLE ninja) ++set(Ninja_FOUND ${Ninja_EXECUTABLE}) ++ ++function(__EXEC_COMMAND) ++ set(options) ++ set(one_value_args WORKING_DIRECTORY) ++ set(multi_value_args COMMAND) ++ cmake_parse_arguments(ARG "${options}" "${one_value_args}" "${multi_value_args}" ${ARGN}) ++ ++ execute_process(COMMAND ${ARG_COMMAND} ++ WORKING_DIRECTORY ${ARG_WORKING_DIRECTORY} ++ RESULT_VARIABLE _RET) ++ if(NOT _RET EQUAL "0") ++ message(FATAL_ERROR "Fail execute command: ${ARG_COMMAND}, error: ${_RET}") ++ endif() ++endfunction() ++ ++function(DOWNLOAD_LIB_PKG LIB_NAME URL SHA256) ++ # OpenEuler tiny package url end with "rpm" suffix, we need ++ # to uncompress it and get the real source code package. ++ if (URL MATCHES ".*\.src\.rpm$") ++ FetchContent_Declare( ++ "${LIB_NAME}_rpm" ++ URL ${URL} ++ URL_HASH SHA256=${SHA256} ++ ) ++ FetchContent_GetProperties("${LIB_NAME}_rpm") ++ FetchContent_Populate("${LIB_NAME}_rpm") ++ ++ # TODO: need to consider the end suffix with zip, tar and so on. ++ file(GLOB _URL_LIST "${${LIB_NAME}_rpm_SOURCE_DIR}/${LIB_NAME}*\.tar\.gz" "${${LIB_NAME}_rpm_SOURCE_DIR}/${LIB_NAME}*\.tar\.xz") ++ if (NOT _URL_LIST) ++ message(FATAL_ERROR "Failed to find source package from ${${LIB_NAME}_rpm_SOURCE_DIR}") ++ endif() ++ list(GET _URL_LIST 0 URL) ++ list(LENGTH _URL_LIST _URL_LIST_LEN) ++ if (_URL_LIST_LEN GREATER 1) ++ message(WARNING "Get source package is more than 1, but we only choose the first one: ${URL}") ++ endif() ++ ++ file(SHA256 "${URL}" SHA256) ++ endif() ++ ++ FetchContent_Declare( ++ ${LIB_NAME} ++ URL ${URL} ++ URL_HASH SHA256=${SHA256} ++ ) ++ FetchContent_GetProperties(${LIB_NAME}) ++ message(STATUS "Download ${LIB_NAME} from ${URL}") ++ if(NOT ${LIB_NAME}_POPULATED) ++ FetchContent_Populate(${LIB_NAME}) ++ set(${LIB_NAME}_SOURCE_DIR ${${LIB_NAME}_SOURCE_DIR} PARENT_SCOPE) ++ set(${LIB_NAME}_BINARY_DIR ${${LIB_NAME}_BINARY_DIR} PARENT_SCOPE) ++ endif() ++endfunction() ++ ++# Generate fake third party tar package, it is all about trustworthiness. ++# ++# Arguments: ++# NAME ++# Specify the package name of the third-party library. ++# ++# URL ++# Specify an output variable, path of the tar package is assigned to this variable. ++# ++# SHA256 ++# Specify an output variable, SHA256 sum code of the tar package is assigned to this variable. ++# ++# FAKE_SHA256 ++# A fake sha256, useful when building third party components from source code. ++# ++# VERSION ++# Specify an output variable, the version of the package is assigned to this variable if version.txt exists. ++# The version.txt is provided by the user/CI to specify the source code's version. ++# The version.txt contains the version of package that must be equal to the version value provided by ConfigVersion.cmake, ++# otherwise cmake find_package may fail ++# The location of version.txt is ${YR_PACKAGE}/${PACKAGFE_NAME}/version.txt, like /usr1/third_party/zlib/version.txt ++function(GEN_THIRDPARTY_PKG NAME URL SHA256 FAKE_SHA256 VERSION) ++ get_filename_component(_THIRDPARTY_DIR "$ENV{YR_PACKAGE}" ABSOLUTE) ++ set(_DIR "${_THIRDPARTY_DIR}/${NAME}") ++ ++ set(VERSION_TXT ${_DIR}/version.txt) ++ if (EXISTS ${VERSION_TXT}) ++ file(READ ${VERSION_TXT} _VERSION) ++ string(STRIP ${_VERSION} _VERSION) ++ endif() ++ ++ if ("${_VERSION}" STREQUAL "") ++ MESSAGE("The ${NAME} directory don't contain version.txt or it's empty") ++ else() ++ MESSAGE("Found thirdparty library ${NAME} version is ${_VERSION} in version.txt") ++ endif() ++ if (NOT EXISTS "${_DIR}" OR NOT IS_DIRECTORY "${_DIR}") ++ message(FATAL_ERROR "Specify path: ${_DIR} not exist or is not a directory!") ++ endif() ++ ++ set(_SUFFIX_LIST ".tar.gz" ".tar.xz" ".tar.bz2" ".zip") ++ if (${NAME} STREQUAL "re2" OR ${NAME} STREQUAL "absl") ++ list(TRANSFORM _SUFFIX_LIST PREPEND ${_DIR}/* OUTPUT_VARIABLE _DIR_SUFFIX_LIST) ++ else() ++ list(TRANSFORM _SUFFIX_LIST PREPEND ${_DIR}/${NAME}* OUTPUT_VARIABLE _DIR_SUFFIX_LIST) ++ endif() ++ file(GLOB _TAR_PKG_FILE ${_DIR_SUFFIX_LIST}) ++ ++ if (_TAR_PKG_FILE) ++ # OpenEuler would save the tar file in it's directory, unthinkable operation. ++ list(GET _TAR_PKG_FILE 0 _DEST_PATH) ++ list(LENGTH _TAR_PKG_FILE _TAR_PKG_LEN) ++ if (_TAR_PKG_LEN GREATER 1) ++ message(WARNING "Get tar file is more than 1, but we only choose the first one: ${_DEST_PATH}") ++ endif() ++ file(SHA256 "${_DEST_PATH}" _SHA256) ++ else() ++ # Step1: Generate sha256 based on source code. ++ execute_process(COMMAND sh -c "find ${_DIR} -path ${_DIR}/.git -prune -o -type f -print0 | sort -z | xargs -0 cat | sha256sum" ++ OUTPUT_VARIABLE _FINAL_SHA256_VALUE ++ RESULT_VARIABLE _RET) ++ if(NOT _RET EQUAL "0") ++ message(FATAL_ERROR "Fail to find files in source code, error: ${_RET}") ++ endif() ++ string(REGEX REPLACE "\ .*$" "" _FINAL_SHA256_VALUE "${_FINAL_SHA256_VALUE}") ++ # Step2: Generate fake *tar.gz ++ find_program(_TAR_EXECUTABLE tar) ++ if (NOT _TAR_EXECUTABLE) ++ message(FATAL_ERROR "tar command not found!") ++ endif() ++ ++ set(_DEST_PATH "${CMAKE_CURRENT_BINARY_DIR}/${NAME}.tar.gz") ++ if (EXISTS "${_DEST_PATH}") ++ file(REMOVE "${_DEST_PATH}") ++ endif() ++ __exec_command(COMMAND ${_TAR_EXECUTABLE} -zmcf "${_DEST_PATH}" -C "${_THIRDPARTY_DIR}" "${NAME}" ++ WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) ++ file(SHA256 "${_DEST_PATH}" _SHA256) ++ endif() ++ ++ # Set output variables. ++ set(${URL} "${_DEST_PATH}" PARENT_SCOPE) ++ set(${SHA256} "${_SHA256}" PARENT_SCOPE) ++ set(${FAKE_SHA256} "${_FINAL_SHA256_VALUE}" PARENT_SCOPE) ++ if (NOT "${_VERSION}" STREQUAL "") ++ set(${VERSION} "${_VERSION}" PARENT_SCOPE) ++ endif() ++endfunction() ++ ++# Add a third-party dependency library on which the Datasystem Depends. ++# ++# LIB_NAME is the name of the library. ++# ++# Additional optional arguments: ++# ++# URL ++# Specify download url, it can be a link or file path. ++# ++# SHA256 ++# Specify package provided by URL sha256 sum for check purpose. ++# ++# FAKE_SHA256 ++# A fake sha256, useful when building third party components from source code. ++# ++# VERSION ++# Specify library version. ++# ++# TOOLCHAIN ++# Specify compile toolchain, support cmake, configure and so on. ++# ++# CONF_PATH ++# Specify configure file path, it can be useful if the configure ++# file is not in root dir. ++# ++# COMPONENTS ... ++# Specify the components we need from third-party library, if we ++# don't need all components, we can just specify the components ++# what we need by this argument. ++# ++# CONF_OPTIONS ... ++# Specify the configure options, the value depends on which ++# toolchain we use. ++# ++# PRE_CONFIGURE ... ++# Specify the pre-configure command before execute the configure, ++# e.g. sh autogen.sh . ++# ++# PATHCES ... ++# Specify the patch files path, they would be apply before compile. ++# ++# CXX_FLAGS ... ++# Specify the CXX compile flags. ++# ++# C_FLAGS ... ++# Specify the C compile flags. ++# ++# LINK_FLAGS ... ++# Specify the link flags. ++# ++# EXTRA_MSGS ... ++# Specify the extra messages, it is helpful when third-party lib also ++# have dependent libraries. If dependent libraries changed, the lib ++# would be force to update. ++function(ADD_THIRDPARTY_LIB LIB_NAME) ++ set(options) ++ set(one_value_args URL SHA256 FAKE_SHA256 VERSION TOOLCHAIN CONF_PATH) ++ set(multi_value_args COMPONENTS CONF_OPTIONS PRE_CONFIGURE PATCHES CXX_FLAGS C_FLAGS LINK_FLAGS EXTRA_MSGS) ++ cmake_parse_arguments(ARG "${options}" "${one_value_args}" "${multi_value_args}" ${ARGN}) ++ ++ string(TOLOWER ${LIB_NAME} _LIB_NAME_LOWER) ++ ++ if(NOT ARG_TOOLCHAIN) ++ set(ARG_TOOLCHAIN "cmake") ++ endif() ++ ++ # Generate a unique install dir name, the impact factors are as follow: ++ # Lib name: ++ set(${LIB_NAME}_CONF_TXT "${_LIB_NAME_LOWER}") ++ if(NOT ${ARG_FAKE_SHA256} STREQUAL "") ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_FAKE_SHA256}") ++ else() ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_SHA256}") ++ endif() ++ # Version: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_VERSION}") ++ # Components: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_COMPONENTS}") ++ # Toolchain: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_TOOLCHAIN}") ++ # Configure options: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_CONF_OPTIONS}") ++ # CXX compiler version: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${CMAKE_CXX_COMPILER_VERSION}") ++ # C compiler version: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${CMAKE_C_COMPILER_VERSION}") ++ # CXX_FLAGS: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_CXX_FLAGS}") ++ # C_FLAGS: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_C_FLAGS}") ++ # LINK_FLAGS: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_LINK_FLAGS}") ++ # Patch files: ++ foreach(_PATCH ${ARG_PATCHES}) ++ file(SHA256 ${_PATCH} _PATCH_SHA256) ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}}_${_PATCH_SHA256}") ++ endforeach() ++ # Extra messages: ++ foreach(_MSG ${ARG_EXTRA_MSGS}) ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}}_${_MSG}") ++ endforeach() ++ string(REPLACE ";" "_" ${LIB_NAME}_CONF_TXT ${${LIB_NAME}_CONF_TXT}) ++ string(SHA256 _ROOT_SUFFIX ${${LIB_NAME}_CONF_TXT}) ++ set(${LIB_NAME}_ROOT "${YR_OPENSOURCE_DIR}/${_LIB_NAME_LOWER}_${_ROOT_SUFFIX}") ++ ++ # Check if we have cache the lib, if true, reuse it directly. ++ set(_VERIFY_FILE "${${LIB_NAME}_ROOT}/${LIB_NAME}_install.txt") ++ if(EXISTS ${${LIB_NAME}_ROOT}) ++ if (EXISTS ${_VERIFY_FILE}) ++ set(${LIB_NAME}_FOUND TRUE) ++ endif() ++ ++ if(${LIB_NAME}_FOUND) ++ message(STATUS "${LIB_NAME} found in ${${LIB_NAME}_ROOT}...") ++ if (EXISTS ${${LIB_NAME}_ROOT}/lib64) ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib64 PARENT_SCOPE) ++ else() ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib PARENT_SCOPE) ++ endif() ++ set(${LIB_NAME}_ROOT "${${LIB_NAME}_ROOT}" PARENT_SCOPE) ++ return() ++ else() ++ message(STATUS "${LIB_NAME} not found in ${${LIB_NAME}_ROOT}, need recompile...") ++ # Well, although the cache directory exists, it appears to be corrupted (because we can't find ++ # it via find_package). So remove the directory directly and we will recompile the lib. ++ file(REMOVE_RECURSE "${${LIB_NAME}_ROOT}") ++ endif() ++ endif() ++ ++ # Fetch the package first. ++ download_lib_pkg(${_LIB_NAME_LOWER} ${ARG_URL} ${ARG_SHA256}) ++ ++ # Apply the patches if need. ++ foreach(_PATCH ${ARG_PATCHES}) ++ if (NOT Patch_FOUND) ++ message(FATAL_ERROR "patch executable not found!") ++ endif() ++ execute_process(COMMAND ${Patch_EXECUTABLE} -p1 INPUT_FILE ${_PATCH} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR} ++ RESULT_VARIABLE _RET) ++ if(NOT _RET EQUAL "0") ++ message("Patch ${_PATCH} failed, error: ${_RET}") ++ endif() ++ endforeach() ++ ++ # Compile the source code and install. ++ message(STATUS "Compiling ${LIB_NAME} in ${${_LIB_NAME_LOWER}_BINARY_DIR}") ++ string(TOLOWER ${ARG_TOOLCHAIN} _TOOLCHAIN_LOWER) ++ ++ if(${_LIB_NAME_LOWER} STREQUAL "tbb" AND ${_TOOLCHAIN_LOWER} STREQUAL "make") ++ find_program(MAKE_PROGRAM make) # CMAKE_MAKE_PROGRAM may not be make. ++ if (NOT MAKE_PROGRAM) ++ message(FATAL_ERROR "make program not found! Please install make.") ++ endif() ++ if (ARG_CXX_FLAGS) ++ list(APPEND ${LIB_NAME}_MAKE_CXXFLAGS "CXXFLAGS=${ARG_CXX_FLAGS}") ++ endif() ++ if (ARG_C_FLAGS) ++ list(APPEND ${LIB_NAME}_MAKE_CFLAGS "CFLAGS=${ARG_C_FLAGS}") ++ endif() ++ if (ARG_LINK_FLAGS) ++ list(APPEND ${LIB_NAME}_MAKE_LDFLAGS "LDFLAGS=${ARG_LINK_FLAGS}") ++ endif() ++ __exec_command(COMMAND ${MAKE_PROGRAM} ++ ${${LIB_NAME}_MAKE_CFLAGS} ++ ${${LIB_NAME}_MAKE_CXXFLAGS} ++ ${${LIB_NAME}_MAKE_LDFLAGS} ++ -j ${BUILD_THREAD_NUM} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}) ++ # Copy headers ++ file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/include/tbb ++ DESTINATION ${${LIB_NAME}_ROOT}/include) ++ # Copy libs ++ file(GLOB_RECURSE _LIBS_LIST "${${_LIB_NAME_LOWER}_SOURCE_DIR}/*lib*.so*") ++ foreach(_LIB ${_LIBS_LIST}) ++ file(COPY "${_LIB}" DESTINATION ${${LIB_NAME}_ROOT}/lib) ++ endforeach() ++ elseif(${_TOOLCHAIN_LOWER} STREQUAL "cmake") ++ if (ARG_CXX_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "-DCMAKE_CXX_FLAGS=${ARG_CXX_FLAGS}") ++ endif() ++ if (ARG_C_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "-DCMAKE_C_FLAGS=${ARG_C_FLAGS}") ++ endif() ++ if (ARG_LINK_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "-DCMAKE_SHARED_LINKER_FLAGS=${ARG_LINK_FLAGS}") ++ endif() ++ list(APPEND ARG_CONF_OPTIONS "-DCMAKE_INSTALL_PREFIX=${${LIB_NAME}_ROOT}") ++ ++ __exec_command(COMMAND ${CMAKE_COMMAND} -G ${CMAKE_GENERATOR} ++ ${ARG_CONF_OPTIONS} ++ ${${_LIB_NAME_LOWER}_SOURCE_DIR}/${ARG_CONF_PATH} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_BINARY_DIR}) ++ ++ __exec_command(COMMAND ${CMAKE_COMMAND} --build . --target install -- -j${BUILD_THREAD_NUM} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_BINARY_DIR}) ++ ++ elseif(${_TOOLCHAIN_LOWER} STREQUAL "configure") ++ # If we need to do something before run ./configure, just do it. ++ if (ARG_PRE_CONFIGURE) ++ __exec_command(COMMAND ${ARG_PRE_CONFIGURE} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}) ++ endif() ++ ++ # Add compile flags, install prefix and run configure. ++ if (ARG_CXX_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "CXXFLAGS=${ARG_CXX_FLAGS}") ++ endif() ++ if (ARG_C_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "CFLAGS=${ARG_C_FLAGS}") ++ endif() ++ if (ARG_LINK_FLAGS) ++ list(APPEND ARG_CONF_OPTIONS "LDFLAGS=${ARG_LINK_FLAGS}") ++ endif() ++ list(APPEND ARG_CONF_OPTIONS "--prefix=${${LIB_NAME}_ROOT}") ++ ++ if (EXISTS ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config) ++ set(_CONFIG_FILE ${${_LIB_NAME_LOWER}_SOURCE_DIR}/config) ++ else() ++ set(_CONFIG_FILE ${${_LIB_NAME_LOWER}_SOURCE_DIR}/configure) ++ endif() ++ ++ __exec_command(COMMAND sh ${_CONFIG_FILE} ${ARG_CONF_OPTIONS} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}) ++ ++ # make -j && make install ++ if (NOT CMAKE_MAKE_PROGRAM) ++ message(FATAL_ERROR "make program not found!") ++ endif() ++ __exec_command(COMMAND ${CMAKE_MAKE_PROGRAM} ++ ${${LIB_NAME}_MAKE_CFLAGS} ++ ${${LIB_NAME}_MAKE_CXXFLAGS} ++ ${${LIB_NAME}_MAKE_LDFLAGS} ++ -j ${BUILD_THREAD_NUM} ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}) ++ __exec_command(COMMAND ${CMAKE_MAKE_PROGRAM} install ++ WORKING_DIRECTORY ${${_LIB_NAME_LOWER}_SOURCE_DIR}) ++ else() ++ message(FATAL_ERROR "Unrecognized toolchain: ${ARG_TOOLCHAIN}") ++ endif() ++ ++ # Write install text to root dir for verify purpose. ++ file(WRITE "${_VERIFY_FILE}" "${${LIB_NAME}_CONF_TXT}") ++ ++ # For output package variables. ++ if (EXISTS ${${LIB_NAME}_ROOT}/lib64) ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib64 PARENT_SCOPE) ++ else() ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib PARENT_SCOPE) ++ endif() ++ ++ set(${LIB_NAME}_ROOT ${${LIB_NAME}_ROOT} PARENT_SCOPE) ++endfunction() ++ ++# Adjuice third-party dependency library version. ++# ++# LIB_NAME is the name of the library. ++# ++# Output variables: ++# ${LIB_NAME}_VERSION ++# third-party library version. ++# ++# ${LIB_NAME}_URL ++# third-party library download url. ++# ++# ${LIB_NAME}_SHA256 ++# third-party library SHA256 for verify. ++function(ADJUICE_THIRDPARTY_VERSION LIB_NAME) ++ if ("$ENV{YR_PACKAGE}" STREQUAL "") ++ if (${LIB_NAME}_VERSION) ++ list(FIND ${LIB_NAME}_VERSIONS "${${LIB_NAME}_VERSION}" ${LIB_NAME}_INDEX) ++ if (${LIB_NAME}_INDEX EQUAL -1) ++ message(FATAL_ERROR "Unsupported protobuf version: ${${LIB_NAME}_VERSION}, available versions are: ${${LIB_NAME}_VERSIONS}") ++ endif() ++ else() ++ set(${LIB_NAME}_INDEX 0) ++ endif() ++ list(GET ${LIB_NAME}_VERSIONS ${${LIB_NAME}_INDEX} ${LIB_NAME}_VERSION) ++ list(GET ${LIB_NAME}_URLS ${${LIB_NAME}_INDEX} ${LIB_NAME}_URL) ++ list(GET ${LIB_NAME}_SHA256S ${${LIB_NAME}_INDEX} ${LIB_NAME}_SHA256) ++ else() ++ gen_thirdparty_pkg(${LIB_NAME} ${LIB_NAME}_URL ${LIB_NAME}_SHA256 ${LIB_NAME}_FAKE_SHA256 ${LIB_NAME}_VERSION) ++ endif() ++ set(${LIB_NAME}_VERSION ${${LIB_NAME}_VERSION} PARENT_SCOPE) ++ set(${LIB_NAME}_URL ${${LIB_NAME}_URL} PARENT_SCOPE) ++ set(${LIB_NAME}_SHA256 ${${LIB_NAME}_SHA256} PARENT_SCOPE) ++ set(${LIB_NAME}_FAKE_SHA256 ${${LIB_NAME}_FAKE_SHA256} PARENT_SCOPE) ++endfunction() ++ ++function(INSTALL_DATASYSTEM_TARGET TARGET) ++ set(options) ++ set(one_value_args EXPORT_NAME) ++ set(multi_value_args) ++ cmake_parse_arguments(ARG "${options}" "${one_value_args}" "${multi_value_args}" ${ARGN}) ++ if (NOT ${TARGET}_INSTALL_LIBPATH) ++ set(${TARGET}_INSTALL_LIBPATH lib) ++ endif() ++ ++ if (NOT ${TARGET}_INSTALL_BINPATH) ++ set(${TARGET}_INSTALL_BINPATH bin) ++ endif() ++ ++ if (ARG_EXPORT_NAME) ++ install(TARGETS ${TARGET} ++ EXPORT ${ARG_EXPORT_NAME} ++ ARCHIVE DESTINATION ${${TARGET}_INSTALL_LIBPATH} ++ LIBRARY DESTINATION ${${TARGET}_INSTALL_LIBPATH} ++ RUNTIME DESTINATION ${${TARGET}_INSTALL_BINPATH}) ++ else() ++ install(TARGETS ${TARGET} ++ ARCHIVE DESTINATION ${${TARGET}_INSTALL_LIBPATH} ++ LIBRARY DESTINATION ${${TARGET}_INSTALL_LIBPATH} ++ RUNTIME DESTINATION ${${TARGET}_INSTALL_BINPATH}) ++ endif() ++endfunction() ++ ++# for git information ++macro(get_git_branch git_branch_out_var) ++ find_package(Git QUIET) ++ if (GIT_FOUND) ++ execute_process( ++ COMMAND ${GIT_EXECUTABLE} symbolic-ref --short -q HEAD ++ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ++ ERROR_QUIET ++ OUTPUT_VARIABLE ${git_branch_out_var} ++ OUTPUT_STRIP_TRAILING_WHITESPACE ++ ) ++ endif () ++endmacro() ++ ++macro(get_git_hash git_hash_out_var) ++ find_package(Git QUIET) ++ if (GIT_FOUND) ++ execute_process( ++ COMMAND ${GIT_EXECUTABLE} log -1 "--pretty=format:[%H] [%ai]" ++ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ++ ERROR_QUIET ++ OUTPUT_VARIABLE ${git_hash_out_var} ++ OUTPUT_STRIP_TRAILING_WHITESPACE ++ ) ++ endif () ++endmacro() ++ ++function(ADD_THIRDPARTY_SO LIB_NAME) ++ set(options) ++ set(one_value_args URL SHA256 VERSION TOOLCHAIN CONF_PATH) ++ set(multi_value_args COMPONENTS CONF_OPTIONS PRE_CONFIGURE PATCHES CXX_FLAGS C_FLAGS LINK_FLAGS EXTRA_MSGS) ++ cmake_parse_arguments(ARG "${options}" "${one_value_args}" "${multi_value_args}" ${ARGN}) ++ ++ string(TOLOWER ${LIB_NAME} _LIB_NAME_LOWER) ++ ++ # Generate a unique install dir name, the impact factors are as follow: ++ # Lib name: ++ set(${LIB_NAME}_CONF_TXT "${_LIB_NAME_LOWER}") ++ # SHA256: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_SHA256}") ++ # Version: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_VERSION}") ++ # Components: ++ set(${LIB_NAME}_CONF_TXT "${${LIB_NAME}_CONF_TXT}_${ARG_COMPONENTS}") ++ ++ string(REPLACE ";" "_" ${LIB_NAME}_CONF_TXT ${${LIB_NAME}_CONF_TXT}) ++ string(SHA256 _ROOT_SUFFIX ${${LIB_NAME}_CONF_TXT}) ++ set(${LIB_NAME}_ROOT "${YR_OPENSOURCE_DIR}/${_LIB_NAME_LOWER}_${_ROOT_SUFFIX}") ++ ++ # Check if we have cache the lib, if true, reuse it directly. ++ set(_VERIFY_FILE "${${LIB_NAME}_ROOT}/${LIB_NAME}_install.txt") ++ if(EXISTS ${${LIB_NAME}_ROOT}) ++ if (EXISTS ${_VERIFY_FILE}) ++ set(${LIB_NAME}_FOUND TRUE) ++ endif() ++ ++ if(${LIB_NAME}_FOUND) ++ message(STATUS "${LIB_NAME} found in ${${LIB_NAME}_ROOT}...") ++ if (EXISTS ${${LIB_NAME}_ROOT}/lib64) ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib64 PARENT_SCOPE) ++ else() ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib PARENT_SCOPE) ++ endif() ++ set(${LIB_NAME}_ROOT "${${LIB_NAME}_ROOT}" PARENT_SCOPE) ++ return() ++ else() ++ message(STATUS "${LIB_NAME} not found in ${${LIB_NAME}_ROOT}, need recompile...") ++ # Well, although the cache directory exists, it appears to be corrupted (because we can't find ++ # it via find_package). So remove the directory directly and we will recompile the lib. ++ file(REMOVE_RECURSE "${${LIB_NAME}_ROOT}") ++ endif() ++ endif() ++ ++ # Fetch the package first. ++ download_lib_pkg(${_LIB_NAME_LOWER} ${ARG_URL} ${ARG_SHA256}) ++ ++ file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/lib/release/libStsSdk.so DESTINATION ${${LIB_NAME}_ROOT}/lib) ++ file(COPY ${${_LIB_NAME_LOWER}_SOURCE_DIR}/include DESTINATION ${${LIB_NAME}_ROOT}) ++ ++ # Write install text to root dir for verify purpose. ++ file(WRITE "${_VERIFY_FILE}" "${${LIB_NAME}_CONF_TXT}") ++ ++ # For output package variables. ++ if (EXISTS ${${LIB_NAME}_ROOT}/lib64) ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib64 PARENT_SCOPE) ++ else() ++ set(${LIB_NAME}_LIB_PATH ${${LIB_NAME}_ROOT}/lib PARENT_SCOPE) ++ endif() ++ ++ set(${LIB_NAME}_ROOT ${${LIB_NAME}_ROOT} PARENT_SCOPE) ++endfunction() +diff --git a/dllm_tools/csrc/.clang-format b/dllm_tools/csrc/.clang-format +new file mode 100644 +index 000000000..08e61ec8f +--- /dev/null ++++ b/dllm_tools/csrc/.clang-format +@@ -0,0 +1,78 @@ ++--- ++Language: Cpp ++BasedOnStyle: Google ++AccessModifierOffset: -4 ++AlignAfterOpenBracket: true ++AlignEscapedNewlinesLeft: true ++AlignOperands: true ++AlignTrailingComments: true ++AllowAllParametersOfDeclarationOnNextLine: true ++AllowShortBlocksOnASingleLine: false ++AllowShortCaseLabelsOnASingleLine: false ++AllowShortIfStatementsOnASingleLine: false ++AllowShortLoopsOnASingleLine: false ++AllowShortFunctionsOnASingleLine: false ++AlwaysBreakAfterDefinitionReturnType: false ++AlwaysBreakTemplateDeclarations: true ++AlwaysBreakBeforeMultilineStrings: true ++BreakBeforeTernaryOperators: true ++BinPackParameters: true ++BinPackArguments: true ++ColumnLimit: 120 ++ConstructorInitializerAllOnOneLineOrOnePerLine: true ++ConstructorInitializerIndentWidth: 4 ++DerivePointerAlignment: false ++PointerAlignment: Right ++ExperimentalAutoDetectBinPacking: false ++IndentCaseLabels: true ++IndentWrappedFunctionNames: false ++IndentFunctionDeclarationAfterType: false ++MaxEmptyLinesToKeep: 1 ++KeepEmptyLinesAtTheStartOfBlocks: false ++NamespaceIndentation: None ++ObjCBlockIndentWidth: 4 ++ObjCSpaceAfterProperty: false ++ObjCSpaceBeforeProtocolList: false ++PenaltyBreakBeforeFirstCallParameter: 1 ++PenaltyBreakComment: 300 ++PenaltyBreakString: 1000 ++PenaltyBreakFirstLessLess: 120 ++PenaltyExcessCharacter: 1000000 ++PenaltyReturnTypeOnItsOwnLine: 200 ++Cpp11BracedListStyle: false ++Standard: Auto ++IndentWidth: 4 ++SortIncludes: false ++TabWidth: 4 ++UseTab: Never ++SpacesInParentheses: false ++SpacesInSquareBrackets: false ++SpacesInAngles: false ++SpaceInEmptyParentheses: false ++SpacesInCStyleCastParentheses: false ++SpaceAfterCStyleCast: false ++SpacesInContainerLiterals: true ++SpaceBeforeAssignmentOperators: true ++SpacesBeforeTrailingComments: 2 ++ContinuationIndentWidth: 4 ++CommentPragmas: '^ IWYU pragma:' ++ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] ++SpaceBeforeParens: ControlStatements ++DisableFormat: false ++BraceWrapping: ++ AfterClass: false ++ AfterControlStatement: false ++ AfterEnum: false ++ AfterFunction: true ++ AfterNamespace: false ++ AfterObjCDeclaration: false ++ AfterStruct: false ++ AfterUnion: false ++ BeforeCatch: false ++ BeforeElse: false ++ IndentBraces: false ++BreakBeforeBinaryOperators: NonAssignment ++BreakBeforeBraces: Custom ++BreakConstructorInitializersBeforeComma: false ++... ++ +diff --git a/dllm_tools/csrc/CMakeLists.txt b/dllm_tools/csrc/CMakeLists.txt +new file mode 100644 +index 000000000..f55e4ba74 +--- /dev/null ++++ b/dllm_tools/csrc/CMakeLists.txt +@@ -0,0 +1,23 @@ ++if (NOT HAS_DS) ++ message(FATAL_ERROR "Can not build if datasystem SDK is not presented.") ++endif() ++ ++include_directories(./) ++ ++aux_source_directory(./ CPP_EXT_SOURCES) ++aux_source_directory(kvc CPP_EXT_SOURCES) ++aux_source_directory(perf CPP_EXT_SOURCES) ++aux_source_directory(utils CPP_EXT_SOURCES) ++ ++# for python user ++pybind11_add_module(cpp_ext MODULE ${CPP_EXT_SOURCES}) ++target_link_libraries(cpp_ext PRIVATE pybind11::module spdlog::spdlog datasystem) ++set_target_properties(cpp_ext PROPERTIES INSTALL_RPATH $ORIGIN) ++target_link_options(cpp_ext PRIVATE "-Wl,--disable-new-dtags") ++ ++# for c++ user ++list(REMOVE_ITEM CPP_EXT_SOURCES .//pybind_register.cpp) ++add_library(dllm_cpp_ext SHARED ${CPP_EXT_SOURCES}) ++target_link_libraries(dllm_cpp_ext PRIVATE spdlog::spdlog datasystem) ++set_target_properties(dllm_cpp_ext PROPERTIES INSTALL_RPATH $ORIGIN) ++target_link_options(dllm_cpp_ext PRIVATE "-Wl,--disable-new-dtags") +diff --git a/dllm_tools/csrc/include/kvc/c_api.h b/dllm_tools/csrc/include/kvc/c_api.h +new file mode 100644 +index 000000000..a55867f9b +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/c_api.h +@@ -0,0 +1,116 @@ ++ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++#ifndef DLLM_KVC_C_API_H ++#define DLLM_KVC_C_API_H ++ ++#include ++#include ++ ++#define DLLMKVC_API __attribute__((visibility("default"))) ++ ++#ifdef __cplusplus ++extern "C" { ++#endif ++ ++typedef void *DLLMKVC_Handle; ++typedef void *DLLMKVC_List; ++ ++DLLMKVC_API DLLMKVC_List DLLMKVC_Str_List(size_t len); ++DLLMKVC_API void DLLMKVC_Str_FreeList(DLLMKVC_List list); ++DLLMKVC_API const char* DLLMKVC_Str_ListGetElem(DLLMKVC_List list, size_t index); ++DLLMKVC_API void DLLMKVC_Str_ListSetElem(DLLMKVC_List list, size_t index, const char* str); ++ ++// Blob ++DLLMKVC_API DLLMKVC_List DLLMKVC_Blob_List(size_t len); ++DLLMKVC_API void DLLMKVC_Blob_FreeList(DLLMKVC_List list); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Blob_ListGetElem(DLLMKVC_List list, size_t index); ++DLLMKVC_API void DLLMKVC_Blob_Init(DLLMKVC_Handle blob, void* pointer, uint64_t size); ++ ++// DeviceBlobList ++DLLMKVC_API DLLMKVC_List DLLMKVC_DeviceBlobList_List(size_t len); ++DLLMKVC_API void DLLMKVC_DeviceBlobList_FreeList(DLLMKVC_List list); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_DeviceBlobList_ListGetElem(DLLMKVC_List list, size_t index); ++DLLMKVC_API void DLLMKVC_DeviceBlobList_Init(DLLMKVC_Handle devBlobList, int32_t deviceIdx, DLLMKVC_List blobs); ++ ++// KvcStore ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_API DLLMKVC_Store(); ++DLLMKVC_API void DLLMKVC_Store_Free(DLLMKVC_Handle store); ++DLLMKVC_API void DLLMKVC_Store_Init(DLLMKVC_Handle store, const char* host, int32_t port, int32_t connTimeoutMs); ++DLLMKVC_API void DLLMKVC_Store_PutD2D( ++ DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists, ++ DLLMKVC_List outFutures); ++DLLMKVC_API void DLLMKVC_Store_GetD2D( ++ DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists, ++ DLLMKVC_List outFutures); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Store_MGetH2D(DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Store_MSetD2H(DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Store_Delete(DLLMKVC_Handle store, DLLMKVC_List keys); ++DLLMKVC_API void DLLMKVC_Store_Exist(DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ bool* outExists, ++ size_t outExistsLen); ++ ++// KvcFuture ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Future(); ++DLLMKVC_API void DLLMKVC_Future_Free(DLLMKVC_Handle future); ++DLLMKVC_API DLLMKVC_List DLLMKVC_Future_List(size_t len); ++DLLMKVC_API void DLLMKVC_Future_FreeList(DLLMKVC_List list); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Future_ListGetElem(DLLMKVC_List list, size_t index); ++DLLMKVC_API int32_t DLLMKVC_Future_WaitFor(DLLMKVC_Handle future, uint32_t ms); ++DLLMKVC_API DLLMKVC_Handle DLLMKVC_Future_Get(DLLMKVC_Handle future); ++ ++// KvcResult ++DLLMKVC_API void DLLMKVC_Result_Free(DLLMKVC_Handle result); ++DLLMKVC_API const char* DLLMKVC_Result_GetErrMsg(const DLLMKVC_Handle result); ++DLLMKVC_API int32_t DLLMKVC_Result_GetCode(DLLMKVC_Handle result); ++ ++ ++#define DLLMKVC_MAX_SHAPE_LEN 16 ++ ++struct DLLMKVC_Tensor { ++ uint64_t ptr; ++ uint32_t elemSize; ++ uint64_t shape[DLLMKVC_MAX_SHAPE_LEN]; ++ size_t shapeLen; ++}; ++ ++// PageAttnUtils ++DLLMKVC_API void DLLMKVC_PageAttn_LayerwiseDevBlobLists( ++ int32_t deviceIdx, ++ const DLLMKVC_Tensor* tensors, size_t numTensors, ++ const uint32_t* blockIds, size_t numBlocks, ++ DLLMKVC_List outDblList); ++ ++DLLMKVC_API void DLLMKVC_PageAttn_BlockwiseDevBlobLists( ++ int32_t deviceIdx, ++ const DLLMKVC_Tensor* tensors, size_t numTensors, ++ const uint32_t* blockIds, size_t numBlocks, ++ DLLMKVC_List outDblList); ++ ++ ++#ifdef __cplusplus ++} ++#endif ++#endif // DLLM_KVC_C_API_H +diff --git a/dllm_tools/csrc/include/kvc/common.h b/dllm_tools/csrc/include/kvc/common.h +new file mode 100644 +index 000000000..fee6bb761 +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/common.h +@@ -0,0 +1,59 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_COMMON_H ++#define DLLM_KVC_COMMON_H ++ ++#include ++#include ++#include ++ ++namespace dllm { ++namespace kvc { ++ ++constexpr int DEFAULT_NUM_THREADS = 2; ++ ++struct Blob { ++ void* pointer = nullptr; ++ uint64_t size = 0; ++}; ++ ++struct DeviceBlobList { ++ std::vector blobs; ++ int32_t deviceIdx = -1; ++}; ++ ++struct KvcTensor { ++ uint64_t ptr; ++ uint32_t elemSize; ++ std::vector shape; ++}; ++ ++struct KvcResult { ++ static KvcResult Ok() {return KvcResult(0, {}, {});} ++ KvcResult() = default; ++ KvcResult(int32_t _statusCode, const std::string &_errorMessage, std::vector &&_failedList) ++ :statusCode(_statusCode), ++ errorMessage(_errorMessage), ++ failedList(std::move(_failedList)) ++ {} ++ int32_t statusCode; ++ std::string errorMessage; ++ std::vector failedList; ++}; ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_COMMON_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/include/kvc/kvc_future.h b/dllm_tools/csrc/include/kvc/kvc_future.h +new file mode 100644 +index 000000000..c83fc7b1e +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/kvc_future.h +@@ -0,0 +1,139 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_KVC_FUTURE_H ++#define DLLM_KVC_KVC_FUTURE_H ++ ++#include ++#include ++#include ++#include ++ ++#include "common.h" ++#include "c_api.h" ++ ++namespace datasystem { ++class Future; ++} ++ ++namespace dllm { ++namespace kvc { ++ ++class KvcFuture { ++public: ++ KvcFuture() = default; ++ ++ // mimic std::future interface ++ std::future_status wait_for(uint64_t ms); ++ KvcResult get(); ++ ++private: ++ friend class KvcStore; ++ ++ enum class From { ++ None, ++ Std, ++ Datasystem ++ }; ++ ++ explicit KvcFuture(KvcResult &&result); ++ explicit KvcFuture(std::shared_future &&future); ++ KvcFuture(const std::shared_ptr > &dsFutures, ++ size_t index, ++ const std::vector &keys); ++ ++ std::shared_future stdFuture_; ++ std::shared_ptr > dsFutures_; ++ size_t index_ = 0; ++ std::vector keys_; ++ KvcResult result_; ++ bool got_ = false; ++ From from_ = From::None; ++}; ++ ++// KvcFutureListHolder & KvcFutureWrapper is for easier C-API invocation ++struct KvcFutureListHolder { ++ explicit KvcFutureListHolder(DLLMKVC_List list): list_(list) {} ++ ~KvcFutureListHolder() ++ { ++ DLLMKVC_Future_FreeList(list_); ++ } ++ DLLMKVC_List list_; ++}; ++ ++using KvcFutureListRef = std::shared_ptr; ++ ++ ++struct KvcFutureHolder { ++ explicit KvcFutureHolder(DLLMKVC_Handle future): future_(future) {} ++ ~KvcFutureHolder() ++ { ++ DLLMKVC_Future_Free(future_); ++ } ++ DLLMKVC_Handle future_; ++}; ++ ++using KvcFutureRef = std::shared_ptr; ++ ++class KvcFutureWrapper { ++public: ++ static void FutureList2Wrappers(DLLMKVC_List list, size_t listSize, std::vector &outWrappers) ++ { ++ KvcFutureListRef listRef = std::make_shared(list); ++ for (size_t i = 0; i < listSize; ++i) { ++ outWrappers.emplace_back(listRef, DLLMKVC_Future_ListGetElem(list, i)); ++ } ++ } ++ ++ static KvcFutureWrapper Future2Wrapper(DLLMKVC_Handle future) ++ { ++ KvcFutureRef futureRef = std::make_shared(future); ++ return KvcFutureWrapper(futureRef); ++ } ++ ++ KvcFutureWrapper(); ++ KvcFutureWrapper(KvcFutureListRef listRef, DLLMKVC_Handle handle):listRef_(listRef), handle_(handle) {} ++ explicit KvcFutureWrapper(KvcFutureRef futureRef):futureRef_(futureRef), handle_(futureRef->future_) {} ++ ++ std::future_status wait_for(uint64_t ms) const ++ { ++ if (!handle_) { ++ throw std::runtime_error("invalid KvcFutureWrapper"); ++ } ++ return static_cast(DLLMKVC_Future_WaitFor(handle_, ms)); ++ } ++ ++ KvcResult get() const ++ { ++ if (!handle_) { ++ throw std::runtime_error("invalid KvcFutureWrapper"); ++ } ++ DLLMKVC_Handle resultHandle = DLLMKVC_Future_Get(handle_); ++ KvcResult result(DLLMKVC_Result_GetCode(resultHandle), ++ DLLMKVC_Result_GetErrMsg(resultHandle), { ++ }); ++ DLLMKVC_Result_Free(resultHandle); ++ return result; ++ } ++ ++private: ++ KvcFutureListRef listRef_; ++ KvcFutureRef futureRef_; ++ DLLMKVC_Handle handle_ = nullptr; ++}; ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_KVC_FUTURE_H +diff --git a/dllm_tools/csrc/include/kvc/kvc_store.h b/dllm_tools/csrc/include/kvc/kvc_store.h +new file mode 100644 +index 000000000..eecee35e4 +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/kvc_store.h +@@ -0,0 +1,65 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_KVC_STORE_H ++#define DLLM_KVC_KVC_STORE_H ++ ++#include ++#include ++ ++#include "common.h" ++#include "kvc_future.h" ++ ++namespace datasystem { ++class HeteroClient; ++} ++ ++namespace dllm { ++ ++namespace utils { ++class ThreadPool; ++} ++ ++namespace kvc { ++ ++class KvcStore { ++public: ++ ++ KvcStore() = default; ++ KvcStore(KvcStore&& other); ++ ~KvcStore(); ++ KvcStore& operator=(KvcStore&& other); ++ ++ void Init(const std::string &host, int32_t port, int32_t connectTimeoutMs, ++ int32_t numThreads = DEFAULT_NUM_THREADS); ++ KvcFuture MGetH2D(const std::vector &keys, const std::vector &devBlobLists) const; ++ KvcFuture MSetD2H(const std::vector &keys, const std::vector &devBlobLists) const; ++ KvcFuture Delete(const std::vector &keys) const; ++ void Exist(const std::vector &keys, std::vector &outExists) const; ++ void PutD2D(const std::vector &keys, ++ const std::vector &devBlobLists, ++ std::vector &outFutures) const; ++ void GetD2D(const std::vector &keys, ++ const std::vector &devBlobLists, ++ std::vector &outFutures) const; ++ ++private: ++ datasystem::HeteroClient* client_ = nullptr; ++ utils::ThreadPool* threadPool_ = nullptr; ++}; ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_KVC_STORE_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/include/kvc/page_attn_utils.h b/dllm_tools/csrc/include/kvc/page_attn_utils.h +new file mode 100644 +index 000000000..d447a9695 +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/page_attn_utils.h +@@ -0,0 +1,46 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_PAGE_ATTN_UTILS_H ++#define DLLM_KVC_PAGE_ATTN_UTILS_H ++ ++#include "common.h" ++ ++namespace dllm { ++namespace kvc { ++ ++class PageAttnUtils { ++public: ++ PageAttnUtils() = delete; ++ ++ static Blob Blk2Blob(uint64_t ptr, size_t elemSize, size_t numBlockElem, uint32_t blockId); ++ ++ static DeviceBlobList Blks2DevBlobList(int32_t deviceIdx, uint64_t ptr, size_t elemSize, size_t numBlockElem, ++ const std::vector &blockIds); ++ ++ static void LayerwiseDevBlobLists(int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds, ++ std::vector &outDevBlobLists); ++ ++ static void BlockwiseDevBlobLists(int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds, ++ std::vector &outDevBlobLists); ++}; ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_PAGE_ATTN_UTILS_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/include/kvc/torch_adaptor.h b/dllm_tools/csrc/include/kvc/torch_adaptor.h +new file mode 100644 +index 000000000..3606873bd +--- /dev/null ++++ b/dllm_tools/csrc/include/kvc/torch_adaptor.h +@@ -0,0 +1,252 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_TORCH_ADAPTOR_H ++#define DLLM_KVC_TORCH_ADAPTOR_H ++ ++#include ++#include ++ ++#include "common.h" ++#include "kvc_future.h" ++#include "c_api.h" ++ ++namespace dllm { ++namespace kvc { ++ ++class TorchAdaptor { ++public: ++ static void *GetStartDataPtr(const at::Tensor &tensor) ++ { ++ uint64_t ptr = (uint64_t)tensor.storage().data_ptr().get(); ++ uint64_t offset = tensor.storage_offset(); ++ uint64_t itemSize = tensor.dtype().itemsize(); ++ uint64_t start = ptr + (offset * itemSize); ++ return (void*)start; ++ } ++ ++ static void Tensor2Blob(const at::Tensor &tensor, DLLMKVC_List blobList, size_t index) ++ { ++ DLLMKVC_Blob_Init(DLLMKVC_Blob_ListGetElem(blobList, index), GetStartDataPtr(tensor), tensor.nbytes()); ++ } ++ ++ static DLLMKVC_List PageAttnLayerwiseDbls(const std::vector &layerTensors, ++ const std::vector &blockIds) ++ { ++ if (layerTensors.empty()) { ++ throw std::invalid_argument("layerTensors is empty"); ++ } ++ if (blockIds.empty()) { ++ throw std::invalid_argument("blockIds is empty"); ++ } ++ for (auto &tensor: layerTensors) { ++ if (tensor.get_device() != layerTensors[0].get_device()) { ++ throw std::invalid_argument("Tensors not from a same device"); ++ } ++ } ++ ++ size_t size = sizeof(DLLMKVC_Tensor)*layerTensors.size(); ++ DLLMKVC_Tensor *tensors = (DLLMKVC_Tensor*)malloc(size); ++ memset((void*)tensors, 0, size); ++ ++ for (size_t i = 0; i < layerTensors.size(); ++i) { ++ tensors[i].ptr = (uint64_t)GetStartDataPtr(layerTensors[i]); ++ tensors[i].elemSize = layerTensors[i].element_size(); ++ ++ at::IntArrayRef shapeAry = layerTensors[i].sizes(); ++ if (shapeAry.size() > DLLMKVC_MAX_SHAPE_LEN) { ++ free((void*)tensors); ++ throw std::invalid_argument("Tensor shape too long"); ++ } ++ ++ for (size_t s = 0; s < shapeAry.size(); ++s) { ++ tensors[i].shape[s] = shapeAry.at(s); ++ } ++ tensors[i].shapeLen = shapeAry.size(); ++ } ++ DLLMKVC_List dblList = DLLMKVC_DeviceBlobList_List(0); ++ DLLMKVC_PageAttn_LayerwiseDevBlobLists((int32_t)layerTensors[0].get_device(), ++ tensors, layerTensors.size(), ++ blockIds.data(), blockIds.size(), ++ dblList); ++ ++ free((void*)tensors); ++ return dblList; ++ } ++ ++ static DLLMKVC_List StrVec2List(const std::vector &keys) ++ { ++ DLLMKVC_List list = DLLMKVC_Str_List(keys.size()); ++ for (size_t i = 0; i < keys.size(); ++i) { ++ DLLMKVC_Str_ListSetElem(list, i, keys[i].c_str()); ++ } ++ return list; ++ } ++ ++ static DLLMKVC_List TensorVec2DblList(const std::vector &tensors) ++ { ++ DLLMKVC_List dblList = DLLMKVC_DeviceBlobList_List(tensors.size()); ++ for (size_t i = 0; i < tensors.size(); ++i) { ++ DLLMKVC_List blobList = DLLMKVC_Blob_List(1); ++ DLLMKVC_Blob_Init(DLLMKVC_Blob_ListGetElem(blobList, 0), ++ GetStartDataPtr(tensors[i]), tensors[i].nbytes()); ++ DLLMKVC_DeviceBlobList_Init(DLLMKVC_DeviceBlobList_ListGetElem(dblList, i), ++ tensors[i].get_device(), blobList); ++ ++ DLLMKVC_Blob_FreeList(blobList); ++ } ++ return dblList; ++ } ++ ++ TorchAdaptor() = default; ++ ++ TorchAdaptor(DLLMKVC_Handle store): store_(store) {} ++ ++ ~TorchAdaptor() ++ { ++ if (store_) { ++ DLLMKVC_Store_Free(store_); ++ } ++ } ++ ++ TorchAdaptor(TorchAdaptor &&other) = delete; ++ TorchAdaptor(const TorchAdaptor &other) = delete; ++ TorchAdaptor& operator=(const TorchAdaptor &other) = delete; ++ TorchAdaptor& operator=(TorchAdaptor &&other) = delete; ++ ++ void Init(const std::string &host, int32_t port, int32_t connTimeoutMs) ++ { ++ if (!store_) { ++ store_ = DLLMKVC_Store(); ++ } ++ DLLMKVC_Store_Init(store_, host.c_str(), port, connTimeoutMs); ++ } ++ ++ void PutTensorsD2D(std::vector &keys, ++ std::vector &tensors, ++ std::vector &outFutures) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = TensorVec2DblList(tensors); ++ DLLMKVC_List futureList = DLLMKVC_Future_List(0); ++ DLLMKVC_Store_PutD2D(store_, ++ keyList, ++ dblList, ++ futureList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ KvcFutureWrapper::FutureList2Wrappers(futureList, keys.size(), outFutures); ++ } ++ ++ void GetTensorsD2D(std::vector &keys, ++ std::vector &tensors, ++ std::vector &outFutures) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = TensorVec2DblList(tensors); ++ DLLMKVC_List futureList = DLLMKVC_Future_List(0); ++ DLLMKVC_Store_GetD2D(store_, ++ keyList, ++ dblList, ++ futureList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ KvcFutureWrapper::FutureList2Wrappers(futureList, keys.size(), outFutures); ++ } ++ ++ void PutPageAttnLayerwiseD2D(std::vector &keys, ++ std::vector &layerTensors, ++ std::vector &blockIds, ++ std::vector &outFutures) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = PageAttnLayerwiseDbls(layerTensors, blockIds); ++ DLLMKVC_List futureList = DLLMKVC_Future_List(0); ++ DLLMKVC_Store_PutD2D(store_, ++ keyList, ++ dblList, ++ futureList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ KvcFutureWrapper::FutureList2Wrappers(futureList, keys.size(), outFutures); ++ } ++ ++ void GetPageAttnLayerwiseD2D(std::vector &keys, ++ std::vector &layerTensors, ++ std::vector &blockIds, ++ std::vector &outFutures) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = PageAttnLayerwiseDbls(layerTensors, blockIds); ++ DLLMKVC_List futureList = DLLMKVC_Future_List(0); ++ DLLMKVC_Store_GetD2D(store_, ++ keyList, ++ dblList, ++ futureList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ KvcFutureWrapper::FutureList2Wrappers(futureList, keys.size(), outFutures); ++ } ++ ++ KvcFutureWrapper MGetTensorsH2D(std::vector &keys, ++ std::vector &tensors) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = TensorVec2DblList(tensors); ++ DLLMKVC_Handle future = DLLMKVC_Store_MGetH2D(store_, ++ keyList, ++ dblList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ return KvcFutureWrapper::Future2Wrapper(future); ++ } ++ ++ KvcFutureWrapper MSetTensorsD2H(std::vector &keys, ++ std::vector &tensors) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_List dblList = TensorVec2DblList(tensors); ++ DLLMKVC_Handle future = DLLMKVC_Store_MSetD2H(store_, ++ keyList, ++ dblList); ++ DLLMKVC_Str_FreeList(keyList); ++ DLLMKVC_DeviceBlobList_FreeList(dblList); ++ return KvcFutureWrapper::Future2Wrapper(future); ++ } ++ ++ KvcFutureWrapper Delete(std::vector &keys) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ DLLMKVC_Handle future = DLLMKVC_Store_Delete(store_, keyList); ++ DLLMKVC_Str_FreeList(keyList); ++ return KvcFutureWrapper::Future2Wrapper(future); ++ } ++ ++ void Exist(std::vector &keys, std::vector &outExists) ++ { ++ DLLMKVC_List keyList = StrVec2List(keys); ++ bool* exists = (bool*)malloc(sizeof(bool)*keys.size()); ++ DLLMKVC_Store_Exist(store_, keyList, exists, keys.size()); ++ outExists.insert(outExists.end(), exists, exists + keys.size()); ++ free(exists); ++ } ++ ++private: ++ DLLMKVC_Handle store_ = nullptr; ++}; ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_TORCH_ADAPTOR_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/include/perf/perf_manager.h b/dllm_tools/csrc/include/perf/perf_manager.h +new file mode 100644 +index 000000000..4576856bb +--- /dev/null ++++ b/dllm_tools/csrc/include/perf/perf_manager.h +@@ -0,0 +1,227 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: perf manager. ++ */ ++#ifndef DLLM_PERF_PERF_MANAGER_H ++#define DLLM_PERF_PERF_MANAGER_H ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace dllm { ++namespace perf { ++ ++/** ++ * @brief PerfInfo used for store the performance information. Time unit is nanoseconds! ++ */ ++struct PerfInfo { ++ std::atomic count; ++ std::atomic totalTime; ++ std::atomic tickCount; ++ std::atomic maxFrequency; ++ std::atomic maxTime; ++ std::atomic minTime = { ULONG_MAX }; ++ ++ PerfInfo() = default; ++ ++ explicit PerfInfo(const PerfInfo &info) ++ { ++ count.store(info.count.load()); ++ maxTime.store(info.maxTime.load()); ++ minTime.store(info.minTime.load()); ++ totalTime.store(info.totalTime.load()); ++ tickCount.store(info.tickCount.load()); ++ maxFrequency.store(info.maxFrequency.load()); ++ }; ++ ++ void Reset() ++ { ++ count = 0; ++ maxTime = 0; ++ minTime = ULONG_MAX; ++ totalTime = 0; ++ tickCount = 0; ++ maxFrequency = 0; ++ } ++ ++ PerfInfo &operator=(const PerfInfo &info) ++ { ++ if (this == &info) { ++ return *this; ++ } ++ count.store(info.count.load()); ++ maxTime.store(info.maxTime.load()); ++ minTime.store(info.minTime.load()); ++ totalTime.store(info.totalTime.load()); ++ tickCount.store(info.tickCount.load()); ++ maxFrequency.store(info.maxFrequency.load()); ++ return *this; ++ } ++ ++ std::string ToString(); ++}; ++ ++/** ++ * @brief PerfManager class used for managing the performance information. ++ */ ++class PerfManager { ++public: ++ /** ++ * @brief Get the Singleton performance manager instance. ++ * @return PerfManager instance. ++ */ ++ static PerfManager *Instance(); ++ ++ virtual ~PerfManager() = default; ++ ++ /** ++ * @brief Record one performance information according to the key. ++ * @param[in] key The key. ++ * @param[in] elapsed Time (unit nanoseconds) the info to add. ++ */ ++ void Add(const std::string &key, uint64_t elapsed); ++ ++ /** ++ * @brief Delete one performance information according to the key. ++ */ ++ void ResetPerfLog(); ++ ++ /** ++ * @brief Get performance logs. ++ * @return The perf log. ++ */ ++ std::string GetPerfLog(); ++ ++ /** ++ * @brief Print performance logs. ++ */ ++ void PrintPerfLog() const; ++ ++ /** ++ * @brief Trigger performance logs. Should call in main thread. ++ */ ++ void Tick(); ++ ++ /** ++ * @brief Get the performance info list. ++ * @param[out] perfInfoList The performance info list. ++ */ ++ void GetPerfInfoList(std::vector> &perfInfoList) const; ++ ++protected: ++ /** ++ * @brief We need to declare the constructor as protected ++ * function because PerfManager is a singleton. ++ */ ++ PerfManager(); ++ ++private: ++ using Clock = std::chrono::steady_clock; ++ ++ // All keys will add in constructor, so no need lock. ++ mutable std::unordered_map perfInfoList_; ++ mutable std::shared_timed_mutex perfMutex_; ++ ++ mutable std::chrono::time_point prevTickTime_; ++ mutable std::chrono::time_point prevLogTime_; ++}; ++ ++/** ++ * @brief Performance check point, the time unit is nanoseconds. ++ */ ++class PerfPoint { ++public: ++ /** ++ * @brief Construct the PerfPoint. ++ * @param[in] key PerfKey. ++ */ ++ explicit PerfPoint(std::string key) : beg_(Clock::now()), key_(key), isRecord_(false) ++ { ++ } ++ ++ /** ++ * @brief Call Record if isRecord_ is false. ++ */ ++ virtual ~PerfPoint() noexcept; ++ ++ /** ++ * @brief Add performance information to PerfManager. ++ */ ++ void Record(); ++ ++ /** ++ * @brief Reset begin time. ++ * @param[in] key PerfKey. ++ */ ++ void Reset(std::string key = ""); ++ ++ /** ++ * @brief Add performance information to PerfManager and reset the begin time and key. ++ * @param[in] key PerfKey. ++ */ ++ void RecordAndReset(std::string key); ++ ++ /** ++ * @brief Add elapsed time to PerfManager. ++ * @param[in] key PerfKey. ++ * @param[in] elapsed Time to add. ++ */ ++ static void RecordElapsed(std::string key, uint64_t elapsed); ++ ++private: ++ using Clock = std::chrono::steady_clock; ++ std::chrono::time_point beg_; ++ std::string key_; ++ bool isRecord_; ++}; ++ ++class PerfManagerPythonWrapper { ++public: ++ PerfManagerPythonWrapper() ++ { ++ manager_ = PerfManager::Instance(); ++ } ++ ++ void Print() const ++ { ++ manager_->PrintPerfLog(); ++ } ++ ++ ~PerfManagerPythonWrapper() ++ { ++ if (manager_ == nullptr) { ++ return; ++ } ++ manager_->PrintPerfLog(); ++ // flush log immediately ++ spdlog::default_logger()->flush(); ++ } ++ ++private: ++ PerfManager *manager_{ nullptr }; ++}; ++ ++} // namespace perf ++} // namespace dllm ++ ++#endif // DLLM_PERF_PERF_MANAGER_H +diff --git a/dllm_tools/csrc/include/utils/logging.h b/dllm_tools/csrc/include/utils/logging.h +new file mode 100644 +index 000000000..95f836ce0 +--- /dev/null ++++ b/dllm_tools/csrc/include/utils/logging.h +@@ -0,0 +1,80 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_UTILS_LOGGING_H ++#define DLLM_UTILS_LOGGING_H ++ ++#include ++#include ++#include ++ ++#include ++ ++namespace dllm { ++namespace utils { ++ ++inline const std::unordered_map LOG_LEVEL_MAP = { ++ { "INFO", spdlog::level::info }, { "DEBUG", spdlog::level::debug }, { "WARN", spdlog::level::warn }, ++ { "ERROR", spdlog::level::err }, { "OFF", spdlog::level::off }, ++}; ++constexpr size_t DEFAULT_LOG_SIZE = 10 * 1024 * 1024; // 10MB ++constexpr size_t DEFAULT_ROTATION_FILES = 5; ++constexpr char DEFAULT_BASE_NAME[] = "dllm"; ++ ++const spdlog::level::level_enum INFO = spdlog::level::info; ++const spdlog::level::level_enum WARN = spdlog::level::warn; ++const spdlog::level::level_enum ERROR = spdlog::level::err; ++ ++class LogMessage { ++public: ++ explicit LogMessage(spdlog::level::level_enum level) : level_(level) ++ { ++ } ++ ++ template ++ LogMessage &operator<<(const T &msg) ++ { ++ stream_ << msg; ++ return *this; ++ } ++ ++ ~LogMessage() ++ { ++ spdlog::log(level_, stream_.str()); ++ } ++ ++private: ++ spdlog::level::level_enum level_; ++ std::stringstream stream_; ++}; ++ ++class Logger { ++public: ++ void SetupRootLogging(const std::string &fileDirectory = "", const std::string &baseName = DEFAULT_BASE_NAME, ++ size_t maxLogSize = DEFAULT_LOG_SIZE, size_t maxFilesNum = DEFAULT_ROTATION_FILES, ++ bool enableConsoleLogging = false); ++ ++ std::string SetLogDir(const std::string &fileDirectory) const; ++ ++ void SetLogLevel(std::string logLevel); ++ ++ void SetFlushLevel(std::string logLevel); ++ ++private: ++ std::shared_ptr logger_; ++}; ++ ++} // namespace utils ++} // namespace dllm ++#endif // DLLM_UTILS_LOGGING_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/kvc/c_api.cpp b/dllm_tools/csrc/kvc/c_api.cpp +new file mode 100644 +index 000000000..e4699ec02 +--- /dev/null ++++ b/dllm_tools/csrc/kvc/c_api.cpp +@@ -0,0 +1,359 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include ++ ++#include "include/kvc/common.h" ++#include "include/kvc/kvc_store.h" ++#include "include/kvc/page_attn_utils.h" ++#include "include/kvc/c_api.h" ++ ++using namespace dllm::kvc; ++ ++inline DLLMKVC_List DLLMKVC_VEC_2_LIST(std::vector* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_List DLLMKVC_VEC_2_LIST(std::vector* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_List DLLMKVC_VEC_2_LIST(std::vector* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_List DLLMKVC_VEC_2_LIST(std::vector* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_Handle DLLMKVC_OBJ_2_HANDLE(KvcFuture* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_Handle DLLMKVC_OBJ_2_HANDLE(Blob* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_Handle DLLMKVC_OBJ_2_HANDLE(DeviceBlobList* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_Handle DLLMKVC_OBJ_2_HANDLE(KvcStore* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline DLLMKVC_Handle DLLMKVC_OBJ_2_HANDLE(KvcResult* ptr) ++{ ++ return static_cast(ptr); ++} ++ ++inline std::vector* DLLMKVC_LIST_2_STR_VEC(DLLMKVC_List list) ++{ ++ return static_cast*>(list); ++} ++ ++inline std::vector* DLLMKVC_LIST_2_BLOB_VEC(DLLMKVC_List list) ++{ ++ return static_cast*>(list); ++} ++ ++inline std::vector* DLLMKVC_LIST_2_DEVBLOBLIST_VEC(DLLMKVC_List list) ++{ ++ return static_cast*>(list); ++} ++ ++inline std::vector* DLLMKVC_LIST_2_FUTURE_VEC(DLLMKVC_List list) ++{ ++ return static_cast*>(list); ++} ++ ++inline KvcFuture* DLLMKVC_HANDLE_2_FUTURE(DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++inline Blob* DLLMKVC_HANDLE_2_BLOB(DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++inline DeviceBlobList* DLLMKVC_HANDLE_2_DEVBLOBLIST(DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++inline KvcStore* DLLMKVC_HANDLE_2_STORE(DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++inline KvcResult* DLLMKVC_HANDLE_2_RESULT(DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++inline const KvcResult* DLLMKVC_CONST_HANDLE_2_RESULT(const DLLMKVC_Handle h) ++{ ++ return static_cast(h); ++} ++ ++DLLMKVC_List DLLMKVC_Str_List(size_t len) ++{ ++ return DLLMKVC_VEC_2_LIST(new std::vector(len)); ++} ++ ++void DLLMKVC_Str_FreeList(DLLMKVC_List list) ++{ ++ delete DLLMKVC_LIST_2_STR_VEC(list); ++} ++ ++const char* DLLMKVC_Str_ListGetElem(DLLMKVC_List list, size_t index) ++{ ++ return DLLMKVC_LIST_2_STR_VEC(list)->operator[](index).c_str(); ++} ++ ++void DLLMKVC_Str_ListSetElem(DLLMKVC_List list, size_t index, const char* str) ++{ ++ DLLMKVC_LIST_2_STR_VEC(list)->operator[](index) = str; ++} ++ ++ ++// Blob ++ ++DLLMKVC_List DLLMKVC_Blob_List(size_t len) ++{ ++ return DLLMKVC_VEC_2_LIST(new std::vector(len)); ++} ++ ++void DLLMKVC_Blob_FreeList(DLLMKVC_List list) ++{ ++ delete DLLMKVC_LIST_2_BLOB_VEC(list); ++} ++ ++DLLMKVC_Handle DLLMKVC_Blob_ListGetElem(DLLMKVC_List list, size_t index) ++{ ++ return DLLMKVC_OBJ_2_HANDLE(&(DLLMKVC_LIST_2_BLOB_VEC(list)->operator[](index))); ++} ++ ++void DLLMKVC_Blob_Init(DLLMKVC_Handle blob, void* pointer, uint64_t size) ++{ ++ Blob *ptr = DLLMKVC_HANDLE_2_BLOB(blob); ++ ptr->pointer = pointer; ++ ptr->size = size; ++} ++ ++// DeviceBlobList ++ ++DLLMKVC_List DLLMKVC_DeviceBlobList_List(size_t len) ++{ ++ return DLLMKVC_VEC_2_LIST(new std::vector(len)); ++} ++ ++void DLLMKVC_DeviceBlobList_FreeList(DLLMKVC_List list) ++{ ++ delete DLLMKVC_LIST_2_DEVBLOBLIST_VEC(list); ++} ++ ++DLLMKVC_Handle DLLMKVC_DeviceBlobList_ListGetElem(DLLMKVC_List list, size_t index) ++{ ++ return DLLMKVC_OBJ_2_HANDLE(&(DLLMKVC_LIST_2_DEVBLOBLIST_VEC(list)->operator[](index))); ++} ++ ++void DLLMKVC_DeviceBlobList_Init(DLLMKVC_Handle devBlobList, int32_t deviceIdx, DLLMKVC_List blobs) ++{ ++ DeviceBlobList* ptr = DLLMKVC_HANDLE_2_DEVBLOBLIST(devBlobList); ++ ptr->deviceIdx = deviceIdx; ++ ptr->blobs = std::move(*(DLLMKVC_LIST_2_BLOB_VEC(blobs))); ++} ++ ++// KvcStore ++ ++DLLMKVC_Handle DLLMKVC_Store() ++{ ++ return DLLMKVC_OBJ_2_HANDLE(new KvcStore()); ++} ++ ++void DLLMKVC_Store_Free(DLLMKVC_Handle store) ++{ ++ delete DLLMKVC_HANDLE_2_STORE(store); ++} ++ ++void DLLMKVC_Store_Init(DLLMKVC_Handle store, const char* host, int32_t port, int32_t connTimeoutMs) ++{ ++ DLLMKVC_HANDLE_2_STORE(store)->Init(host, port, connTimeoutMs); ++} ++ ++void DLLMKVC_Store_PutD2D( ++ DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists, ++ DLLMKVC_List outFutures) ++{ ++ DLLMKVC_HANDLE_2_STORE(store)->PutD2D(*DLLMKVC_LIST_2_STR_VEC(keys), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(devBlobLists), ++ *DLLMKVC_LIST_2_FUTURE_VEC(outFutures)); ++} ++ ++void DLLMKVC_Store_GetD2D( ++ DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists, ++ DLLMKVC_List outFutures) ++{ ++ DLLMKVC_HANDLE_2_STORE(store)->GetD2D(*DLLMKVC_LIST_2_STR_VEC(keys), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(devBlobLists), ++ *DLLMKVC_LIST_2_FUTURE_VEC(outFutures)); ++} ++ ++DLLMKVC_Handle DLLMKVC_Store_MGetH2D(DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists) ++{ ++ KvcFuture future = DLLMKVC_HANDLE_2_STORE(store)->MGetH2D(*DLLMKVC_LIST_2_STR_VEC(keys), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(devBlobLists)); ++ DLLMKVC_Handle out = DLLMKVC_Future(); ++ *DLLMKVC_HANDLE_2_FUTURE(out) = future; ++ return out; ++} ++ ++DLLMKVC_Handle DLLMKVC_Store_MSetD2H(DLLMKVC_Handle store, ++ DLLMKVC_List keys, ++ DLLMKVC_List devBlobLists) ++{ ++ KvcFuture future = DLLMKVC_HANDLE_2_STORE(store)->MSetD2H(*DLLMKVC_LIST_2_STR_VEC(keys), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(devBlobLists)); ++ DLLMKVC_Handle out = DLLMKVC_Future(); ++ *DLLMKVC_HANDLE_2_FUTURE(out) = future; ++ return out; ++} ++ ++DLLMKVC_Handle DLLMKVC_Store_Delete(DLLMKVC_Handle store, DLLMKVC_List keys) ++{ ++ KvcFuture future = DLLMKVC_HANDLE_2_STORE(store)->Delete(*DLLMKVC_LIST_2_STR_VEC(keys)); ++ DLLMKVC_Handle out = DLLMKVC_Future(); ++ *DLLMKVC_HANDLE_2_FUTURE(out) = future; ++ return out; ++} ++ ++void DLLMKVC_Store_Exist(DLLMKVC_Handle store, DLLMKVC_List keys, bool* outExists, size_t outExistsLen) ++{ ++ std::vector exists; ++ DLLMKVC_HANDLE_2_STORE(store)->Exist(*DLLMKVC_LIST_2_STR_VEC(keys), exists); ++ size_t outLen = outExistsLen < exists.size() ? outExistsLen : exists.size(); ++ // cannot use memcpy() coz std::vector doesn't have data() function. ++ for (size_t i = 0; i < outLen; ++i) { ++ outExists[i] = exists[i]; ++ } ++} ++ ++// KvcFuture ++DLLMKVC_Handle DLLMKVC_Future() ++{ ++ return DLLMKVC_OBJ_2_HANDLE(new KvcFuture()); ++} ++ ++void DLLMKVC_Future_Free(DLLMKVC_Handle future) ++{ ++ delete DLLMKVC_HANDLE_2_FUTURE(future); ++} ++ ++DLLMKVC_List DLLMKVC_Future_List(size_t len) ++{ ++ return DLLMKVC_VEC_2_LIST(new std::vector(len)); ++} ++ ++void DLLMKVC_Future_FreeList(DLLMKVC_List list) ++{ ++ delete DLLMKVC_LIST_2_FUTURE_VEC(list); ++} ++ ++DLLMKVC_Handle DLLMKVC_Future_ListGetElem(DLLMKVC_List list, size_t index) ++{ ++ return DLLMKVC_OBJ_2_HANDLE(&(DLLMKVC_LIST_2_FUTURE_VEC(list)->operator[](index))); ++} ++ ++int32_t DLLMKVC_Future_WaitFor(DLLMKVC_Handle future, uint32_t ms) ++{ ++ return static_cast(DLLMKVC_HANDLE_2_FUTURE(future)->wait_for(ms)); ++} ++ ++DLLMKVC_Handle DLLMKVC_Future_Get(DLLMKVC_Handle future) ++{ ++ KvcResult *result = new KvcResult(); ++ *result = DLLMKVC_HANDLE_2_FUTURE(future)->get(); ++ return result; ++} ++ ++// KvcResult ++ ++void DLLMKVC_Result_Free(DLLMKVC_Handle result) ++{ ++ delete DLLMKVC_HANDLE_2_RESULT(result); ++} ++ ++const char* DLLMKVC_Result_GetErrMsg(const DLLMKVC_Handle result) ++{ ++ return DLLMKVC_CONST_HANDLE_2_RESULT(result)->errorMessage.c_str(); ++} ++ ++int32_t DLLMKVC_Result_GetCode(DLLMKVC_Handle result) ++{ ++ return DLLMKVC_HANDLE_2_RESULT(result)->statusCode; ++} ++ ++void DLLMKVC_PageAttn_LayerwiseDevBlobLists( ++ int32_t deviceIdx, ++ const DLLMKVC_Tensor* tensors, size_t numTensors, ++ const uint32_t* blockIds, size_t numBlocks, ++ DLLMKVC_List outDblList) ++{ ++ std::vector layerTensors(numTensors); ++ for (size_t i = 0; i < numTensors; ++i) { ++ const DLLMKVC_Tensor* tensor = tensors + i; ++ layerTensors[i] = KvcTensor{.ptr=tensor->ptr, ++ .elemSize=tensor->elemSize, ++ .shape=std::vector(tensor->shape, tensor->shape + tensor->shapeLen)}; ++ } ++ ++ PageAttnUtils::LayerwiseDevBlobLists(deviceIdx, layerTensors, ++ std::vector(blockIds, blockIds + numBlocks), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(outDblList)); ++} ++ ++void DLLMKVC_PageAttn_BlockwiseDevBlobLists( ++ int32_t deviceIdx, ++ const DLLMKVC_Tensor* tensors, size_t numTensors, ++ const uint32_t* blockIds, size_t numBlocks, ++ DLLMKVC_List outDblList) ++{ ++ std::vector layerTensors(numTensors); ++ for (size_t i = 0; i < numTensors; ++i) { ++ const DLLMKVC_Tensor* tensor = tensors + i; ++ layerTensors[i] = KvcTensor{.ptr=tensor->ptr, ++ .elemSize=tensor->elemSize, ++ .shape=std::vector(tensor->shape, tensor->shape + tensor->shapeLen)}; ++ } ++ ++ PageAttnUtils::BlockwiseDevBlobLists(deviceIdx, layerTensors, ++ std::vector(blockIds, blockIds + numBlocks), ++ *DLLMKVC_LIST_2_DEVBLOBLIST_VEC(outDblList)); ++} +diff --git a/dllm_tools/csrc/kvc/kvc_future.cpp b/dllm_tools/csrc/kvc/kvc_future.cpp +new file mode 100644 +index 000000000..cb9ea2d02 +--- /dev/null ++++ b/dllm_tools/csrc/kvc/kvc_future.cpp +@@ -0,0 +1,88 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include ++#include ++#include ++ ++#include "include/kvc/kvc_future.h" ++ ++namespace dllm { ++namespace kvc { ++ ++KvcFuture::KvcFuture(KvcResult &&result):result_(std::move(result)), got_(true) ++{ ++} ++ ++KvcFuture::KvcFuture(std::shared_future &&future) ++ :stdFuture_(std::move(future)), ++ from_(From::Std) ++{ ++} ++ ++KvcFuture::KvcFuture(const std::shared_ptr > &dsFutures, ++ size_t index, ++ const std::vector &keys) ++ :dsFutures_(dsFutures), ++ index_(index), ++ keys_(keys), ++ from_(From::Datasystem) ++{ ++} ++ ++std::future_status KvcFuture::wait_for(uint64_t ms) ++{ ++ if (got_) { ++ return std::future_status::ready; ++ } ++ if (from_ == From::Datasystem) { ++ datasystem::Status rc = (*dsFutures_)[index_].Get(ms); ++ if (rc.GetCode() == datasystem::StatusCode::K_FUTURE_TIMEOUT) { ++ return std::future_status::timeout; ++ } ++ if (rc.IsOk()) { ++ result_ = KvcResult::Ok(); ++ } else { ++ result_ = KvcResult(rc.GetCode(), rc.GetMsg(), std::move(keys_)); ++ } ++ got_ = true; ++ return std::future_status::ready; ++ } else if (from_ == From::Std) { ++ return stdFuture_.wait_for(std::chrono::milliseconds(ms)); ++ } ++ throw std::runtime_error("invalid KvcFuture"); ++} ++ ++KvcResult KvcFuture::get() ++{ ++ if (got_) { ++ return result_; ++ } ++ if (from_ == From::Datasystem) { ++ datasystem::Status rc = (*dsFutures_)[index_].Get(); ++ if (rc.IsOk()) { ++ result_ = KvcResult::Ok(); ++ } else { ++ result_ = KvcResult(rc.GetCode(), rc.GetMsg(), std::move(keys_)); ++ } ++ got_ = true; ++ return result_; ++ } else if (from_ == From::Std) { ++ return stdFuture_.get(); ++ } ++ throw std::runtime_error("invalid KvcFuture"); ++} ++ ++} // namespace kvc ++} // namespace dllm +diff --git a/dllm_tools/csrc/kvc/kvc_store.cpp b/dllm_tools/csrc/kvc/kvc_store.cpp +new file mode 100644 +index 000000000..bfea381d1 +--- /dev/null ++++ b/dllm_tools/csrc/kvc/kvc_store.cpp +@@ -0,0 +1,179 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include ++#include ++#include ++ ++#include "utils/thread_pool.h" ++#include "include/kvc/kvc_store.h" ++ ++namespace dllm { ++namespace kvc { ++ ++namespace ds = datasystem; ++ ++inline std::vector ToDsDevBlobLists(const std::vector &devBlobLists) ++{ ++ std::vector dsDbls(devBlobLists.size()); ++ for (size_t i = 0; i < devBlobLists.size(); ++i) { ++ const DeviceBlobList &devBlobList = devBlobLists[i]; ++ std::vector dsBlobs(devBlobList.blobs.size()); ++ for (size_t j = 0; j < devBlobList.blobs.size(); ++j) { ++ const Blob &blob = devBlobList.blobs[j]; ++ dsBlobs[j] = {.pointer=blob.pointer, .size=blob.size}; ++ } ++ dsDbls[i] = {.blobs=std::move(dsBlobs), .deviceIdx=devBlobList.deviceIdx}; ++ } ++ return dsDbls; ++} ++ ++inline void ThrowIfStatusErr(ds::Status status, const char *file, int line) ++{ ++ if (status.IsError()) { ++ std::string errMsg = "Datasystem error: " + status.ToString() + "\nin " + file + ":" + std::to_string(line); ++ spdlog::error(errMsg); ++ throw std::runtime_error(std::move(errMsg)); ++ } ++} ++ ++KvcStore::KvcStore(KvcStore&& other) ++{ ++ *this = std::move(other); ++} ++ ++KvcStore::~KvcStore() ++{ ++ if (client_) { ++ client_->ShutDown(); ++ delete client_; ++ } ++ if (threadPool_) { ++ delete threadPool_; ++ } ++} ++ ++KvcStore& KvcStore::operator=(KvcStore&& other) ++{ ++ if (this == &other) { ++ return *this; ++ } ++ if (client_) { ++ client_->ShutDown(); ++ delete client_; ++ } ++ if (threadPool_) { ++ delete threadPool_; ++ } ++ client_ = other.client_; ++ threadPool_ = other.threadPool_; ++ other.client_ = nullptr; ++ other.threadPool_ = nullptr; ++ return *this; ++} ++ ++void KvcStore::Init(const std::string &host, int32_t port, int32_t connectTimeoutMs, int32_t numThreads) ++{ ++ ds::ConnectOptions connectOpts{ .host = host, .port = port, .connectTimeoutMs = connectTimeoutMs }; ++ client_ = new ds::HeteroClient(connectOpts); ++ ThrowIfStatusErr((client_->Init()), __FILE__, __LINE__); ++ threadPool_ = new utils::ThreadPool(numThreads); ++} ++ ++KvcFuture KvcStore::MGetH2D(const std::vector &keys, ++ const std::vector &devBlobLists) const ++{ ++ // must capture by value ++ return KvcFuture(threadPool_->Submit([client_ = client_, keys = keys, devBlobLists = devBlobLists]() { ++ std::vector failedList; ++ ds::Status rc = client_->MGetH2D(keys, ToDsDevBlobLists(devBlobLists), failedList, 10000); ++ return KvcResult(rc.GetCode(), rc.GetMsg(), std::move(failedList)); ++ })); ++} ++ ++KvcFuture KvcStore::MSetD2H(const std::vector &keys, ++ const std::vector &devBlobLists) const ++{ ++ // must capture by value ++ return KvcFuture(threadPool_->Submit([client_ = client_, keys = keys, devBlobLists = devBlobLists]() { ++ ds::Status rc = client_->MSetD2H(keys, ToDsDevBlobLists(devBlobLists)); ++ return KvcResult(rc.GetCode(), rc.GetMsg(), {}); ++ })); ++} ++ ++KvcFuture KvcStore::Delete(const std::vector &keys) const ++{ ++ // must capture by value ++ return KvcFuture(threadPool_->Submit([client_ = client_, keys = keys]() { ++ std::vector failedList; ++ ds::Status rc = client_->Delete(keys, failedList); ++ return KvcResult(rc.GetCode(), rc.GetMsg(), std::move(failedList)); ++ })); ++} ++ ++void KvcStore::PutD2D(const std::vector &keys, ++ const std::vector &devBlobLists, ++ std::vector &outFutures) const ++{ ++ size_t oldFutsSize = outFutures.size(); ++ outFutures.resize(oldFutsSize + keys.size()); ++ std::shared_ptr > dsFutures = std::make_shared >(); ++ ds::Status rc = client_->DevPublish(keys, ToDsDevBlobLists(devBlobLists), *(dsFutures.get())); ++ if (rc.IsError()) { ++ for (size_t i = 0; i < keys.size(); ++i) { ++ outFutures[oldFutsSize + i] = KvcFuture(KvcResult(rc.GetCode(), rc.GetMsg(), {keys[i]})); ++ } ++ return; ++ } ++ ++ if (dsFutures->size() != keys.size()) { ++ throw std::runtime_error("datasystem futures and keys size not match"); ++ } ++ ++ for (size_t i = 0; i < keys.size(); ++i) { ++ outFutures[oldFutsSize + i] = KvcFuture(dsFutures, i, {keys[i]}); ++ } ++} ++ ++void KvcStore::GetD2D(const std::vector &keys, ++ const std::vector &devBlobLists, ++ std::vector &outFutures) const ++{ ++ size_t oldFutsSize = outFutures.size(); ++ outFutures.resize(oldFutsSize + keys.size()); ++ std::shared_ptr > dsFutures = std::make_shared >(); ++ ds::Status rc = client_->DevSubscribe(keys, ToDsDevBlobLists(devBlobLists), *(dsFutures.get())); ++ if (rc.IsError()) { ++ for (size_t i = 0; i < keys.size(); ++i) { ++ outFutures[oldFutsSize + i] = KvcFuture(KvcResult(rc.GetCode(), rc.GetMsg(), {keys[i]})); ++ } ++ } ++ ++ if (dsFutures->size() != keys.size()) { ++ throw std::runtime_error("datasystem futures and keys size not match"); ++ } ++ ++ for (size_t i = 0; i < keys.size(); ++i) { ++ outFutures[oldFutsSize + i] = KvcFuture(dsFutures, i, {keys[i]}); ++ } ++} ++ ++void KvcStore::Exist(const std::vector &keys, std::vector &outExists) const ++{ ++ auto rc = client_->Exist(keys, outExists); ++ ThrowIfStatusErr((rc), __FILE__, __LINE__); ++} ++ ++} // namespace kvc ++} // namespace dllm +diff --git a/dllm_tools/csrc/kvc/page_attn_utils.cpp b/dllm_tools/csrc/kvc/page_attn_utils.cpp +new file mode 100644 +index 000000000..691282c3f +--- /dev/null ++++ b/dllm_tools/csrc/kvc/page_attn_utils.cpp +@@ -0,0 +1,76 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++#include "include/kvc/page_attn_utils.h" ++ ++namespace dllm { ++namespace kvc { ++ ++Blob PageAttnUtils::Blk2Blob(uint64_t ptr, size_t elemSize, size_t numBlockElem, uint32_t blockId) ++{ ++ size_t blockSize = elemSize * numBlockElem; ++ intptr_t address = ptr + (blockSize * blockId); ++ void *pointer = reinterpret_cast(address); ++ return Blob{.pointer=pointer, .size=blockSize}; ++} ++ ++DeviceBlobList PageAttnUtils::Blks2DevBlobList(int32_t deviceIdx, uint64_t ptr, size_t elemSize, size_t numBlockElem, ++ const std::vector &blockIds) ++{ ++ std::vector blobs(blockIds.size()); ++ for (size_t i = 0; i < blockIds.size(); ++i) { ++ blobs[i] = Blk2Blob(ptr, elemSize, numBlockElem, blockIds[i]); ++ } ++ return DeviceBlobList{.blobs = blobs, deviceIdx = deviceIdx}; ++} ++ ++void PageAttnUtils::LayerwiseDevBlobLists(int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds, ++ std::vector &outDevBlobLists) ++{ ++ size_t oriSize = outDevBlobLists.size(); ++ outDevBlobLists.resize(oriSize + layerTensors.size()); ++ for (size_t i = 0; i < layerTensors.size(); ++i) { ++ const KvcTensor& layer = layerTensors[i]; ++ size_t numBlockElem = 1; ++ for (size_t j = 1; j < layer.shape.size(); ++j) { ++ numBlockElem *= layer.shape[j]; ++ } ++ outDevBlobLists[oriSize + i] = Blks2DevBlobList(deviceIdx, layer.ptr, layer.elemSize, numBlockElem, blockIds); ++ } ++} ++ ++void PageAttnUtils::BlockwiseDevBlobLists(int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds, ++ std::vector &outDevBlobLists) ++{ ++ size_t oriSize = outDevBlobLists.size(); ++ outDevBlobLists.resize(oriSize + blockIds.size()); ++ for (size_t i = 0; i < blockIds.size(); ++i) { ++ std::vector blobs(layerTensors.size()); ++ for (size_t j = 0; j < layerTensors.size(); ++j) { ++ const KvcTensor& layer = layerTensors[j]; ++ size_t numBlockElem = 1; ++ for (size_t j = 1; j < layer.shape.size(); ++j) { ++ numBlockElem *= layer.shape[j]; ++ } ++ blobs[j] = Blk2Blob(layer.ptr, layer.elemSize, numBlockElem, blockIds[i]); ++ } ++ outDevBlobLists[oriSize + i] = DeviceBlobList{.blobs = blobs, deviceIdx = deviceIdx}; ++ } ++} ++ ++ ++} // namespace kvc ++} // namespace dllm +diff --git a/dllm_tools/csrc/kvc/pybind.h b/dllm_tools/csrc/kvc/pybind.h +new file mode 100644 +index 000000000..26ed6cc0a +--- /dev/null ++++ b/dllm_tools/csrc/kvc/pybind.h +@@ -0,0 +1,191 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_KVC_PYBIND_H ++#define DLLM_KVC_PYBIND_H ++ ++#include ++#include ++#include ++#include // for fmt:: ++#include ++#include ++ ++#include "include/kvc/common.h" ++#include "include/kvc/kvc_store.h" ++#include "include/kvc/kvc_future.h" ++#include "include/kvc/page_attn_utils.h" ++ ++ ++namespace py = pybind11; ++ ++namespace dllm { ++namespace kvc { ++ ++const int TIMEOUT_MS = 1000; ++ ++void PyBind(py::module_ m) ++{ ++ using namespace pybind11::literals; ++ m.doc() = "pybind11 for KV Cache"; ++ py::class_(m, "Blob") ++ .def(py::init<>([](uint64_t pointer, uint64_t size) { ++ return Blob{ .pointer = reinterpret_cast(static_cast(pointer)), .size = size }; ++ }), ++ "pointer"_a, "size"_a) ++ .def_property_readonly("pointer", [](const Blob &blob) { ++ return static_cast(reinterpret_cast(blob.pointer)); ++ }) ++ .def_readonly("size", &Blob::size) ++ .def("__repr__", ++ [](const Blob &blob) { ++ return fmt::format("Blob[pointer:{}, size:{}]", ++ static_cast(reinterpret_cast(blob.pointer)), ++ blob.size); ++ }) ++ .def( ++ "__eq__", ++ [](const Blob &lhs, const Blob &rhs) { return lhs.pointer == rhs.pointer && lhs.size == rhs.size; }, ++ "other"_a); ++ ++ py::class_(m, "DeviceBlobList") ++ .def(py::init<>([](std::vector blobs, int32_t deviceIdx) { ++ return DeviceBlobList{ std::move(blobs), deviceIdx }; ++ }), ++ "blobs"_a, "device_idx"_a) ++ .def_readwrite("device_idx", &DeviceBlobList::deviceIdx) ++ .def_readwrite("blobs", &DeviceBlobList::blobs) ++ .def("append", [](DeviceBlobList &self, Blob blob) { self.blobs.emplace_back(std::move(blob)); }, "blob"_a) ++ .def("to_array", [](DeviceBlobList &self) { ++ std::vector> ret; ++ for (auto &item : self.blobs) { ++ std::vector list; ++ list.emplace_back(static_cast(reinterpret_cast(item.pointer))); ++ list.emplace_back(item.size); ++ ret.emplace_back(list); ++ } ++ return ret; ++ }); ++ ++ py::class_(m, "KvcTensor") ++ .def(py::init<>([](uint64_t ptr, uint32_t elemSize, const std::vector &shape) { ++ return KvcTensor{ .ptr = ptr, .elemSize = elemSize, .shape = shape }; ++ }), ++ "ptr"_a, "elem_size"_a, "shape"_a) ++ .def_readwrite("ptr", &KvcTensor::ptr) ++ .def_readwrite("elem_size", &KvcTensor::elemSize) ++ .def_readwrite("shape", &KvcTensor::shape) ++ .def("__repr__", [](const KvcTensor &self) { ++ std::ostringstream oss; ++ for (auto dim: self.shape) { ++ oss << dim << ","; ++ } ++ return fmt::format("KvcTensor[ptr:{}, elem_size:{}, shape:{}]", self.ptr, ++ self.elemSize, oss.str()); ++ }) ++ .def("__eq__", [](const KvcTensor &lhs, const KvcTensor &rhs) { ++ return lhs.ptr == rhs.ptr && lhs.elemSize == rhs.elemSize && ++ std::equal(lhs.shape.begin(), lhs.shape.end(), rhs.shape.begin()); ++ }, "other"_a); ++ ++ py::class_(m, "KvcResult") ++ .def_readwrite("status_code", &KvcResult::statusCode) ++ .def_readwrite("error_message", &KvcResult::errorMessage) ++ .def_readwrite("failed_list", &KvcResult::failedList) ++ .def("__repr__", [](const KvcResult &self) { ++ std::ostringstream oss; ++ for (auto failed: self.failedList) { ++ oss << failed << ","; ++ } ++ return fmt::format("kvcResult[status_code:{}, error_message:{}, failed_list:{}]", ++ self.statusCode, self.errorMessage, oss.str()); ++ }); ++ ++ py::class_(m, "KvcFuture") ++ .def("running", [](KvcFuture &self) { ++ return self.wait_for(0) != std::future_status::ready; ++ }) ++ .def("done", [](KvcFuture &self) { ++ return self.wait_for(0) == std::future_status::ready; ++ }) ++ .def("result", [](KvcFuture &self, float timeout) { ++ if (timeout < 0 ++ || self.wait_for(int(timeout * TIMEOUT_MS)) == std::future_status::ready) { ++ KvcResult res = self.get(); ++ return res; ++ } ++ py::set_error(PyExc_TimeoutError, "timeout"); ++ throw py::error_already_set(); ++ }, "timeout"_a = -1, py::return_value_policy::copy); ++ ++ py::class_(m, "KvcStore") ++ .def(py::init<>([]() { return KvcStore();})) ++ .def("init", &KvcStore::Init, "host"_a, "port"_a, ++ "conn_timeout_ms"_a, "num_threads"_a = DEFAULT_NUM_THREADS, ++ py::call_guard()) ++ .def("mget_h2d", &KvcStore::MGetH2D, "keys"_a, "dev_blob_lists"_a, ++ py::call_guard()) ++ .def("mset_d2h", &KvcStore::MSetD2H, "keys"_a, "dev_blob_lists"_a, ++ py::call_guard()) ++ .def("delete", &KvcStore::Delete, "keys"_a, ++ py::call_guard()) ++ .def("exist", [](KvcStore &self, ++ const std::vector &keys) { ++ std::vector outExists; ++ self.Exist(keys, outExists); ++ return outExists; ++ }, "keys"_a, ++ py::call_guard()) ++ .def("put_d2d", [](KvcStore &self, ++ const std::vector &keys, ++ const std::vector &devBlobLists) { ++ std::vector outFutures; ++ self.PutD2D(keys, devBlobLists, outFutures); ++ return outFutures; ++ }, "keys"_a, "dev_blob_lists"_a, ++ py::call_guard()) ++ .def("get_d2d", [](KvcStore &self, ++ const std::vector &keys, ++ const std::vector &devBlobLists) { ++ std::vector outFutures; ++ self.GetD2D(keys, devBlobLists, outFutures); ++ return outFutures; ++ }, "keys"_a, "dev_blob_lists"_a, ++ py::call_guard()); ++ ++ py::class_(m, "PageAttnUtils") ++ .def_static("blk_2_blob", &PageAttnUtils::Blk2Blob, ++ "ptr"_a, "elem_size"_a, "num_block_elem"_a, "block_id"_a) ++ .def_static("blks_2_dev_blob_list", &PageAttnUtils::Blks2DevBlobList, ++ "device_idx"_a, "ptr"_a, "elem_size"_a, "num_block_elem"_a, "block_ids"_a) ++ .def_static("layerwise_dev_blob_lists", [](int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds) { ++ std::vector outDblList; ++ PageAttnUtils::LayerwiseDevBlobLists(deviceIdx, layerTensors, blockIds, outDblList); ++ return outDblList; ++ }, "device_idx"_a, "layer_tensors"_a, "block_ids"_a) ++ .def_static("blockwise_dev_blob_lists", [](int32_t deviceIdx, ++ const std::vector &layerTensors, ++ const std::vector &blockIds) { ++ std::vector outDblList; ++ PageAttnUtils::BlockwiseDevBlobLists(deviceIdx, layerTensors, blockIds, outDblList); ++ return outDblList; ++ }, "device_idx"_a, "layer_tensors"_a, "block_ids"_a); ++} ++ ++} // namespace kvc ++} // namespace dllm ++ ++#endif // DLLM_KVC_PYBIND_H +diff --git a/dllm_tools/csrc/perf/perf_manager.cpp b/dllm_tools/csrc/perf/perf_manager.cpp +new file mode 100644 +index 000000000..85c03062d +--- /dev/null ++++ b/dllm_tools/csrc/perf/perf_manager.cpp +@@ -0,0 +1,221 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include ++ ++/** ++ * Description: Perf manager. ++ */ ++ ++#include "include/perf/perf_manager.h" ++ ++#include ++#include ++#include ++#include ++ ++#include "include/utils/logging.h" ++ ++namespace dllm { ++namespace perf { ++ ++const uint64_t SECONDS_TO_NANO_UNIT = 1000ul * 1000ul * 1000ul; ++const uint64_t TRIGGER_PERF_LOG_NANO_INTERVAL = 60ul * SECONDS_TO_NANO_UNIT; ++ ++std::string PerfInfo::ToString() ++{ ++ uint64_t _count = this->count.load(); ++ uint64_t _totalTime = this->totalTime.load(); ++ uint64_t avgTime = 0; ++ uint64_t avgWithoutExtremes = 0; ++ ++ if (_count > 0) { ++ avgTime = _totalTime / _count; ++ } ++ const int32_t extremePointCount = 2; ++ if (_count > extremePointCount) { ++ avgWithoutExtremes = (_totalTime - this->minTime.load() - this->maxTime.load()) / (_count - extremePointCount); ++ } else { ++ avgWithoutExtremes = avgTime; ++ } ++ ++ std::ostringstream oss; ++ oss << "{" ++ << "\"count\": " << _count << ", " ++ << "\"minTime\": " << this->minTime.load() << ", " ++ << "\"maxTime\": " << this->maxTime.load() << ", " ++ << "\"totalTime\": " << _totalTime << ", " ++ << "\"avgTime\": " << avgTime << ", " ++ << "\"avgTimeWithoutExtremes\": " << avgWithoutExtremes << ", " ++ << "\"maxFrequency\": " << this->maxFrequency.load() ++ << "}"; ++ ++ return oss.str(); ++} ++ ++PerfManager *PerfManager::Instance() ++{ ++ static PerfManager inst; ++ return &inst; ++} ++ ++PerfManager::PerfManager() ++{ ++ spdlog::info("distributed kvc enable perf."); ++ prevTickTime_ = Clock::now(); ++ prevLogTime_ = Clock::now(); ++} ++ ++void PerfManager::Add(const std::string &key, uint64_t elapsed) ++{ ++ PerfInfo *info = nullptr; ++ { ++ std::lock_guard lock(perfMutex_); ++ info = &perfInfoList_[key]; ++ } ++ std::shared_lock lock(perfMutex_); ++ ++ info->count.fetch_add(1, std::memory_order_relaxed); ++ info->totalTime.fetch_add(elapsed, std::memory_order_relaxed); ++ info->tickCount.fetch_add(1, std::memory_order_relaxed); ++ ++ uint64_t preValue = info->maxTime.load(); ++ while (elapsed > preValue && !info->maxTime.compare_exchange_weak(preValue, elapsed)) { ++ // empty ++ } ++ ++ preValue = info->minTime.load(); ++ while (elapsed < preValue && !info->minTime.compare_exchange_weak(preValue, elapsed)) { ++ // empty ++ } ++} ++ ++void PerfManager::ResetPerfLog() ++{ ++ std::lock_guard lock(perfMutex_); ++ perfInfoList_.clear(); ++ dllm::utils::LogMessage(dllm::utils::INFO) << "Reset PerfLog in perf manager......"; ++} ++ ++std::string PerfManager::GetPerfLog() ++{ ++ std::stringstream ss; ++ std::string prefix; ++ std::shared_lock lock(perfMutex_); ++ for (auto &info : perfInfoList_) { ++ auto keyName = info.first; ++ uint64_t count = info.second.count.load(); ++ if (count > 0) { ++ ss << prefix << keyName << ": " << info.second.ToString(); ++ if (prefix.empty()) { ++ prefix = '\n'; ++ } ++ } ++ } ++ return ss.str(); ++} ++ ++void PerfManager::GetPerfInfoList(std::vector> &perfInfoList) const ++{ ++ std::shared_lock lock(perfMutex_); ++ for (auto &pair : perfInfoList_) { ++ const PerfInfo &info = pair.second; ++ const auto &keyName = pair.first; ++ uint64_t count = info.count.load(); ++ if (count > 0) { ++ perfInfoList.emplace_back(keyName, info); ++ } ++ } ++} ++ ++void PerfManager::PrintPerfLog() const ++{ ++ std::string perfLog = PerfManager::Instance()->GetPerfLog(); ++ dllm::utils::LogMessage(dllm::utils::INFO) << "[Perf Log]:\n" << perfLog; ++} ++ ++void PerfManager::Tick() ++{ ++ std::chrono::time_point nowTime = Clock::now(); ++ uint64_t elapsed = std::chrono::duration_cast(nowTime - prevLogTime_).count(); ++ uint64_t tickElapsed = std::chrono::duration_cast(nowTime - prevTickTime_).count(); ++ if (tickElapsed > 0) { ++ // Get the max frequency. ++ prevTickTime_ = nowTime; ++ std::shared_lock lock(perfMutex_); ++ for (auto &pair : perfInfoList_) { ++ PerfInfo &info = pair.second; ++ uint64_t tickCount = info.tickCount.load(); ++ uint64_t maxFrequency = info.maxFrequency.load(); ++ uint64_t frequency = tickCount * SECONDS_TO_NANO_UNIT / tickElapsed; ++ if (frequency > maxFrequency) { ++ info.maxFrequency.store(frequency); ++ } ++ info.tickCount.store(0); ++ } ++ } ++ ++ if (elapsed >= TRIGGER_PERF_LOG_NANO_INTERVAL) { ++ prevLogTime_ = nowTime; ++ std::string perfLog = GetPerfLog(); ++ dllm::utils::LogMessage(dllm::utils::INFO) << "[Perf Log]:\n" << perfLog; ++ } ++} ++ ++PerfPoint::~PerfPoint() noexcept ++{ ++ if (!isRecord_) { ++ try { ++ Record(); ++ } catch (const std::exception& e) { ++ std::cerr << "[PerfPoint] Exception in destructor: " << e.what() << std::endl; ++ } ++ } ++} ++ ++void PerfPoint::Record() ++{ ++ int64_t elapsed = std::chrono::duration_cast(Clock::now() - beg_).count(); ++ PerfManager *perfManager = PerfManager::Instance(); ++ if (perfManager != nullptr) { ++ perfManager->Add(key_, elapsed); ++ } ++ isRecord_ = true; ++} ++ ++void PerfPoint::Reset(std::string key) ++{ ++ beg_ = Clock::now(); ++ isRecord_ = false; ++ if (!key.empty()) { ++ key_ = key; ++ } ++} ++ ++void PerfPoint::RecordAndReset(std::string key) ++{ ++ Record(); ++ Reset(key); ++} ++ ++void PerfPoint::RecordElapsed(std::string key, uint64_t elapsed) ++{ ++ PerfManager *perfManager = PerfManager::Instance(); ++ if (perfManager != nullptr) { ++ perfManager->Add(key, elapsed); ++ } ++} ++ ++} // namespace perf ++} // namespace dllm +diff --git a/dllm_tools/csrc/perf/pybind.h b/dllm_tools/csrc/perf/pybind.h +new file mode 100644 +index 000000000..6c5123499 +--- /dev/null ++++ b/dllm_tools/csrc/perf/pybind.h +@@ -0,0 +1,40 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_PERF_PYBIND_H ++#define DLLM_PERF_PYBIND_H ++ ++#include ++#include ++ ++#include "include/perf/perf_manager.h" ++ ++ ++namespace dllm { ++namespace perf { ++ ++namespace py = pybind11; ++ ++void PyBind(py::module_ m) ++{ ++ m.doc() = "pybind11 for Performance"; ++ py::class_>(m, "PerfManager") ++ .def(py::init([]() { return std::make_shared(); })) ++ .def("print", &PerfManagerPythonWrapper::Print); ++} ++ ++} // namespace perf ++} // namespace dllm ++ ++#endif // DLLM_PERF_PYBIND_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/pybind_register.cpp b/dllm_tools/csrc/pybind_register.cpp +new file mode 100644 +index 000000000..7ca311959 +--- /dev/null ++++ b/dllm_tools/csrc/pybind_register.cpp +@@ -0,0 +1,35 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include ++#include ++ ++#include "kvc/pybind.h" ++#include "perf/pybind.h" ++#include "utils/pybind.h" ++ ++ ++namespace py = pybind11; ++ ++namespace dllm { ++ ++PYBIND11_MODULE(cpp_ext, m) ++{ ++ m.doc() = "pybind11 for DLLM"; ++ kvc::PyBind(m.def_submodule("kvc", "KV Cache")); ++ perf::PyBind(m.def_submodule("perf", "Performance")); ++ utils::PyBind(m.def_submodule("utils", "Utilities")); ++} ++ ++} // namespace dllm +diff --git a/dllm_tools/csrc/utils/expected.h b/dllm_tools/csrc/utils/expected.h +new file mode 100644 +index 000000000..9d75d45b7 +--- /dev/null ++++ b/dllm_tools/csrc/utils/expected.h +@@ -0,0 +1,291 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan ++ * PSL v2. You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY ++ * KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ++ * NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. See the ++ * Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: Thread pool. ++ */ ++#ifndef DLLM_UTILS_EXPECTED_H ++#define DLLM_UTILS_EXPECTED_H ++ ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace dllm { ++namespace utils { ++ ++inline std::logic_error g_getValueWhenError = std::logic_error("Accessing value in error state"); ++inline std::logic_error g_getErrorWhenOk = std::logic_error("Accessing error in value state"); ++ ++template ++class Unexpected { ++public: ++ static_assert(!std::is_same::value, "E must not be void"); ++ ++ Unexpected() = delete; ++ constexpr explicit Unexpected(const E &e) : value_(e) ++ { ++ } ++ ++ constexpr explicit Unexpected(E &&e) : value_(std::move(e)) ++ { ++ } ++ ++ template ::value>::type * = nullptr> ++ constexpr explicit Unexpected(Args &&...args) : value_(std::forward(args)...) ++ { ++ } ++ template &, Args &&...>::value>::type * = ++ nullptr> ++ constexpr explicit Unexpected(std::initializer_list l, Args &&...args) : value_(l, std::forward(args)...) ++ { ++ } ++ ++ constexpr const E &Value() const & ++ { ++ return value_; ++ } ++ constexpr E &Value() & ++ { ++ return value_; ++ } ++ constexpr E &&Value() && ++ { ++ return std::move(value_); ++ } ++ constexpr const E &&Value() const && ++ { ++ return std::move(value_); ++ } ++ ++private: ++ E value_; ++}; ++ ++// Expected class ++template ++class Expected { ++ static_assert(!std::is_reference::value, "T must not be a reference"); ++ static_assert(!std::is_same::type>::value, "T must not be in_place_t"); ++ static_assert(!std::is_same>::type>::value, "T must not be Unexpected"); ++ static_assert(!std::is_reference::value, "E must not be a reference"); ++ ++public: ++ // Type aliases ++ using ValueType = T; ++ using ErrorType = E; ++ ++ // Constructors for success ++ explicit constexpr Expected() : data_(std::in_place_index<0>, T{}) ++ { ++ } ++ ++ explicit constexpr Expected(const T &value) : data_(std::in_place_index<0>, value) ++ { ++ } ++ explicit constexpr Expected(T &&value) : data_(std::in_place_index<0>, std::forward(value)) ++ { ++ } ++ ++ // Constructors for error ++ explicit constexpr Expected(Unexpected error) : data_(std::in_place_index<1>, std::move(error.Value())) ++ { ++ } ++ ++ // Copy and move constructors ++ constexpr Expected(const Expected &) = default; ++ constexpr Expected(Expected &&) noexcept = default; ++ ++ // Assignment operators ++ Expected &operator=(const Expected &) = default; ++ Expected &operator=(Expected &&) noexcept = default; ++ ++ // Check if the expected contains a value ++ constexpr bool HasValue() const noexcept ++ { ++ return data_.index() == 0; ++ } ++ ++ // Access value ++ constexpr T &Value() & ++ { ++ if (!HasValue()) { ++ throw g_getValueWhenError; ++ } ++ return std::get<0>(data_); ++ } ++ ++ constexpr const T &Value() const & ++ { ++ return Value(); ++ } ++ ++ constexpr T &&Value() && ++ { ++ if (!HasValue()) { ++ throw g_getValueWhenError; ++ } ++ return std::move(std::get<0>(data_)); ++ } ++ ++ // Access error ++ constexpr E &Error() & ++ { ++ if (HasValue()) { ++ throw g_getErrorWhenOk; ++ } ++ return std::get<1>(data_); ++ } ++ ++ constexpr const E &Error() const & ++ { ++ return Error(); ++ } ++ ++ constexpr E &&Error() && ++ { ++ if (HasValue()) { ++ throw std::logic_error("Accessing error in value state"); ++ } ++ return std::move(std::get<1>(data_)); ++ } ++ ++ // Emplace a new value ++ template ++ void Emplace(Args &&...args) ++ { ++ data_.template emplace<0>(std::forward(args)...); ++ } ++ ++ // Transform value ++ template ++ auto Transform(F &&f) const ++ { ++ using Result = std::invoke_result_t; ++ if (HasValue()) { ++ return Expected(std::invoke(std::forward(f), Value())); ++ } ++ return Expected(Unexpected(Error())); ++ } ++ ++ // Transform error ++ template ++ auto TransformError(F &&f) const ++ { ++ using Result = std::invoke_result_t; ++ if (!HasValue()) { ++ return Expected(Unexpected(std::invoke(std::forward(f), Error()))); ++ } ++ return Expected(Value()); ++ } ++ ++ const T ValueOr(T defaultValue) const & ++ { ++ return HasValue() ? Value() : defaultValue; ++ } ++ ++ const T ValueOr(T defaultValue) && ++ { ++ return HasValue() ? std::move(Value()) : std::move(defaultValue); ++ } ++ ++ const E ErrorOr(E defaultErr) const & ++ { ++ return HasValue() ? defaultErr : this->Error(); ++ } ++ ++ const E ErrorOr(E defaultErr) && ++ { ++ return HasValue() ? std::move(this->Error()) : std::move(defaultErr); ++ } ++ ++ template ++ auto OrElse(F &&f) const & ++ { ++ if (HasValue()) { ++ return *this; ++ } else { ++ return std::invoke(std::forward(f), Error()); ++ } ++ } ++ ++ template ++ auto OrElse(F &&f) && ++ { ++ if (HasValue()) { ++ return std::move(*this); ++ } else { ++ return std::invoke(std::forward(f), std::move(Error())); ++ } ++ } ++ ++ template ++ auto AndThen(F &&f) const & ++ { ++ if (HasValue()) { ++ return std::invoke(std::forward(f), Value()); ++ } else { ++ return *this; ++ } ++ } ++ ++ template ++ auto AndThen(F &&f) && ++ { ++ if (HasValue()) { ++ return std::invoke(std::forward(f), std::move(Value())); ++ } else { ++ return std::move(*this); ++ } ++ } ++ ++private: ++ std::variant data_; ++}; ++ ++// Factory functions ++template ++Expected MakeExpected(T &&value) ++{ ++ return Expected, E>(std::forward(value)); ++} ++ ++template ++Expected MakeUnexpected(E &&error) ++{ ++ return Expected>(Unexpected>(std::forward(error))); ++} ++ ++// Factory functions ++template ++Expected Ok(T &&value) ++{ ++ return MakeExpected, E>(std::forward(value)); ++} ++ ++template ++Expected Err(E &&error) ++{ ++ return MakeUnexpected>(std::forward(error)); ++} ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_EXPECTED_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/file_util.h b/dllm_tools/csrc/utils/file_util.h +new file mode 100644 +index 000000000..333916eaf +--- /dev/null ++++ b/dllm_tools/csrc/utils/file_util.h +@@ -0,0 +1,101 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_UTILS_FILE_UTIL_H ++#define DLLM_UTILS_FILE_UTIL_H ++ ++#include ++#include ++#include ++ ++#include ++#include ++ ++#include "strings_util.h" ++#include "file_util.h" ++ ++namespace dllm { ++namespace utils { ++ ++inline std::string GetCurrentDirectory() ++{ ++ char buffer[PATH_MAX]; ++ if (getcwd(buffer, sizeof(buffer)) != nullptr) { ++ return std::string(buffer); ++ } else { ++ perror("getcwd failed"); ++ return ""; ++ } ++} ++ ++inline std::string JoinPath(const std::string &directory, const std::string &filename) ++{ ++ if (directory.empty()) { ++ return filename; ++ } ++ if (directory.back() == '/') { ++ return directory + filename; ++ } ++ return directory + "/" + filename; ++} ++ ++inline bool FileExist(const std::string &filename, int mode = F_OK) ++{ ++ return access(filename.c_str(), mode) == 0; ++} ++ ++inline std::vector SplitPath(const std::string &path) ++{ ++ std::vector parts; ++ std::stringstream ss(path); ++ std::string item; ++ char delimiter = '/'; ++ ++ while (std::getline(ss, item, delimiter)) { ++ if (!item.empty()) { // 跳过空部分 ++ parts.push_back(item); ++ } ++ } ++ return parts; ++} ++ ++inline void CreateDir(const std::string &dir, bool recursively, uint32_t mode) ++{ ++ if (!recursively) { ++ int ret = mkdir(dir.c_str(), mode); ++ if (ret != 0) { ++ std::stringstream ss; ++ ss << "mkdir path: " << dir << " failed with code: " << ret << ", errno: " << errno ++ << ", errmsg: " << StrErr(errno); ++ throw std::runtime_error(ss.str()); ++ } ++ } else { ++ std::vector segments = SplitPath(dir); ++ std::string partialPath; ++ for (const auto &segment : segments) { ++ partialPath = (partialPath.empty() || partialPath == "/") ? partialPath.append(segment) ++ : partialPath.append("/").append(segment); ++ // Check whether the partialPath is directory. ++ struct stat statBuf {}; ++ if (stat(partialPath.c_str(), &statBuf) == 0 && S_ISDIR(statBuf.st_mode)) { ++ continue; ++ } ++ CreateDir(partialPath, false, mode); ++ } ++ } ++} ++ ++} // namespace utils ++} // namespace dllm ++#endif // DLLM_UTILS_FILE_UTIL_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/kvc_future.h b/dllm_tools/csrc/utils/kvc_future.h +new file mode 100644 +index 000000000..2dcf6292c +--- /dev/null ++++ b/dllm_tools/csrc/utils/kvc_future.h +@@ -0,0 +1,59 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: A simple extension of shared_future. ++ */ ++#ifndef DLLM_UTILS_KVC_FUTURE_H ++#define DLLM_UTILS_KVC_FUTURE_H ++ ++#include ++#include ++#include ++ ++namespace dllm { ++namespace utils { ++ ++class FutureTimeoutException : public std::runtime_error { ++public: ++ explicit FutureTimeoutException(const std::string &message) : std::runtime_error(message) ++ { ++ } ++}; ++ ++template ++class KVCFuture : public std::shared_future { ++public: ++ using std::shared_future::shared_future; ++ T Get(uint64_t timeoutMs) const ++ { ++ if (!this->valid()) { ++ throw std::runtime_error("The future is invalid."); ++ } ++ if (this->wait_for(std::chrono::milliseconds(timeoutMs)) == std::future_status::timeout) { ++ throw FutureTimeoutException("The future is not ready, please try again."); ++ } ++ return this->get(); ++ } ++ bool IsReady() const ++ { ++ return this->wait_for(std::chrono::milliseconds(0)) == std::future_status::ready; ++ } ++}; ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_KVC_FUTURE_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/logging.cpp b/dllm_tools/csrc/utils/logging.cpp +new file mode 100644 +index 000000000..039457a5d +--- /dev/null ++++ b/dllm_tools/csrc/utils/logging.cpp +@@ -0,0 +1,117 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#include "include/utils/logging.h" ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++#include "file_util.h" ++ ++namespace dllm { ++namespace utils { ++ ++const auto DEFAULT_LOG_LEVEL = spdlog::level::info; ++const auto DEFAULT_FLUSH_LEVEL = spdlog::level::err; ++ ++std::string GetCurrentTime() ++{ ++ auto now = std::chrono::system_clock::now(); ++ auto inTimeT = std::chrono::system_clock::to_time_t(now); ++ std::stringstream ss; ++ ss << std::put_time(std::localtime(&inTimeT), "%Y-%m-%d_%H-%M-%S"); ++ return ss.str(); ++} ++ ++void Logger::SetupRootLogging(const std::string &fileDirectory, const std::string &baseName, size_t maxLogSize, ++ size_t maxFilesNum, bool enableConsoleLogging) ++{ ++ auto pid = std::to_string(getpid()); ++ std::string timestamp = GetCurrentTime(); ++ auto logDir = SetLogDir(fileDirectory); ++ std::string fileName = fmt::format("{}_{}_{}.log", baseName, pid, timestamp); ++ auto fullPathFileName = JoinPath(logDir, fileName); ++ auto asyncFileSink = ++ std::make_shared(fullPathFileName, maxLogSize, maxFilesNum); ++ std::vector sinks{ asyncFileSink }; ++ if (enableConsoleLogging) { ++ auto consoleSink = std::make_shared(); ++ sinks.emplace_back(std::move(consoleSink)); ++ } ++ logger_ = std::make_shared(baseName, std::begin(sinks), std::end(sinks)); ++ logger_->set_level(DEFAULT_LOG_LEVEL); ++ logger_->flush_on(DEFAULT_FLUSH_LEVEL); ++ logger_->set_pattern("%Y-%m-%dT%H:%M:%S.%e | %l | %P:%t | %v"); ++ spdlog::info("KVC SetupRootLogging, file: {}, enable_console: {}, flush_level: {}, log_level: {}", fullPathFileName, ++ enableConsoleLogging, spdlog::level::to_string_view(DEFAULT_LOG_LEVEL), ++ spdlog::level::to_string_view(DEFAULT_FLUSH_LEVEL)); ++ spdlog::set_default_logger(logger_); ++ std::atexit(&spdlog::shutdown); ++} ++ ++std::string Logger::SetLogDir(const std::string &fileDirectory) const ++{ ++ std::string logDir; ++ // Use fileDirectory first ++ if (!fileDirectory.empty()) { ++ logDir = fileDirectory; ++ } else { ++ // Use env DS_LOG_PATH second ++ auto dsLogPath = std::getenv("DS_LOG_PATH"); ++ if (dsLogPath != nullptr) { ++ logDir = dsLogPath; ++ } else { ++ // Use current path finally ++ logDir = GetCurrentDirectory(); ++ } ++ } ++ if (!FileExist(logDir)) { ++ const int32_t dirPrivilege = 0755; ++ CreateDir(logDir, true, dirPrivilege); ++ } ++ return logDir; ++} ++ ++void Logger::SetLogLevel(std::string logLevel) ++{ ++ std::transform(logLevel.begin(), logLevel.end(), logLevel.begin(), ::toupper); ++ auto it = LOG_LEVEL_MAP.find(logLevel); ++ if (it != LOG_LEVEL_MAP.end()) { ++ logger_->set_level(it->second); ++ spdlog::info("Logger set_level : {}", logLevel); ++ } else { ++ throw std::runtime_error(fmt::format("Invalid log level: {}", logLevel)); ++ } ++} ++ ++void Logger::SetFlushLevel(std::string logLevel) ++{ ++ std::transform(logLevel.begin(), logLevel.end(), logLevel.begin(), ::toupper); ++ auto it = LOG_LEVEL_MAP.find(logLevel); ++ if (it != LOG_LEVEL_MAP.end()) { ++ logger_->flush_on(it->second); ++ spdlog::info("Logger flush_on : {}", logLevel); ++ } else { ++ throw std::runtime_error(fmt::format("Invalid log level: {}", logLevel)); ++ } ++} ++ ++} // namespace utils ++} // namespace dllm +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/pybind.h b/dllm_tools/csrc/utils/pybind.h +new file mode 100644 +index 000000000..7af95c0a3 +--- /dev/null ++++ b/dllm_tools/csrc/utils/pybind.h +@@ -0,0 +1,58 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_UTILS_PYBIND_H ++#define DLLM_UTILS_PYBIND_H ++ ++#include ++#include ++ ++#include "include/utils/logging.h" ++ ++ ++namespace py = pybind11; ++ ++namespace dllm { ++namespace utils { ++ ++void PyBind(py::module_ m) ++{ ++ using namespace pybind11::literals; ++ m.doc() = "pybind11 for Utilities"; ++ py::class_>(m, "Logger") ++ .def(py::init<>([](const std::string fileDirectory, const std::string &baseName, size_t maxLogSize, ++ size_t maxFilesNum, bool enableConsoleLogging) { ++ auto logger = std::make_shared(); ++ logger->SetupRootLogging(fileDirectory, baseName, maxLogSize, maxFilesNum, enableConsoleLogging); ++ return logger; ++ }), ++ "log_dir"_a = "", "base_name"_a = DEFAULT_BASE_NAME, "max_log_size"_a = DEFAULT_LOG_SIZE, ++ "max_files_num"_a = DEFAULT_ROTATION_FILES, "enable_console_logging"_a = false) ++ .def("set_log_level", &Logger::SetLogLevel, "log_level"_a) ++ .def("set_flush_level", &Logger::SetFlushLevel, "log_level"_a) ++ .def_static( ++ "info", [](const std::string &msg) { spdlog::info(msg); }, "msg"_a) ++ .def_static( ++ "error", [](const std::string &msg) { spdlog::error(msg); }, "msg"_a) ++ .def_static( ++ "debug", [](const std::string &msg) { spdlog::info(msg); }, "msg"_a) ++ .def_static( ++ "warn", [](const std::string &msg) { spdlog::warn(msg); }, "msg"_a) ++ .def_static("log_off", []() { spdlog::set_level(spdlog::level::off); }); ++} ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_PYBIND_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/strings_util.h b/dllm_tools/csrc/utils/strings_util.h +new file mode 100644 +index 000000000..c562494e1 +--- /dev/null ++++ b/dllm_tools/csrc/utils/strings_util.h +@@ -0,0 +1,72 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++#ifndef DLLM_UTILS_STRINGS_UTIL_H ++#define DLLM_UTILS_STRINGS_UTIL_H ++ ++#include ++#include ++#include ++ ++namespace dllm { ++namespace utils { ++ ++const std::streampos LOG_MAX_SIZE_LIMIT = 25000; ++ ++inline std::streampos GetSize(std::iostream *ss) ++{ ++ if (!ss) { ++ return 0; ++ } ++ auto currentPos = ss->tellg(); ++ ss->seekg(0, ss->end); ++ auto size = ss->tellg(); ++ ss->seekg(currentPos, ss->beg); ++ return size; ++} ++ ++/** ++ * @brief Print vector. ++ * @param[in] vec Vector to print. ++ * @return Return string. ++ */ ++template ++std::string VectorToString(const Vec &vec) ++{ ++ std::stringstream out; ++ auto totalCount = vec.size(); ++ decltype(totalCount) count = 0; ++ for (auto &item : vec) { ++ out << item << " "; ++ auto length = GetSize(&out); ++ count++; ++ if (length > LOG_MAX_SIZE_LIMIT) { ++ out << "...(" << (totalCount - count) << ")"; ++ break; ++ } ++ } ++ return out.str(); ++} ++ ++inline std::string StrErr(int errNum) ++{ ++ char errBuf[256]; ++ errBuf[0] = '\0'; ++ return strerror_r(errNum, errBuf, sizeof errBuf); ++} ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_STRINGS_UTIL_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/thread_pool.cpp b/dllm_tools/csrc/utils/thread_pool.cpp +new file mode 100644 +index 000000000..c29d6e9bb +--- /dev/null ++++ b/dllm_tools/csrc/utils/thread_pool.cpp +@@ -0,0 +1,179 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: Thread pool. ++ */ ++#include "thread_pool.h" ++ ++#include ++ ++namespace dllm { ++namespace utils { ++ ++// If ThreadPool throw std::system_error when constructing, ThreadPool's destructor won't execute, ++// thread will release without join and cause abort. ++// Wrap ThreadWorkers and rewrite its destructor to keep all threads will join normally when ThreadPool construct fail. ++ThreadWorkers::~ThreadWorkers() ++{ ++ this->Join(); ++} ++ ++void ThreadWorkers::Join() ++{ ++ for (auto &workerPair : *this) { ++ auto thread = &workerPair.second; ++ if (thread->Joinable()) { ++ thread->Join(); ++ } ++ } ++} ++ ++size_t ThreadPool::GetThreadsNum() ++{ ++ std::shared_lock workerLock(workersMtx_); ++ return workers_.size(); ++} ++ ++void ThreadPool::DoThreadWork() ++{ ++ while (true) { ++ std::function task; ++ { ++ // 1st: Proceed Condition. ++ std::unique_lock lock(this->mtx_); ++ // After threadIdleTimeoutMs_, if taskQ is still empty, try to destroy this thread. ++ if (!this->proceedCV_.wait_for(lock, threadIdleTimeoutMs_, ++ [this] { return this->shutDown_ || !this->taskQ_.empty(); })) { ++ if (GetThreadsNum() > minThreadNum_) { ++ auto tid = std::this_thread::get_id(); ++ DestroyUnuseWorker(tid); ++ return; ++ } else { ++ continue; ++ } ++ } ++ ++ // ShutDown and Finished. ++ if (this->shutDown_ && (droppable_ || this->taskQ_.empty())) { ++ return; ++ } ++ ++ // 2nd: Dequeue Task. ++ task = std::move(this->taskQ_.front()); ++ this->taskQ_.pop(); ++ } ++ { ++ // 3rd: Execute Task. ++ runningThreadsNum_++; ++ task(); ++ runningThreadsNum_--; ++ } ++ } ++} ++ ++void ThreadPool::AddThread() ++{ ++ std::lock_guard workerLock(workersMtx_); ++ auto thread = Thread([this] { this->DoThreadWork(); }); ++ thread.SetName(name_); ++ workers_[thread.GetId()] = std::move(thread); ++} ++ ++void ThreadPool::DestroyUnuseWorker(std::thread::id tid) ++{ ++ std::lock_guard workerLock(workersMtx_); ++ if (!shutDown_ && workers_.find(tid) != workers_.end()) { ++ if (workers_[tid].Joinable()) { ++ workers_[tid].Detach(); ++ } ++ (void)workers_.erase(tid); ++ } ++} ++ ++void ThreadPool::TryToAddThreadIfNeeded() ++{ ++ { ++ std::shared_lock lock(workersMtx_); ++ auto threadNum = workers_.size(); ++ if (threadNum >= maxThreadNum_ || threadNum >= taskQ_.size() + runningThreadsNum_) { ++ return; ++ } ++ } ++ AddThread(); ++} ++ ++ThreadPool::ThreadPool(size_t minThreadNum, size_t maxThreadNum, std::string name, bool droppable, ++ int threadIdleTimeoutMs) ++ : shutDown_(false), ++ joined_(false), ++ droppable_(droppable), ++ minThreadNum_(minThreadNum), ++ maxThreadNum_(maxThreadNum), ++ name_(name), ++ threadIdleTimeoutMs_(threadIdleTimeoutMs) ++{ ++ if (maxThreadNum_ == 0) { ++ if (minThreadNum_ == 0) { ++ throw std::runtime_error("ThreadPool: minThreadNum == maxThreadNum == 0, won't create any thread."); ++ } ++ maxThreadNum_ = minThreadNum_; ++ } ++ if (minThreadNum_ > maxThreadNum_) { ++ throw std::runtime_error("ThreadPool: minThreadNum > maxThreadNum"); ++ } ++ // create core workers when construct ++ workers_.reserve(minThreadNum_); ++ for (size_t i = 0; i < minThreadNum_; ++i) { ++ AddThread(); ++ } ++} ++ ++// The destructor joins all threads. ++ThreadPool::~ThreadPool() ++{ ++ bool isShutDown = false; ++ bool isJoined = false; ++ { ++ std::unique_lock lock(mtx_); ++ isShutDown = shutDown_; ++ isJoined = joined_; ++ } ++ if (!isShutDown) { ++ ShutDown(); ++ } ++ if (!isJoined) { ++ Join(); ++ } ++} ++ ++void ThreadPool::Join() ++{ ++ workers_.Join(); ++ joined_ = true; ++} ++ ++void ThreadPool::ShutDown() ++{ ++ { ++ std::unique_lock lock(mtx_); ++ shutDown_ = true; ++ } ++ // Here, either shutdown correctly checked or have already been blocking. ++ // Thus, safe to unprotected by lock(mtx_). ++ proceedCV_.notify_all(); ++} ++} // namespace utils ++} // namespace dllm +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/thread_pool.h b/dllm_tools/csrc/utils/thread_pool.h +new file mode 100644 +index 000000000..6f9f2953f +--- /dev/null ++++ b/dllm_tools/csrc/utils/thread_pool.h +@@ -0,0 +1,245 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: Thread pool. ++ */ ++#ifndef DLLM_UTILS_THREAD_POOL_H ++#define DLLM_UTILS_THREAD_POOL_H ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace dllm { ++namespace utils { ++ ++class Thread { ++public: ++ Thread() noexcept = default; ++ Thread(Thread &) = delete; ++ Thread(const Thread &) = delete; ++ Thread(const Thread &&) = delete; ++ ++ Thread(Thread &&other) noexcept ++ { ++ Swap(other); ++ } ++ ~Thread() = default; ++ ++ template ++ explicit Thread(F &&f, Args &&...args) ++ : thread_(WrapFn, std::decay_t...>, std::forward(f), std::forward(args)...) ++ { ++ } ++ ++ Thread &operator=(const Thread &) = delete; ++ ++ Thread &operator=(Thread &&other) noexcept ++ { ++ if (this == &other) { ++ return *this; ++ } ++ thread_ = std::move(other.thread_); ++ return *this; ++ } ++ ++ std::thread::id GetId() const noexcept ++ { ++ return thread_.get_id(); ++ } ++ ++ bool Joinable() const noexcept ++ { ++ return thread_.joinable(); ++ } ++ ++ void Join() ++ { ++ thread_.join(); ++ } ++ ++ void Detach() ++ { ++ thread_.detach(); ++ } ++ ++ void Swap(Thread &other) noexcept ++ { ++ thread_.swap(other.thread_); ++ } ++ ++ void SetName(const std::string &name) ++ { ++ const size_t taskCommLen = 15; ++ auto truncateName = name.substr(0, taskCommLen); ++ auto handle = thread_.native_handle(); ++ (void)pthread_setname_np(handle, truncateName.c_str()); ++ } ++ ++private: ++ // If an unhandled exception occurs in an std::thread, the stack is unwound before std::terminate is called, which ++ // makes it impossible to find the location of the exception. The supposed fix was to use noexcept on the internal ++ // thread main function ++ template ++ static auto WrapFn(F &&f, Args &&...args) noexcept -> decltype(std::ref(f)(std::forward(args)...)) ++ { ++ return std::ref(f)(std::forward(args)...); ++ } ++ ++ std::thread thread_; ++}; ++ ++class ThreadWorkers : public std::unordered_map { ++public: ++ ~ThreadWorkers(); ++ ++ void Join(); ++}; ++ ++class ThreadPool { ++public: ++ ThreadPool() = delete; ++ ++ ThreadPool(const ThreadPool &) = delete; ++ ++ ThreadPool(ThreadPool &&) = delete; ++ ++ ThreadPool &operator=(ThreadPool &&) = delete; ++ ++ ThreadPool &operator=(const ThreadPool &) = delete; ++ ++ explicit ThreadPool(size_t minThreadNum, size_t maxThreadNum = 0, std::string name = "", bool droppable = false, ++ int threadIdleTimeoutMs = 60 * 1000); ++ ++ // Using a variable in the return type that has not been declared yet ++ // (because the return type declaration goes before the parameters type declaration). ++ // Add new work item to the pool. ++ template ++ auto Submit(F &&f, Args &&...args) -> std::shared_future::type> ++ { ++ using RetType = typename std::result_of::type; ++ ++ // Wrapper over promise, or single-element-blocking-queue. ++ auto task = ++ std::make_shared>(std::bind(std::forward(f), std::forward(args)...)); ++ ++ std::shared_future res = task->get_future(); ++ { ++ std::unique_lock lock(mtx_); ++ if (shutDown_) { ++ const std::string error = "Submit after Shutdown Error."; ++ throw std::runtime_error(error.c_str()); ++ } ++ // Future is set after during (*task)(), a synchronous way to notify others waiting for it. ++ taskQ_.emplace([task]() { (*task)(); }); ++ TryToAddThreadIfNeeded(); ++ } ++ // Here, impossible to be empty; so no dead wait occurs. ++ // Thus, safe to unprotected by lock(mtx_). ++ proceedCV_.notify_one(); ++ return res; ++ } ++ ++ template ++ void Execute(F &&f, Args &&...args) ++ { ++ using RetType = typename std::result_of::type; ++ auto task = std::bind(std::forward(f), std::forward(args)...); ++ static_assert(std::is_void::value, "Return value type must be void!"); ++ ++ std::unique_lock lock(mtx_); ++ if (shutDown_) { ++ throw std::runtime_error("Submit after Shutdown Error."); ++ } ++ taskQ_.emplace(std::move(task)); ++ TryToAddThreadIfNeeded(); ++ proceedCV_.notify_one(); ++ } ++ ++ /** ++ * @brief Get the number of threads. ++ * @return The number of threads created by ThreadPool. ++ */ ++ size_t GetThreadsNum(); ++ ++ ~ThreadPool(); ++ ++protected: ++ void ShutDown(); ++ ++ void Join(); ++ ++ void DoThreadWork(); ++ ++ /** ++ * @brief Try to add thread if needed, will ignore error if threads resource is not enough. ++ */ ++ void TryToAddThreadIfNeeded(); ++ ++ /** ++ * @brief Join and erase unused thread in workers_ ++ * @param[in] tid The Thread id ready to destroy. ++ */ ++ void DestroyUnuseWorker(std::thread::id tid); ++ ++ /** ++ * @brief Add thread directly, may throw system error if threads resource is not enough. ++ */ ++ void AddThread(); ++ ++private: ++ using Task = std::function; ++ ThreadWorkers workers_; ++ ++ std::queue taskQ_; ++ ++ std::mutex mtx_; ++ ++ // The mutext protecting workers_ get size, erase, add concurrently ++ std::shared_timed_mutex workersMtx_; ++ std::condition_variable proceedCV_; ++ ++ bool shutDown_; ++ bool joined_; ++ bool droppable_; ++ ++ size_t minThreadNum_; ++ size_t maxThreadNum_; ++ ++ std::string name_; ++ ++ // The num of threads which is running task. ++ std::atomic runningThreadsNum_{ 0 }; ++ ++ // If a threads wait for threadIdleTimeoutMs_ and no task need to execute, try to destroy it. ++ std::chrono::milliseconds threadIdleTimeoutMs_; ++}; ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_THREAD_POOL_H +\ No newline at end of file +diff --git a/dllm_tools/csrc/utils/timer.h b/dllm_tools/csrc/utils/timer.h +new file mode 100644 +index 000000000..5a460bb2b +--- /dev/null ++++ b/dllm_tools/csrc/utils/timer.h +@@ -0,0 +1,125 @@ ++/** ++ * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. ++ * ++ * This software is licensed under Mulan PSL v2. ++ * You can use this software according to the terms and conditions of the Mulan PSL v2. ++ * You may obtain a copy of Mulan PSL v2 at: ++ * ++ * http://license.coscl.org.cn/MulanPSL2 ++ * ++ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++ * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++ * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ * See the Mulan PSL v2 for more details. ++ */ ++ ++/** ++ * Description: Timer utils. ++ */ ++#ifndef DLLM_UTILS_TIMER_H ++#define DLLM_UTILS_TIMER_H ++ ++#include ++#include ++#include ++ ++namespace dllm { ++namespace utils { ++ ++inline uint64_t TsToNs(struct timespec &ts) ++{ ++ const uint64_t NS_TO_SECS = 1'000'000'000ul; ++ return ts.tv_sec * (NS_TO_SECS) + ts.tv_nsec; ++} ++ ++inline struct timespec NsToTs(uint64_t ns) ++{ ++ const uint64_t NS_TO_SECS = 1'000'000'000ul; ++ return ++ { ++ .tv_sec = static_cast(ns / NS_TO_SECS), .tv_nsec = static_cast(ns % NS_TO_SECS) ++ }; ++} ++ ++inline std::time_t GetSteadyClockTimeStampUs() ++{ ++ // Attention: System clock is not monotonic. ++ // Instead, steady clock is monotonic. ++ return std::chrono::time_point_cast(std::chrono::steady_clock::now()) ++ .time_since_epoch() ++ .count(); ++} ++ ++inline std::time_t GetSystemClockTimeStampUs() ++{ ++ // Attention: System clock is not monotonic. ++ return std::chrono::time_point_cast(std::chrono::system_clock::now()) ++ .time_since_epoch() ++ .count(); ++} ++ ++class Timer { ++public: ++ Timer() : beg_(clock::now()), timeoutMs_(0) ++ { ++ } ++ ++ Timer(int64_t timeoutMs) : beg_(clock::now()), timeoutMs_(timeoutMs) ++ { ++ } ++ ++ ~Timer() = default; ++ ++ void Reset() ++ { ++ beg_ = clock::now(); ++ } ++ ++ double ElapsedSecond() const ++ { ++ return std::chrono::duration_cast(clock::now() - beg_).count(); ++ } ++ ++ double ElapsedMilliSecond() const ++ { ++ return std::chrono::duration_cast(clock::now() - beg_).count(); ++ } ++ ++ double ElapsedMicroSecond() const ++ { ++ return std::chrono::duration_cast(clock::now() - beg_).count(); ++ } ++ ++ double ElapsedSecondAndReset() ++ { ++ double elapsed = std::chrono::duration_cast(clock::now() - beg_).count(); ++ beg_ = clock::now(); ++ return elapsed; ++ } ++ ++ double ElapsedMilliSecondAndReset() ++ { ++ double elapsed = std::chrono::duration_cast(clock::now() - beg_).count(); ++ beg_ = clock::now(); ++ return elapsed; ++ } ++ ++ int64_t GetRemainingTimeMs() ++ { ++ int64_t remaining = timeoutMs_ - ElapsedMilliSecond(); ++ return std::max((int64_t)0, remaining); ++ } ++ ++private: ++ typedef std::chrono::steady_clock clock; ++ typedef std::chrono::duration > second; ++ typedef std::chrono::duration millisecond; ++ typedef std::chrono::duration microsecond; ++ std::chrono::time_point beg_; ++ int64_t timeoutMs_; ++}; ++ ++} // namespace utils ++} // namespace dllm ++ ++#endif // DLLM_UTILS_TIMER_H +\ No newline at end of file +diff --git a/dllm_tools/dllm/__init__.py b/dllm_tools/dllm/__init__.py +new file mode 100644 +index 000000000..90685a702 +--- /dev/null ++++ b/dllm_tools/dllm/__init__.py +@@ -0,0 +1,13 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. +diff --git a/dllm_tools/dllm/balancer/README.md b/dllm_tools/dllm/balancer/README.md +new file mode 100644 +index 000000000..bf6832672 +--- /dev/null ++++ b/dllm_tools/dllm/balancer/README.md +@@ -0,0 +1,5 @@ ++放所有调度器相关的python代码 ++1、PD调度 ++2、DP负载均衡 ++3、EP负载均衡 ++4、KVC位置调度 +\ No newline at end of file +diff --git a/dllm_tools/dllm/balancer/__init__.py b/dllm_tools/dllm/balancer/__init__.py +new file mode 100644 +index 000000000..0c6215390 +--- /dev/null ++++ b/dllm_tools/dllm/balancer/__init__.py +@@ -0,0 +1,15 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from dllm.balancer.balancer import Balancer +diff --git a/dllm_tools/dllm/balancer/balancer.py b/dllm_tools/dllm/balancer/balancer.py +new file mode 100644 +index 000000000..522eb3d3b +--- /dev/null ++++ b/dllm_tools/dllm/balancer/balancer.py +@@ -0,0 +1,363 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++import logging ++import asyncio ++from typing import List, Dict ++import time ++ ++import ray ++import aiohttp ++from prometheus_client.parser import text_string_to_metric_families ++ ++from dllm import constants ++from dllm.entities import Role, SchedulerPolicy, VllmInstanceInfo, DispatchResult, MetricsInfo, VllmInstanceStatus ++from dllm.constants import CONTROLLER_ACTOR_NAME, DLLM_NAMESPACE ++ ++ ++logger = logging.getLogger(__name__) ++ ++ ++class Balancer: ++ def __init__( ++ self, ++ policy: SchedulerPolicy = SchedulerPolicy.ROUND_ROBIN, ++ ): ++ self.policy = policy ++ self.role_2_instances: Dict[Role, VllmInstanceInfo] = {} # prefill/decode/mixed => VllmInstanceInfo ++ self.instance_infos: Dict[str, VllmInstanceInfo] = {} # id -> VllmInstanceInfo ++ self.monitoring_instance: Dict[str, VllmInstanceInfo] = {} # id -> VllmInstanceInfo ++ self.instance_metrics: Dict[str, MetricsInfo] = {} # id -> MetricsInfo ++ self._round_robin_index_p = 0 ++ self._round_robin_index_d = 0 ++ self._round_robin_index_m = 0 ++ self.last_heartbeat: Dict[str, float] = {} ++ self._controller_handle: ray.actor.ActorHandle = None ++ self.reported_failures = set() # 用于跟踪已报告的失败实例 ++ self.lock = asyncio.Lock() # 用于并发控制 ++ ++ # start update metrics loop ++ loop = asyncio.get_event_loop() ++ loop.create_task(self.update_vllm_instance_metrics()) ++ ++ async def update_vllm_instance_metrics(self): ++ """ ++ 更新 vLLM 实例的性能指标。 ++ ++ 该函数用于采集并更新与 vLLM 实例相关的监控指标, ++ 以便后续进行负载均衡和性能调优。 ++ ++ Args: ++ ++ Returns: ++ ++ """ ++ while True: ++ try: ++ async with aiohttp.ClientSession() as session: ++ await asyncio.gather( ++ *[ ++ self._query_instance_metrics(session, instance_info) ++ for instance_info in self.instance_infos.values() ++ if instance_info.uri is not None ++ ], ++ return_exceptions=True, ++ ) ++ await asyncio.sleep(constants.METRICS_UPDATE_CYCLE) ++ except Exception as e: ++ logger.error("create request session error: %s", e) ++ ++ def dispatch_request(self) -> DispatchResult: ++ """ ++ 分发请求给相应的处理模块。 ++ ++ 根据请求内容或类型,选择合适的处理逻辑进行处理。 ++ ++ Args: ++ ++ Returns: ++ 处理结果, 包含prefill_vllm_instance_uri和decode_vllm_instance_uri ++ """ ++ if self.policy == SchedulerPolicy.ROUND_ROBIN: ++ return self._round_robin_pair() ++ raise ValueError(f"Unsupported policy: {self.policy}") ++ ++ def get_all_instance(self) -> Dict[str, VllmInstanceInfo]: ++ '''Return all vllm instance.''' ++ return {key: item for key, item in self.instance_infos.items() if item.uri is not None} ++ ++ async def viz_profile(self, is_start: True): ++ '''Interfaces to start and stop profiling.''' ++ from dllm.monkey_patch.viz_profile.common import viz_profile_basic ++ viz_profile_basic("dllm_balancer", is_start) ++ ++ def update_vllm_instance_info(self, infos: List[VllmInstanceInfo]): ++ """ ++ 更新 vLLM 实例的相关信息。 ++ ++ 该方法负责获取并刷新 vLLM 实例的状态或配置信息, ++ 以支持负载均衡和监控逻辑。 ++ ++ Args: ++ vLLM 实例的相关信息类 ++ ++ Returns: ++ ++ """ ++ for item in infos: ++ self.instance_infos[item.id] = item ++ self.instance_metrics[item.id] = MetricsInfo() ++ ++ # reconstruct the role map ++ self.role_2_instances.clear() ++ for _, instance_info in self.instance_infos.items(): ++ if instance_info.role not in self.role_2_instances: ++ self.role_2_instances[instance_info.role] = [] ++ self.role_2_instances[instance_info.role].append(instance_info) ++ ++ async def add_vllm_instance(self, vllm_instance_info: VllmInstanceInfo): ++ """ ++ Add a VLLM instance information to the balancer. ++ ++ Args: ++ instance_info (VllmInstanceInfo): The information of the VLLM instance to add. ++ ++ Updates the internal state by adding the instance information, ++ metrics, and last heartbeat time. ++ """ ++ async with self.lock: ++ instance_id = vllm_instance_info.id ++ ++ if instance_id in self.reported_failures: ++ self.reported_failures.remove(instance_id) ++ ++ self.instance_infos[instance_id] = vllm_instance_info ++ self.instance_metrics[instance_id] = MetricsInfo() ++ self.last_heartbeat[instance_id] = time.time() ++ ++ self.role_2_instances.clear() ++ for _, instance_info in self.instance_infos.items(): ++ role = instance_info.role ++ if role not in self.role_2_instances: ++ self.role_2_instances[role] = [] ++ self.role_2_instances[role].append(instance_info) ++ ++ asyncio.create_task(self.monitor_instance_status(instance_id)) ++ ++ ++ async def monitor_instance_status(self, instance_id): ++ """ ++ Monitor the status of the VLLM instance until it becomes RUNNING. ++ """ ++ start_time = time.time() ++ timeout = 300 # 5 minutes timeout ++ ++ # Check if instance is ready ++ while self.instance_infos[instance_id].status != VllmInstanceStatus.RUNNING: ++ # Check for timeout ++ if time.time() - start_time > timeout: ++ logger.error(f"Timeout waiting for instance {instance_id} to become RUNNING") ++ return ++ ++ logger.info('Waiting for new instance to become ready') ++ await asyncio.sleep(8) ++ ++ # Once the instance is ready, add it to monitoring_instance and begin monitoring its health ++ self.monitoring_instance[instance_id] = self.instance_infos[instance_id] ++ logger.info(f"Instance {instance_id} is now RUNNING and added to monitoring") ++ ++ async def update_vllm_instance_health(self, vllm_instance_info: List[VllmInstanceInfo]) -> bool: ++ """ ++ Update health status of VLLM instances. ++ ++ Args: ++ vllm_instance_info: List of VllmInstanceInfo objects containing information ++ ++ Returns: ++ bool: True if update was successful ++ """ ++ ++ current_time = time.time() ++ for item in vllm_instance_info: ++ self.instance_infos[item.id] = item ++ self.last_heartbeat[item.id] = current_time ++ return True ++ ++ async def check_instances_ready_and_monitor(self): ++ """ ++ Wait until all VLLM actor running status ++ ++ Returns: ++ No return value. End of function when all actor ready,. ++ """ ++ logger.info(f"Start checking if all instances are ready.") ++ if not self._controller_handle: ++ try: ++ self._controller_handle = ray.get_actor(name=CONTROLLER_ACTOR_NAME, namespace=DLLM_NAMESPACE) ++ except BaseException: ++ logger.error('get _controller_handle fail') ++ _get_expected_vllm_actors_num = await self._controller_handle._get_expected_vllm_actors_num.remote() ++ while self._get_ready_vllm_actors_num() < _get_expected_vllm_actors_num: ++ try: ++ logger.debug("expect %d waiting vllm actor, %s", ++ self._get_ready_vllm_actors_num(), self.instance_infos) ++ for s in self.instance_infos.values(): ++ if s.status == VllmInstanceStatus.SUBPROCESS_EXITED: ++ raise RuntimeError(f"vllm instance: {s} exited unexpectedly") ++ await asyncio.sleep(1) ++ except Exception as e: ++ logger.error(f"An error when waiting vllm instances ready: {e}") ++ return ++ logger.info(f"All instances are already") ++ asyncio.create_task(self._monitor_instance_health()) ++ ++ async def remove_failed_instance(self, instance_id: str): ++ """ ++ Controller调用此方法通知Balancer删除实例 ++ """ ++ async with self.lock: ++ if instance_id not in self.reported_failures: ++ self.reported_failures.add(instance_id) ++ ++ if instance_id in self.instance_infos: ++ self.instance_infos.pop(instance_id, None) ++ ++ if instance_id in self.last_heartbeat: ++ self.last_heartbeat.pop(instance_id, None) ++ ++ if instance_id in self.instance_metrics: ++ self.instance_metrics.pop(instance_id, None) ++ ++ if instance_id in self.monitoring_instance: ++ self.monitoring_instance.pop(instance_id, None) ++ # 重建role映射 ++ self.role_2_instances.clear() ++ for _, instance_info in self.instance_infos.items(): ++ role = instance_info.role ++ if role not in self.role_2_instances: ++ self.role_2_instances[role] = [] ++ self.role_2_instances[role].append(instance_info) ++ logger.info(f"Removed failed instance {instance_id} from balancer") ++ ++ async def report_failures_to_controller(self, failed_instances: List[VllmInstanceInfo]) -> None: ++ """批量报告失败的实例""" ++ for instance in failed_instances: ++ instance_id = instance.id ++ try: ++ await self._controller_handle.report_failure_from_balancer.remote(instance_id) ++ logger.info(f"Sccuessful rebuild instance.") ++ except Exception as e: ++ logger.error(f"Failed to report failure for instance {instance_id}: {str(e)}") ++ ++ def _round_robin_selection(self, role: Role) -> str: ++ instances = [item.uri for i, item in self.instance_infos.items() if item.role == role and item.uri is not None] ++ if role == Role.PREFILL: ++ instance = instances[self._round_robin_index_p] ++ self._round_robin_index_p = (self._round_robin_index_p + 1) % len(instances) ++ if role == Role.DECODE: ++ instance = instances[self._round_robin_index_d] ++ self._round_robin_index_d = (self._round_robin_index_d + 1) % len(instances) ++ if role == Role.MIXED: ++ instance = instances[self._round_robin_index_m] ++ self._round_robin_index_m = (self._round_robin_index_m + 1) % len(instances) ++ return instance ++ ++ async def _query_instance_metrics(self, session, instance_info): ++ ins_uri = instance_info.uri ++ ins_id = instance_info.id ++ async with session.post(f"{ins_uri}/metrics", timeout=3) as resp: ++ resp_code = resp.status ++ if resp_code != constants.HTTP_OK: ++ logger.error(f"get metrics failed, uri:{ins_uri}, code:{resp_code}") ++ return ++ resp_body = await resp.text() ++ metrics_dict = { ++ metric_family.name: metric_family.samples[0].value ++ for metric_family in text_string_to_metric_families(resp_body) ++ if metric_family.name in MetricsInfo.METRIC_NAME_MAPPING.values() and metric_family.samples ++ } ++ if not metrics_dict: ++ return ++ if ins_id not in self.instance_metrics: ++ self.instance_metrics[ins_id] = MetricsInfo() ++ metric_info = self.instance_metrics[ins_id] ++ for param_name, metric_name in MetricsInfo.METRIC_NAME_MAPPING.items(): ++ if metric_name not in metrics_dict: ++ continue ++ # data type conversion ++ target_type = metric_info.__annotations__[param_name] ++ setattr(metric_info, param_name, target_type(metrics_dict[metric_name])) ++ logger.debug("instance metrics info: %s", self.instance_metrics) ++ ++ def _round_robin_pair(self) -> DispatchResult: ++ # current policy: if has mixed, use mixed ++ is_pd_disagged = (Role.MIXED not in self.role_2_instances ++ or len(self.role_2_instances.get(Role.MIXED, None)) == 0) ++ if not is_pd_disagged: ++ mixed_uri = self._round_robin_selection(Role.MIXED) ++ return DispatchResult(prefill_vllm_instance_uri=None, decode_vllm_instance_uri=mixed_uri) ++ ++ prefill_uri = self._round_robin_selection(Role.PREFILL) ++ decode_uri = self._round_robin_selection(Role.DECODE) ++ return DispatchResult(prefill_vllm_instance_uri=prefill_uri, decode_vllm_instance_uri=decode_uri) ++ ++ async def _monitor_instance_health(self): ++ """ ++ Monitor instance health, report to controller if >20s no response / failed status ++ """ ++ self.monitoring_instance = dict(self.instance_infos) ++ while True: ++ current_time = time.time() ++ instances_to_report = [] ++ async with self.lock: ++ for instance_id, instance_info in self.monitoring_instance.items(): ++ if instance_id in self.reported_failures: ++ continue ++ logger.debug("Monitoring ID: %d, Status: %s", instance_id, instance_info.status) ++ ++ if instance_info.status == VllmInstanceStatus.HEALTHCHECK_FAILED: ++ logger.error(f"Instance {instance_id} has failed health check.") ++ instances_to_report.append(instance_info) ++ self.reported_failures.add(instance_id) ++ ++ # Consider instance unhealthy if no heartbeat ++ elif current_time - self.last_heartbeat.get(instance_id, 0) > 20: ++ logger.error(f"Instance {instance_id} is unhealthy (no heartbeat).") ++ instances_to_report.append(instance_info) ++ self.reported_failures.add(instance_id) ++ ++ logger.info(f'instances_to_report is {instances_to_report}') ++ if instances_to_report: ++ logger.info(f"Reporting failed instances") ++ await self.report_failures_to_controller(instances_to_report) ++ instances_to_report = [] ++ ++ await asyncio.sleep(5) ++ ++ def _get_ready_vllm_actors_num(self): ++ """ ++ Get the number of ready VLLM instances. ++ ++ Returns: ++ Number of ready VLLM instances. ++ """ ++ return sum(info.status == VllmInstanceStatus.RUNNING for info in self.instance_infos.values()) ++ ++ def _get_unready_vllm_actors_num(self): ++ """ ++ Get the number of unready VLLM instances. ++ ++ Returns: ++ Number of unready VLLM instances. ++ """ ++ return sum(info.status != VllmInstanceStatus.RUNNING for info in self.instance_infos.values()) +diff --git a/dllm_tools/dllm/balancer/policy/README.md b/dllm_tools/dllm/balancer/policy/README.md +new file mode 100644 +index 000000000..86fb70311 +--- /dev/null ++++ b/dllm_tools/dllm/balancer/policy/README.md +@@ -0,0 +1,2 @@ ++扩展各个调度器的各种策略 ++继承已有调度器的基类,实现策略 +\ No newline at end of file +diff --git a/dllm_tools/dllm/balancer/policy/__init__.py b/dllm_tools/dllm/balancer/policy/__init__.py +new file mode 100644 +index 000000000..c63e836b7 +--- /dev/null ++++ b/dllm_tools/dllm/balancer/policy/__init__.py +@@ -0,0 +1,12 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. +diff --git a/dllm_tools/dllm/config.py b/dllm_tools/dllm/config.py +new file mode 100644 +index 000000000..5f29aab5a +--- /dev/null ++++ b/dllm_tools/dllm/config.py +@@ -0,0 +1,164 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from dataclasses import dataclass ++from typing import List, Optional ++ ++from dllm.entities import SchedulerPolicy, Role ++ ++ ++# [ (long options) , (short options) ] ++vllm_options_should_be_filled_by_dllm = [ ++ # host and port ++ ("host",), ++ ("port",), ++ # tp config, which should use --p-tp/--d-tp ++ ("tensor-parallel-size", "tp"), ++ # dp config, which should use --p-dp/--d-dp ++ ("data-parallel-size", "dp"), ++ ("data-parallel-size-local", "dpl"), ++ ("data-parallel-start-rank", "dpr"), ++ ("data-parallel-address", "dpa"), ++ ("data-parallel-rpc-port", "dpp"), ++ ("headless",), ++ # ep config ++ ("enable-expert-parallel",), ++ ("no-enable-expert-parallel",), ++ # pd disagg config, auto used if enable pd disagg ++ ("kv-transfer-config",), ++] ++ ++ ++class AutoValidator: ++ def __post_init__(self): ++ for name, _ in self.__dataclass_fields__.items(): ++ method = getattr(self, f"_validate_{name}", None) ++ if method: ++ method() ++ ++ ++@dataclass ++class InferenceInstanceConfig(AutoValidator): ++ startup_params: List[str] ++ startup_env: Optional[str] ++ tp: Optional[int] ++ dp: Optional[int] ++ ep: Optional[int] ++ ++ def _validate_startup_params(self): ++ def __contain_long_options(opname, params): ++ underline_op = opname.replace('-', '_') ++ return any(p == f"--{opname}" or p.startswith(f"--{opname}=") or ++ p == f"--{underline_op}" or p.startswith(f"--{underline_op}=") ++ for p in params) ++ ++ def __contain_short_options(opname, params): ++ underline_op = opname.replace('-', '_') ++ return any(p == f"-{opname}" or p.startswith(f"-{opname}=") or ++ p == f"--{underline_op}" or p.startswith(f"--{underline_op}=") ++ for p in params) ++ ++ not_acceptable_options = [] ++ for opt in vllm_options_should_be_filled_by_dllm: ++ if len(opt) > 0 and __contain_long_options(opt[0], self.startup_params): ++ not_acceptable_options.append(opt[0]) ++ if len(opt) > 1 and __contain_short_options(opt[1], self.startup_params): ++ not_acceptable_options.append(opt[1]) ++ ++ if len(not_acceptable_options) > 0: ++ raise ValueError( ++ f"Options {not_acceptable_options} are reserved and should not be specified in start up command; " ++ "they must be populated by the dllm" ++ ) ++ ++ def _validate_ep(self): ++ if self.ep < 0: ++ raise ValueError("expert parallel size should be 0(disable) or >1(enable)") ++ ++ def _validate_dp(self): ++ if not self.dp > 0: ++ raise ValueError("data parallel size should be greater than 0") ++ ++ def _validate_tp(self): ++ if not self.tp > 0: ++ raise ValueError("tensor parallel size should be greater than 0") ++ ++ ++@dataclass ++class ControllerConfig(AutoValidator): ++ """ ++ prefill_instances_num: Number of P (Prefill) instances to start ++ prefill_startup_params: Common startup parameters for P instances ++ decode_instances_num: Number of D (Decode) instances to start ++ decode_startup_params: Common startup parameters for D instances ++ scheduler_policy: Scheduling policy enum ++ """ ++ ++ scheduler_policy: SchedulerPolicy ++ ++ prefill_instances_num: int ++ p_inference_instance_config: InferenceInstanceConfig ++ decode_instances_num: int ++ d_inference_instance_config: InferenceInstanceConfig ++ ++ def _validate_prefill_instances_num(self): ++ if self.prefill_instances_num < 0: ++ raise ValueError("prefill instance num should be equal to or greater than 0") ++ ++ def _validate_decode_instances_num(self): ++ if self.decode_instances_num < 0: ++ raise ValueError("decode instance num should be equal to or greater than 0") ++ ++ ++@dataclass ++class VllmInstancePDConfig(AutoValidator): ++ role: Role ++ pd_rank: Optional[int] = 0 ++ pd_size: Optional[int] = 0 ++ ++ def is_disaggregated_p_d(self): ++ """judge if in the pd disaggregated mode""" ++ return self.role != Role.MIXED ++ ++ ++@dataclass ++class VllmInstanceDPConfig(AutoValidator): ++ dp_rank: Optional[int] = 0 ++ dp_size: Optional[int] = 0 ++ dp_local_size: Optional[int] = 1 ++ dp_master_ip: Optional[str] = "" ++ dp_master_port: Optional[int] = 0 ++ ++ def is_dp_enabled(self): ++ """judge if dp is enabled""" ++ return self.dp_size > 0 ++ ++ ++@dataclass ++class VllmInstanceEPConfig(AutoValidator): ++ ep_size: Optional[int] = 0 ++ ++ def is_ep_enabled(self): ++ """judge if ep is enabled""" ++ return self.ep_size > 0 ++ ++ ++@dataclass ++class VllmInstanceConfig(AutoValidator): ++ exec_cmd: List[str] ++ env: Optional[str] = None ++ tp: Optional[int] = 0 ++ pd_config: Optional[VllmInstancePDConfig] = None ++ dp_config: Optional[VllmInstanceDPConfig] = None ++ ep_config: Optional[VllmInstanceEPConfig] = None +diff --git a/dllm_tools/dllm/constants.py b/dllm_tools/dllm/constants.py +new file mode 100644 +index 000000000..c61781796 +--- /dev/null ++++ b/dllm_tools/dllm/constants.py +@@ -0,0 +1,42 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++DLLM_NAMESPACE = "dllm" ++CONTROLLER_ACTOR_NAME = "controller" ++BALANCER_ACTOR_NAME = "balancer" ++ ++ENDPOINT_APPLICATION_NAME = "dllm-pd-endpoint" ++ENDPOINT_PROXY_DEPLOYMENT_NAME = "dllm-pd-endpoint" ++ ++VLLM_INSTANCE_HEALTH_CHECK_INTERVAL_S = 10 ++ ++HTTP_OK = 200 ++ ++HTTP_PARAM_INVALID = 400 ++ ++HTTP_TOO_MANY_REQUESTS = 429 ++ ++HTTP_INTERNAL_ERROR = 500 ++ ++# The number of running requests on VLLM instances ++NUM_REQUESTS_RUNNING = "vllm:num_requests_running" ++ ++# The number of waiting requests on VLLM instances ++NUM_REQUESTS_WAITING = "vllm:num_requests_waiting" ++ ++# The usage of gpu cache on VLLM instances ++GPU_CACHE_USAGE_PERC = "vllm:gpu_cache_usage_perc" ++ ++# Time Unit: s ++METRICS_UPDATE_CYCLE = 0.5 +diff --git a/dllm_tools/dllm/controller/README.md b/dllm_tools/dllm/controller/README.md +new file mode 100644 +index 000000000..5ea727787 +--- /dev/null ++++ b/dllm_tools/dllm/controller/README.md +@@ -0,0 +1,7 @@ ++推理集群生命周期管理: ++1、支持各种并行策略的部署,包括模型加载、建链等功能 ++2、服务发现 ++3、健康检查 ++4、故障重启 ++5、扩缩容 ++6、快速弹性 +\ No newline at end of file +diff --git a/dllm_tools/dllm/controller/__init__.py b/dllm_tools/dllm/controller/__init__.py +new file mode 100644 +index 000000000..4b63cd569 +--- /dev/null ++++ b/dllm_tools/dllm/controller/__init__.py +@@ -0,0 +1,15 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from dllm.controller.controller import Controller +diff --git a/dllm_tools/dllm/controller/controller.py b/dllm_tools/dllm/controller/controller.py +new file mode 100644 +index 000000000..31c643132 +--- /dev/null ++++ b/dllm_tools/dllm/controller/controller.py +@@ -0,0 +1,528 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import itertools ++import asyncio ++from typing import List, Dict ++import logging ++import ray ++from ray.util.placement_group import PlacementGroup ++ ++from dllm.entities import ActorInstanceInfo, InstanceInfo, Role, VllmInstanceInfo ++from dllm.constants import BALANCER_ACTOR_NAME ++from dllm.config import ( ++ ControllerConfig, ++ VllmInstanceConfig, ++ VllmInstancePDConfig, ++ VllmInstanceDPConfig, ++ VllmInstanceEPConfig, ++ InferenceInstanceConfig, ++) ++from dllm.controller.vllm_instance import start_vllm_instance ++from dllm.balancer.balancer import Balancer ++ ++logger = logging.getLogger(__name__) ++ ++ ++def flatten_list(multi_level_list): ++ """ ++ 将嵌套的列表展平成一维列表。 ++ ++ Args: ++ multi_level_list Any: 一个二维列表或嵌套列表。 ++ ++ Returns: ++ list[Any]: 展平后的一维列表,包含原列表中所有元素,保持顺序。 ++ """ ++ return list(itertools.chain(*multi_level_list)) ++ ++ ++def _get_npu_num_per_ray_node(): ++ npu_nums = [] ++ for e in ray.nodes(): ++ num = e.get("Resources", {}).get("NPU", None) ++ if num is not None: ++ npu_nums.append(int(num)) ++ return max(npu_nums) ++ ++ ++def split_dp_resources(tp_size: int, dp_size: int, npu_pack_max_size: int = 8) -> List[int]: ++ """ ++ split dp resources into some packed groups, like ++ ++ | DP | TP | total | 910C | 910B | ++ | ---- | ---- | ----- | ----- | ------- | ++ | 4 | 2 | 8 | 8 | 8 | ++ | 3 | 3 | 9 | 9 | 6+3 | ++ | 4 | 4 | 16 | 16 | 8+8 | ++ | 32 | 1 | 32 | 16+16 | 8+8+8+8 | ++ | 64 | 1 | 64 | 16*4 | 8*8 | ++ ++ right now, we don't care about resource fragments ++ ++ Returns: ++ list of npu nums ++ """ ++ if tp_size <= npu_pack_max_size: ++ # TP size is within the allowed limit; no action needed ++ pass ++ else: ++ raise ValueError(f"When enabling DP, the TP size ({tp_size}) should not exceed the " ++ f"maximum allowed size ({npu_pack_max_size}) on a single machine.") ++ total_npu = dp_size * tp_size ++ group_size = ( ++ npu_pack_max_size - (npu_pack_max_size % tp_size) if npu_pack_max_size % tp_size != 0 else npu_pack_max_size ++ ) ++ num_groups = total_npu // group_size ++ remainder = total_npu % group_size ++ packs = [group_size] * num_groups ++ if remainder > 0: ++ packs.append(remainder) ++ return packs ++ ++ ++class Controller: ++ def __init__(self, controller_config: ControllerConfig): ++ """ ++ Initialize the global controller. ++ ++ Args: ++ controller_config: ControllerConfig ++ """ ++ self.config = controller_config ++ ++ self.p_instances_actors: List[ray.actor.ActorHandle] = [] ++ self.d_instances_actors: List[ray.actor.ActorHandle] = [] ++ self.vllm_instances_info: Dict[str, VllmInstanceInfo] = {} ++ self.actor_instance_info: Dict[str, ActorInstanceInfo] = {} ++ self.balancer = None ++ ++ self.dp_groups: Dict[str, Dict] = {} # group_key (string) -> group info ++ self.instance_to_group: Dict[str, str] = {} # instance_id -> group_key ++ ++ async def make_inference_instance( ++ self, pd_role: Role, pd_rank: int, inference_instance_config: InferenceInstanceConfig ++ ) -> List[ray.actor.ActorHandle]: ++ """make inference instance (PREFILL instance, or DECODE instance) ++ 1. if dp enabled, ==> start dp group ++ 2. if dp not enabled, ==> just start vllm instance ++ ++ Returns: ++ all vllm instances actors in this inference instance ++ """ ++ if inference_instance_config.dp > 1: ++ # enable dp ++ return await make_dp_group( ++ controller=self, ++ pd_role=pd_role, ++ pd_idx=pd_rank, ++ tp_size=inference_instance_config.tp, ++ dp_size=inference_instance_config.dp, ++ ep_size=inference_instance_config.ep, ++ start_params=inference_instance_config.startup_params, ++ env=inference_instance_config.startup_env, ++ ) ++ ++ # no dp ++ return [ ++ self.create_vllm_instance( ++ VllmInstanceConfig( ++ exec_cmd=inference_instance_config.startup_params, ++ env=inference_instance_config.startup_env, ++ tp=inference_instance_config.tp, ++ pd_config=VllmInstancePDConfig(role=pd_role, pd_rank=pd_rank), ++ dp_config=VllmInstanceDPConfig(), ++ ep_config=VllmInstanceEPConfig(inference_instance_config.ep), ++ ) ++ ) ++ ] ++ ++ async def make_balancer(self) -> List[ray.actor.ActorHandle]: ++ """make balancer, and send all vllm instance info to the balancer ++ ++ Returns: ++ balancer handle ++ """ ++ balancer = ray.remote(Balancer).options(name=BALANCER_ACTOR_NAME).remote(policy=self.config.scheduler_policy) ++ return balancer ++ ++ async def initialize(self): ++ """initialize all vllm instances, construct pd/dp groups""" ++ logger.info(f"initialize with config: {self.config}") ++ # Dictionary to track VLLM instances health status ++ self.vllm_instances_info: Dict[str, VllmInstanceInfo] = {} # ++ ++ # start VllmInstance ++ # start Prefill Instances ++ is_disaggregated_pd = self.config.prefill_instances_num > 0 and self.config.decode_instances_num > 0 ++ for p_pd_rank in range(self.config.prefill_instances_num): ++ p_actors = self.make_inference_instance( ++ pd_rank=p_pd_rank, ++ pd_role=Role.PREFILL if is_disaggregated_pd else Role.MIXED, ++ inference_instance_config=self.config.p_inference_instance_config, ++ ) ++ self.p_instances_actors.extend(await p_actors) ++ ++ # start Decode Instances ++ for d_pd_rank in range(self.config.decode_instances_num): ++ d_actors = self.make_inference_instance( ++ pd_rank=d_pd_rank, ++ pd_role=Role.DECODE if is_disaggregated_pd else Role.MIXED, ++ inference_instance_config=self.config.d_inference_instance_config, ++ ) ++ self.d_instances_actors.extend(await d_actors) ++ ++ logger.info(f"Create Balancer") ++ self.balancer = await self.make_balancer() ++ logger.info(f"Finished create Balancer") ++ ++ # init all vllm instances ++ for vllm_instance_actor in [*self.p_instances_actors, *self.d_instances_actors]: ++ vllm_instance_actor.initialize.remote() ++ ++ # wait for all instances ready ++ await self.balancer.check_instances_ready_and_monitor.remote() ++ ++ logger.info(f"All instances ready, VllmInstance num: {len(self.vllm_instances_info)}, updating Balancer") ++ ++ # update Balancer ++ self.balancer.update_vllm_instance_info.remote(list(self.vllm_instances_info.values())) ++ ++ logger.info( ++ f"Controller initialized with {self.config.prefill_instances_num} P instances and " ++ f"{self.config.decode_instances_num} D instances" ++ ) ++ ++ async def terminate(self, timeout_s=5) -> None: ++ """ ++ TODO: clean all dllm actors started by controller ++ """ ++ if self.balancer: ++ ray.kill(self.balancer) ++ ++ terminate_futures = [] ++ for actor in [*self.p_instances_actors, *self.d_instances_actors]: ++ terminate_futures.append(actor.terminate.remote(timeout_s=timeout_s)) ++ await asyncio.gather(*terminate_futures) ++ ++ for actor in [*self.p_instances_actors, *self.d_instances_actors]: ++ ray.kill(actor) ++ ++ ++ async def report_failure_from_balancer(self, instance_id: str) -> None: ++ """ ++ Report a failure for an instance from the balancer. ++ ++ Args: ++ instance_id (str): The ID of the instance that has failed. ++ ++ Returns: ++ bool: True if the report was handled successfully, False otherwise. ++ """ ++ logger.info(f"Received report from balancer, instance_id is {instance_id} ") ++ actor_info = self.actor_instance_info.get(instance_id) ++ if actor_info.pg is None: ++ await self.balancer.remove_failed_instance.remote(instance_id) ++ await self.terminate_instance(actor_info.actor) ++ logger.info("Actor %s terminated, will restart the new actor", instance_id) ++ ++ actor_handle = actor_info.actor ++ in_p_instances_actors = actor_handle in self.p_instances_actors ++ in_d_instances_actors = actor_handle in self.d_instances_actors ++ if in_p_instances_actors: ++ self.p_instances_actors = [a for a in self.p_instances_actors if a != actor_handle] ++ else: ++ self.d_instances_actors = [a for a in self.d_instances_actors if a != actor_handle] ++ ++ self.actor_instance_info.pop(instance_id, None) ++ ++ max_retries = 3 ++ retry_count = 0 ++ restart_success = False ++ new_actor = None ++ while retry_count < max_retries and not restart_success: ++ retry_count += 1 ++ try: ++ new_actor = await self.restart_instance(actor_info) ++ except Exception as restart_err: ++ if new_actor is not None: ++ ray.kill(new_actor) ++ new_actor = None ++ logger.error(f"Attempt {retry_count} to restart actor {instance_id} failed:") ++ logger.error(f"Restart error: {str(restart_err)}") ++ if retry_count < max_retries: ++ logger.info(f"Retrying restart for actor {instance_id}...") ++ else: ++ logger.error(f"CRITICAL: FAILED to restart actor {instance_id} after {max_retries} attempts") ++ else: ++ logger.info(f"New actor created successfully: {new_actor}") ++ restart_success = True ++ ++ if in_p_instances_actors: ++ self.p_instances_actors.append(new_actor) ++ elif in_d_instances_actors: ++ self.d_instances_actors.append(new_actor) ++ ++ if restart_success: ++ logger.info("New Actor has been restarted") ++ else: ++ #rebuild the placement group ++ logger.info(f"Actor {instance_id} in a placement group, will rebuild entire group") ++ group_key = self.instance_to_group.get(instance_id) ++ # 获取整个 group 的实例 ID ++ group_info = self.dp_groups.get(group_key, {}) ++ group_instance_ids = group_info.get("instance_ids", []) ++ for single_instance_id in group_instance_ids: ++ await self.balancer.remove_failed_instance.remote(single_instance_id) ++ await self.rebuild_dp_group(group_key) ++ logger.info("New DP group has been restarted") ++ ++ ++ ++ def create_vllm_instance(self, vllm_instance_config: VllmInstanceConfig, pg: PlacementGroup = None ++ ) -> ray.actor.ActorHandle: ++ """ ++ Create and start a new VLLM instance. ++ ++ Args: ++ vllm_instance_config (VllmInstanceConfig): Configuration for the VLLM instance. ++ pg (PlacementGroup, optional): Placement group for the instance. ++ ++ Returns: ++ ray.actor.ActorHandle: The handle of the created VLLM instance. ++ """ ++ role = vllm_instance_config.pd_config.role.name ++ pd_rank = vllm_instance_config.pd_config.pd_rank ++ instance_id = f"vllm-instance-{role}-{pd_rank}" ++ ++ if vllm_instance_config.dp_config.dp_size > 1: ++ # DP env should be set by `init_dp_config` method ++ dp_rank = vllm_instance_config.dp_config.dp_rank ++ dp_local_size = vllm_instance_config.dp_config.dp_local_size ++ ++ instance_id = ( ++ f"{instance_id}-DP-{dp_rank}-" ++ f"{dp_rank+ dp_local_size}" ++ ) ++ group_key = f"DP-{role}-{pd_rank}" ++ ++ if group_key not in self.dp_groups: ++ self.dp_groups[group_key] = { ++ "instance_ids": [], ++ "config": vllm_instance_config, ++ "pd_role": role, ++ "pd_rank": pd_rank ++ } ++ ++ self.instance_to_group[instance_id] = group_key ++ self.dp_groups[group_key]["instance_ids"].append(instance_id) ++ ++ actor = start_vllm_instance(vllm_instance_config=vllm_instance_config, pg=pg, name=instance_id) ++ self.actor_instance_info[instance_id] = ActorInstanceInfo( ++ actor=actor, ++ config=vllm_instance_config, ++ pg=pg, ++ instance_id=instance_id ++ ) ++ return actor ++ ++ ++ async def terminate_instance(self, actor_to_terminate: ray.actor.ActorHandle) -> None: ++ """ ++ Terminate a given actor instance. ++ ++ Args: ++ actor_to_terminate (ray.actor.ActorHandle): The actor instance to terminate. ++ """ ++ try: ++ logger.info("start terminated") ++ await asyncio.gather(actor_to_terminate.terminate.remote(timeout_s=5)) ++ ray.kill(actor_to_terminate) ++ logger.info("finished terminated") ++ except Exception as e: ++ logger.error(f"Cannot terminate instance, it maybe already killed: {str(e)}") ++ ++ ++ async def restart_instance(self, actor_info: Dict[str, ActorInstanceInfo]) -> ray.actor.ActorHandle: ++ """ ++ Restart a VLLM instance based on the provided actor information (Should be same as the before one). ++ ++ Args: ++ actor_info (dict): Information about the instance to restart. ++ ++ Returns: ++ ray.actor.ActorHandle: The newly created actor instance. ++ """ ++ new_actor: ray.actor.ActorHandle = self.create_vllm_instance( ++ vllm_instance_config=actor_info.config, ++ pg=actor_info.pg ++ ) ++ new_actor.initialize.remote() ++ logger.info("New actor restarted successfully") ++ return new_actor ++ ++ async def rebuild_dp_group(self, group_key: str): ++ """ ++ Rebuild an entire data parallel (DP) group after failure. ++ ++ This method handles the complete reconstruction of a DP group by: ++ 1. Terminating all existing actors in the group ++ 2. Cleaning up related state information ++ 3. Recreating new actors using the original configuration ++ 4. Reinitializing the new actors ++ 5. Rebuilding the DP group's metadata and actor lists ++ ++ Args: ++ group_key (str): Unique identifier for the DP group to rebuild. ++ Format should be "DP-{role}-{pd_rank}". ++ ++ Returns: ++ bool: True if the group was successfully rebuilt, ++ False if the group wasn't found or an error occurred. ++ """ ++ if group_key not in self.dp_groups: ++ logger.error(f"DP group {group_key} not found") ++ return ++ ++ group_info = self.dp_groups[group_key] ++ ++ instance_list_info = await self._collect_and_terminate_group_instances(group_key, group_info) ++ ++ for instance_id in group_info["instance_ids"]: ++ if instance_id in self.actor_instance_info: ++ del self.actor_instance_info[instance_id] ++ if instance_id in self.instance_to_group: ++ del self.instance_to_group[instance_id] ++ ++ for _, info in instance_list_info.items(): ++ if info["in_p"]: ++ self.p_instances_actors = [a for a in self.p_instances_actors if a != info["actor"]] ++ if info["in_d"]: ++ self.d_instances_actors = [a for a in self.d_instances_actors if a != info["actor"]] ++ # 使用原配置重建 ++ new_actors = await make_dp_group( ++ controller=self, ++ pd_role=group_info["pd_role"], ++ pd_idx=group_info["pd_rank"], ++ tp_size=group_info["config"].tp, ++ dp_size=group_info["config"].dp_config.dp_size, ++ ep_size=group_info["config"].ep_config.ep_size, ++ start_params=group_info["config"].exec_cmd, ++ env=group_info["config"].env ++ ) ++ ++ self.dp_groups[group_key]["instance_ids"] = [ ++ info["instance_id"] ++ for info in self.actor_instance_info.values() ++ if self.instance_to_group.get(info["instance_id"]) == group_key ++ ] ++ ++ init_tasks = [actor.initialize.remote() for actor in new_actors] ++ await asyncio.gather(*init_tasks) ++ ++ # 根据之前记录的信息,将新actor添加到正确的列表 ++ for i, new_actor in enumerate(new_actors): ++ old_instance_id = list(instance_list_info.keys())[i] ++ list_info = instance_list_info[old_instance_id] ++ ++ if list_info["in_p"]: ++ self.p_instances_actors.append(new_actor) ++ if list_info["in_d"]: ++ self.d_instances_actors.append(new_actor) ++ ++ logger.info(f"Successfully rebuilt DP group {group_key} with {len(new_actors)} instances") ++ ++ def _get_expected_vllm_actors_num(self): ++ return len(self.p_instances_actors) + len(self.d_instances_actors) ++ ++ async def _collect_and_terminate_group_instances(self, group_key: str, group_info: Dict[str, List[str]] ++ ) -> Dict[str, InstanceInfo]: ++ """Collect instance info and terminate all actors in the group""" ++ instance_list_info: Dict[str, InstanceInfo] = {} ++ terminate_tasks: List[asyncio.Task] = [] ++ ++ for instance_id in group_info["instance_ids"]: ++ actor_info = self.actor_instance_info[instance_id] ++ actor_handle: ray.actor.ActorHandle = actor_info.actor ++ ++ # Record which lists the instance belongs to ++ instance_list_info[instance_id] = InstanceInfo( ++ actor=actor_handle, ++ in_p=actor_handle in self.p_instances_actors, ++ in_d=actor_handle in self.d_instances_actors ++ ) ++ ++ # Schedule termination ++ logger.info(f"Terminating instance {instance_id} in group {group_key}") ++ terminate_tasks.append(self.terminate_instance(actor_handle)) ++ ++ # Execute all terminations in parallel ++ if terminate_tasks: ++ await asyncio.gather(*terminate_tasks) ++ ++ return instance_list_info ++ ++ ++async def make_dp_group( ++ controller: Controller, pd_role: Role, pd_idx: int, tp_size: int, dp_size: int, ep_size: int, ++ start_params: List[str], env: str = None ++) -> List[ray.actor.ActorHandle]: ++ """ ++ prepare one dp group ++ 1. start dp master vllm instance ++ 1.1. find dp master ip and a free port as dp master port ++ 1.2. init dp master vllm instance's dp config ++ 2. start other dp vllm instances with dp master ip and dp master port ++ """ ++ packs = split_dp_resources( ++ tp_size=tp_size, dp_size=dp_size, npu_pack_max_size=_get_npu_num_per_ray_node() ++ ) ++ pg = ray.util.placement_group(bundles=[{"NPU": p} for p in packs], strategy="PACK", name=f"DP-{pd_role}-{pd_idx}") ++ await pg.ready() ++ ++ actors = [] ++ dp_master_vllm_instance_config = VllmInstanceConfig( ++ exec_cmd=start_params, ++ env=env, ++ tp=tp_size, ++ pd_config=VllmInstancePDConfig(role=pd_role, pd_rank=pd_idx), ++ dp_config=VllmInstanceDPConfig(dp_rank=0, dp_size=dp_size, dp_local_size=packs[0] // tp_size), ++ ep_config=VllmInstanceEPConfig(ep_size=ep_size), ++ ) ++ dp_master_actor = controller.create_vllm_instance(vllm_instance_config=dp_master_vllm_instance_config, pg=pg) ++ actors.append(dp_master_actor) ++ ++ dp_master_ip, dp_master_port = await dp_master_actor.init_dp_master_ip_port.remote() ++ dp_master_vllm_instance_config.dp_config.dp_master_ip = dp_master_ip ++ dp_master_vllm_instance_config.dp_config.dp_master_port = dp_master_port ++ await dp_master_actor.init_dp_config.remote(dp_master_vllm_instance_config.dp_config) ++ ++ dp_rank = packs[0] // tp_size ++ for idx in range(1, len(packs)): ++ dp_vllm_instance_config = VllmInstanceConfig( ++ exec_cmd=start_params, ++ env=env, ++ tp=tp_size, ++ pd_config=VllmInstancePDConfig(role=pd_role, pd_rank=pd_idx), ++ dp_config=VllmInstanceDPConfig( ++ dp_rank=dp_rank, dp_size=dp_size, dp_master_ip=dp_master_ip, dp_master_port=dp_master_port, ++ dp_local_size=packs[idx] // tp_size, ++ ), ++ ep_config=VllmInstanceEPConfig(ep_size=ep_size), ++ ) ++ dp_rank += packs[idx] // tp_size ++ actor = controller.create_vllm_instance(vllm_instance_config=dp_vllm_instance_config, pg=pg) ++ await actor.init_dp_config.remote(dp_vllm_instance_config.dp_config) ++ actors.append(actor) ++ return actors +diff --git a/dllm_tools/dllm/controller/endpoint.py b/dllm_tools/dllm/controller/endpoint.py +new file mode 100644 +index 000000000..69694ad2a +--- /dev/null ++++ b/dllm_tools/dllm/controller/endpoint.py +@@ -0,0 +1,282 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++import json ++import os ++from typing import Union ++ ++import asyncio ++import uuid ++import logging ++import aiohttp ++import ray ++from ray import serve ++from fastapi import FastAPI, Request ++from fastapi.responses import StreamingResponse, Response ++ ++from dllm.constants import ( ++ BALANCER_ACTOR_NAME, ++ DLLM_NAMESPACE, ++ ENDPOINT_PROXY_DEPLOYMENT_NAME, ++ ENDPOINT_APPLICATION_NAME, ++) ++from dllm.entities import DispatchResult ++ ++logger = logging.getLogger(__name__) ++ ++app = FastAPI() ++ ++ ++@serve.deployment( ++ name=ENDPOINT_PROXY_DEPLOYMENT_NAME, ++ num_replicas=1, ++ max_ongoing_requests=4096, ++) ++@serve.ingress(app) ++class ProxyDeployment: ++ #: the balancer handle ++ _balancer_handle: Union[ray.actor.ActorHandle, None] ++ ++ def __init__(self): ++ self._balancer_handle = None ++ ++ @staticmethod ++ async def record_exception_info(e): ++ """ ++ record exception info ++ Args: ++ e: exception info ++ """ ++ import sys ++ import traceback ++ exc_info = sys.exc_info() ++ logger.info("Error occurred in disagg prefill proxy server") ++ logger.info(e) ++ logger.info("".join(traceback.format_exception(*exc_info))) ++ ++ async def forward_request(self, url: str, headers: dict, data: dict): ++ """ ++ Send request to the inference instance, return the AsyncGenerator reading the content ++ Args: ++ url: request url ++ headers: request header ++ data: request data ++ Returns: ++ AsyncGenerator: the first iteration is the status code, and subsequent iterations are the response content ++ """ ++ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 60 * 60)) as session: ++ async with session.post(url=url, json=data, headers=headers) as response: ++ # Return status code in advance ++ yield response.status ++ if response.status == 200: ++ async for chunk_bytes in response.content.iter_chunked(1024): ++ yield chunk_bytes ++ else: ++ content = await response.read() ++ yield content ++ ++ async def forward_request_without_yield(self, url: str, headers: dict, data: dict): ++ """ ++ Asynchronously sends a POST request with JSON data and returns HTTP status code with raw content. ++ ++ This method uses aiohttp.ClientSession with a 6-hour total timeout to POST data to the specified URL. ++ Headers and JSON payload are provided by the caller without validation. Response content is returned ++ as bytes, requiring manual decoding by the consumer. ++ ++ Args: ++ url (str): Target endpoint including protocol and path (e.g. https://api.example.com/path) ++ headers (dict): Custom HTTP headers provided as dictionary. Authentication headers should be ++ explicitly included when required. ++ data (dict): Request body data to be JSON-serialized. Must be a serializable dictionary object. ++ ++ Returns: ++ Tuple[Any, Any]: 2-element tuple containing: ++ response.status: HTTP status code (e.g. 200, 404) ++ content: Raw response body content as byte array ++ ++ Raises: ++ aiohttp.ClientError: For network-related errors like failed DNS resolution or connection issues ++ asyncio.TimeoutError: When the request exceeds 6-hour timeout limit ++ ++ Example: ++ status, content = await forward_request_without_yield( ++ 'https://api.example.com/submit', ++ {'Authorization': 'Bearer token123'}, ++ {'key1': 'value1', 'key2': ['list_value']} ++ ) ++ if status == 200: ++ json_data = json.loads(content.decode('utf-8')) ++ ++ Requires: ++ aiohttp library must be installed for asynchronous HTTP client functionality ++ """ ++ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 60 * 60)) as session: ++ async with session.post(url=url, json=data, headers=headers) as response: ++ content = await response.read() ++ return response.status, content ++ ++ async def schedule(self, prompt: str) -> DispatchResult: ++ """ ++ Async scheduling method for model inference requests. ++ ++ Args: ++ prompt (str): Input text prompt to be processed by LLM. ++ ++ Returns: ++ DispatchResult: Object containing: ++ prefill_vllm_instance_uri: str ++ decode_vllm_instance_uri: str ++ """ ++ if self._balancer_handle is None: ++ self._balancer_handle = ray.get_actor(name=BALANCER_ACTOR_NAME, namespace=DLLM_NAMESPACE) ++ dispatch_result = await self._balancer_handle.dispatch_request.remote() ++ return dispatch_result ++ ++ @app.post("/health") ++ async def health(self, request: Request): ++ """ ++ Healthcheck endpoint to verify service availability. ++ ++ Returns a 200 OK response with "healthy" status to confirm the service ++ is running correctly. This is typically used by Kubernetes or load balancers ++ for liveness/readiness probes. ++ ++ Args: ++ request (Request): FastAPI request object ++ ++ Returns: ++ Response: HTTP 200 response with content "healthy" ++ ++ """ ++ return Response(status_code=200, content="healthy") ++ ++ async def viz_profile_internal(self, raw_request: Request, is_start: bool): ++ '''Start/stop viztracer profile, and send request to vllm''' ++ if os.environ.get("ENABLE_VIZTRACER_PROFILE", '0') != '1': ++ return Response(status_code=500, content="set env ENABLE_VIZTRACER_PROFILE = 1 to enable viz profile.") ++ from dllm.monkey_patch.viz_profile.common import viz_profile_basic ++ _success, _message = viz_profile_basic("dllm_endpoint", is_start) ++ if not _success: ++ return Response(status_code=500, content=_message) ++ ++ _balancer = ray.get_actor(name=BALANCER_ACTOR_NAME, namespace=DLLM_NAMESPACE) ++ _result = await _balancer.get_all_instance.remote() ++ _headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} ++ _tasks = [] ++ for _, item in _result.items(): ++ task = asyncio.create_task(self.forward_request_without_yield( ++ f"{item.uri}/{'start_viz_profile' if is_start else 'stop_viz_profile'}", ++ headers=_headers, ++ data={})) ++ _tasks.append(task) ++ _results = await asyncio.gather(*_tasks) ++ # also need to consider balancer process ++ _balancer = ray.get_actor(name=BALANCER_ACTOR_NAME, namespace=DLLM_NAMESPACE) ++ await _balancer.viz_profile.remote(is_start) ++ ++ return Response(status_code=200, content="success") ++ ++ ++ @app.post("/start_viz_profile") ++ async def start_viz_profile(self, raw_request: Request): ++ '''Start viztracer profile''' ++ return await self.viz_profile_internal(raw_request, True) ++ ++ @app.post("/stop_viz_profile") ++ async def stop_viz_profile(self, raw_request: Request): ++ '''Stop viztracer profile''' ++ return await self.viz_profile_internal(raw_request, False) ++ ++ @app.post("/v1/completions") ++ async def openai_completions(self, raw_request: Request): ++ """ ++ https://github.com/vllm-project/vllm/blob/main/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py ++ """ ++ import pydantic ++ from vllm.entrypoints.openai.protocol import CompletionRequest ++ ++ request_body = await raw_request.json() ++ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", ++ "X-Request-Id": raw_request.headers.get("X-Request-Id") or str(uuid.uuid4())} ++ ++ try: ++ request = CompletionRequest(**request_body) ++ except pydantic.ValidationError as e: ++ return Response(status_code=500, content={"error": str(e)}) ++ ++ dispatch_result = await self.schedule(request.prompt) ++ logger.info( ++ f"({headers['X-Request-Id']}) recv request: {request.prompt}, " ++ f"prefill to: {dispatch_result.prefill_vllm_instance_uri}," ++ f"decode to {dispatch_result.decode_vllm_instance_uri}" ++ ) ++ ++ try: ++ prefill_request = request_body.copy() ++ prefill_request['kv_transfer_params'] = { ++ "do_remote_decode": True, ++ "do_remote_prefill": False, ++ "remote_engine_id": None, ++ "remote_block_ids": None, ++ "remote_host": None, ++ "remote_port": None ++ } ++ prefill_request["max_tokens"] = 1 ++ prefill_request["stream"] = False ++ if "stream_options" in prefill_request: ++ del prefill_request["stream_options"] ++ if dispatch_result.prefill_vllm_instance_uri: ++ status_code, prefill_result = await self.forward_request_without_yield( ++ f"{dispatch_result.prefill_vllm_instance_uri}/v1/completions", ++ headers=headers, ++ data=prefill_request, ++ ) ++ if status_code != 200: ++ logger.error(f"prefill request failed, status code:{status_code}, content:{prefill_result}") ++ kv_transfer_params = json.loads(prefill_result.decode('utf-8')).get('kv_transfer_params', {}) ++ if kv_transfer_params: ++ request_body["kv_transfer_params"] = kv_transfer_params ++ ++ decode_token_generator = self.forward_request( ++ f"{dispatch_result.decode_vllm_instance_uri}/v1/completions", ++ headers=headers, ++ data=request_body, ++ ) ++ status_code = 200 ++ # Only iterate once, get the status code and transmit it transparently ++ async for status in decode_token_generator: ++ status_code = status ++ break ++ return StreamingResponse( ++ decode_token_generator, ++ status_code=status_code, ++ media_type="application/octet-stream", ++ ) ++ except Exception as e: ++ await self.record_exception_info(e) ++ raise ++ ++ ++def deploy_endpoint_to_cluster(host: str = "0.0.0.0", port: int = 8000): ++ """ ++ Deploys an API endpoint as a service in a distributed cluster. ++ ++ Args: ++ host (str): IP address for HTTP server to listen on. Default is 0.0.0.0 for unrestricted access. ++ port (int): Port number for HTTP server. Default is 8000. Must be available on target nodes. ++ ++ Returns: ++ None: Returns nothing since ray.serve.run() blocks indefinitely. ++ """ ++ serve.start(http_options=serve.HTTPOptions(host=host, port=port)) ++ serve.run(ProxyDeployment.bind(), name=ENDPOINT_APPLICATION_NAME) +diff --git a/dllm_tools/dllm/controller/vllm_instance.py b/dllm_tools/dllm/controller/vllm_instance.py +new file mode 100644 +index 000000000..ee48b84bf +--- /dev/null ++++ b/dllm_tools/dllm/controller/vllm_instance.py +@@ -0,0 +1,372 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import asyncio ++import json ++import uuid ++ ++import subprocess ++import sys ++import os ++import signal ++import logging ++from asyncio import Task ++import ray ++import aiohttp ++ ++from ray.util.placement_group import PlacementGroup ++from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy ++ ++from dllm.constants import DLLM_NAMESPACE, VLLM_INSTANCE_HEALTH_CHECK_INTERVAL_S, BALANCER_ACTOR_NAME ++from dllm.entities import Role, VllmInstanceInfo, VllmInstanceStatus ++from dllm.config import VllmInstanceConfig, VllmInstanceDPConfig ++from dllm.utils import find_node_ip, find_free_port, find_interface_by_ip, find_ip_by_interface, get_num_npus ++ ++logger = logging.getLogger(__name__) ++ ++ ++def select_distributed_torch_interface(): ++ """ ++ Determines the preferred network interface for distributed PyTorch communication. ++ ++ Args: ++ [Function takes no explicit arguments but inspects environment] ++ ++ GLOO_SOCKET_IFNAME (str): Environment variable specifying the network interface ++ for GLOO-based communication. Takes precedence over NCCL. ++ NCCL_SOCKET_IFNAME (str): Environment variable specifying the interface for ++ NCCL-based communication. ++ ++ Returns: ++ Optional[str]: Returns either: ++ - Value of GLOO_SOCKET_IFNAME if set ++ - Value of NCCL_SOCKET_IFNAME if set ++ - None if neither is specified ++ """ ++ for env in ["GLOO_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME"]: ++ if env in os.environ: ++ return os.environ[env] ++ return None ++ ++ ++class VllmInstance: ++ """ ++ VllmInstance is a vllm engine wrapped by a ray actor, responsibilities: ++ 1. start vllm api server (and pass some args) ++ ref: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#vllm-serve ++ 2. do the health check job (report to Controller if any failure) ++ """ ++ ++ _vllm_instance_config: VllmInstanceConfig ++ _vllm_instance_info: VllmInstanceInfo ++ #: the actor handle of balancer ++ _balancer_handle: ray.actor.ActorHandle ++ _vllm_api_server_process: subprocess.Popen ++ _vllm_api_server_health_monitor_task: Task[None] ++ ++ def __init__(self, name: str, vllm_config: VllmInstanceConfig): ++ """ ++ Args: ++ env: the environment variables pass to subprocess ++ exec_cmd: the vllm api server startup command, e.g. ["vllm", "serve", "--a=1", "--b=2"] ++ """ ++ self._vllm_instance_config = vllm_config ++ self._vllm_instance_info = VllmInstanceInfo(id=name, uri=None, role=vllm_config.pd_config.role) ++ self._balancer_handle = None ++ self._vllm_api_server_process = None ++ self._vllm_api_server_health_monitor_task = None ++ self._env = dict(os.environ) ++ self._env["HCCL_IF_BASE_PORT"] = os.environ.get('HCCL_IF_BASE_PORT', "50000") ++ ++ self.__has_process_started = False ++ ++ async def init_dp_master_ip_port(self): ++ """ ++ if dp config is None, init dp master ++ """ ++ intf = select_distributed_torch_interface() ++ if intf: ++ ip = find_ip_by_interface(intf) ++ else: ++ ip = find_node_ip() ++ intf = find_interface_by_ip(ip) ++ self._env["GLOO_SOCKET_IFNAME"] = intf ++ self._env["NCCL_SOCKET_IFNAME"] = intf ++ master_port = find_free_port(ip) ++ return ip, master_port ++ ++ async def init_dp_config(self, dp_config: VllmInstanceDPConfig = None): ++ """ ++ if dp config is None, init dp master ++ """ ++ self._vllm_instance_info.dp_master_ip = dp_config.dp_master_ip ++ self._vllm_instance_info.dp_master_port = dp_config.dp_master_port ++ self._vllm_instance_config.dp_config = dp_config ++ ++ async def initialize(self) -> None: ++ """launch subprocess""" ++ logger.info(f"initialize with ASCEND_RT_VISIBLE_DEVICES: {os.environ.get('ASCEND_RT_VISIBLE_DEVICES')}") ++ ++ # normalize and set some env vars ++ self._resort_ascend_rt_visible_devices_env() ++ self._env["VLLM_USE_V1"] = "1" ++ ++ # api server options ++ # dp slaves have no http api server ++ if self._vllm_instance_config.dp_config.dp_size == 0 or self._vllm_instance_config.dp_config.dp_rank == 0: ++ protocal = "http" ++ ip = find_node_ip() ++ port = find_free_port() ++ self._vllm_instance_info.uri = f"{protocal}://{ip}:{port}" ++ self._vllm_instance_config.exec_cmd.extend(["--host", ip, "--port", str(port)]) ++ ++ # tp, pd, and dp options ++ self._vllm_instance_config.exec_cmd.extend(["--tensor-parallel-size", str(self._vllm_instance_config.tp)]) ++ self._add_pd_command_options() ++ self._add_dp_command_options() ++ self._add_ep_command_options() ++ self._add_env() ++ ++ logger.info(f"initialize with command: {self._vllm_instance_config.exec_cmd}, env:{self._env}") ++ self._vllm_api_server_process = subprocess.Popen( ++ self._vllm_instance_config.exec_cmd, ++ stdout=sys.stdout, ++ stdin=sys.stdin, ++ stderr=sys.stderr, ++ text=True, ++ preexec_fn=os.setpgrp, ++ env=self._env, ++ ) ++ ++ # use a thread to check and report health status ++ # thread safety issue: https://github.com/ray-project/ray/issues/2385 ++ self._vllm_api_server_health_monitor_task = asyncio.create_task(self._monitor_health()) ++ ++ async def terminate(self, timeout_s=5): ++ """terminate""" ++ if self._vllm_api_server_process is None: ++ return ++ ++ try: ++ pgid = os.getpgid(self._vllm_api_server_process.pid) ++ os.killpg(pgid, signal.SIGTERM) ++ except ProcessLookupError: ++ logger.info("process already exited") ++ return ++ ++ # Another way is "self._vllm_api_server_process.terminate()" ++ try: ++ self._vllm_api_server_process.wait(timeout_s) ++ except (TimeoutError, subprocess.TimeoutExpired): ++ pass ++ finally: ++ if self._vllm_api_server_process.poll() is None: ++ # Another way is "self._vllm_api_server_process.kill()" ++ os.killpg(pgid, signal.SIGKILL) ++ ++ def _resort_ascend_rt_visible_devices_env(self): ++ if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: ++ return ++ try: ++ device_ids = [int(id.strip()) for id in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] ++ except ValueError: ++ return ++ os.environ["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(map(str, sorted(device_ids))) ++ self._env["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(map(str, sorted(device_ids))) ++ ++ def _add_pd_command_options(self): ++ connector_type = int(os.environ.get("USING_CONNECTOR_TYPE", 0)) ++ if self._vllm_instance_config.pd_config.is_disaggregated_p_d() and connector_type == 1: ++ self._vllm_instance_config.exec_cmd.extend( ++ [ ++ "--kv-transfer-config", ++ json.dumps( ++ { ++ "kv_connector": "DLLMDsD2DConnector", ++ "engine_id": str(uuid.uuid4()), ++ "kv_role": ( ++ "kv_producer" ++ if self._vllm_instance_config.pd_config.role is Role.PREFILL ++ else "kv_consumer" ++ ), ++ "kv_rank": 0 if self._vllm_instance_config.pd_config.role is Role.PREFILL else 1, ++ "kv_parallel_size": 2, ++ "kv_connector_extra_config": { ++ "device_ids": [int(i) for i in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] ++ }, ++ } ++ ), ++ ] ++ ) ++ else: ++ enable_prefix_connector = int(os.environ.get("USING_PREFIX_CONNECTOR", 0)) ++ if (not self._vllm_instance_config.pd_config.is_disaggregated_p_d()) and enable_prefix_connector: ++ self._vllm_instance_config.exec_cmd.extend( ++ [ ++ "--kv-transfer-config", ++ json.dumps( ++ { ++ "kv_connector": "DLLMDsConnector", ++ "kv_role": ( ++ "kv_both" ++ ), ++ "kv_connector_extra_config": { ++ "device_ids": [int(i) for i in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] ++ }, ++ } ++ ), ++ ] ++ ) ++ elif self._vllm_instance_config.pd_config.is_disaggregated_p_d(): ++ self._vllm_instance_config.exec_cmd.extend( ++ [ ++ "--kv-transfer-config", ++ json.dumps( ++ { ++ "kv_connector": "DLLMDsConnector", ++ "kv_role": ( ++ "kv_producer" ++ if self._vllm_instance_config.pd_config.role is Role.PREFILL ++ else "kv_consumer" ++ ), ++ "kv_parallel_size": 2, ++ "kv_rank": 0 if self._vllm_instance_config.pd_config.role is Role.PREFILL else 1, ++ "kv_connector_extra_config": { ++ "device_ids": [int(i) for i in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] ++ }, ++ } ++ ), ++ ] ++ ) ++ else: ++ return ++ ++ # NOTE ++ # If using disaggregated prefill, the DLLM connector must access all NPUs to transfer KV cache normally. ++ # The actual device IDs are set by `DLLMConnector.kv_transfer_config.kv_connector_extra_config.device_ids`, ++ # they have been retrieved from the `ASCEND_RT_VISIBLE_DEVICES` environment variable as code above ++ self._env["ASCEND_RT_VISIBLE_DEVICES"] = ",".join(map(str, range(get_num_npus()))) ++ ++ def _add_dp_command_options(self): ++ if not self._vllm_instance_config.dp_config.is_dp_enabled(): ++ return ++ ++ self._vllm_instance_config.exec_cmd.extend( ++ [ ++ "--data-parallel-size", ++ str(self._vllm_instance_config.dp_config.dp_size), ++ "--data-parallel-size-local", ++ str(self._vllm_instance_config.dp_config.dp_local_size), ++ "--data-parallel-start-rank", ++ str(self._vllm_instance_config.dp_config.dp_rank), ++ "--data-parallel-address", ++ str(self._vllm_instance_config.dp_config.dp_master_ip), ++ "--data-parallel-rpc-port", ++ str(self._vllm_instance_config.dp_config.dp_master_port), ++ ] ++ ) ++ ++ if self._vllm_instance_config.dp_config.dp_rank > 0: ++ self._vllm_instance_config.exec_cmd.extend(["--headless"]) ++ ++ def _add_ep_command_options(self): ++ if not self._vllm_instance_config.ep_config.is_ep_enabled(): ++ return ++ ++ self._vllm_instance_config.exec_cmd.extend( ++ [ ++ "--enable-expert-parallel", ++ ] ++ ) ++ ++ def _add_env(self): ++ if self._vllm_instance_config.env is None: ++ return ++ ++ env_dict = dict(item.split('=') for item in self._vllm_instance_config.env.split()) ++ for env_key, env_value in env_dict.items(): ++ self._env[env_key] = env_value ++ ++ async def _monitor_health(self): ++ """Asynchronously monitor subprocess health and report to controller""" ++ while not self._balancer_handle: ++ try: ++ self._balancer_handle = ray.get_actor(name=BALANCER_ACTOR_NAME, namespace=DLLM_NAMESPACE) ++ except: ++ logger.warning(f'Instance get _balancer_handle failed, wait for 1 second and retry.') ++ await asyncio.sleep(1) ++ ++ await self._balancer_handle.add_vllm_instance.remote(self._vllm_instance_info) ++ async with aiohttp.ClientSession() as session: ++ last_report_time = asyncio.get_event_loop().time() ++ last_status = self._vllm_instance_info.status ++ while True: ++ self._vllm_instance_info.status = VllmInstanceStatus.RUNNING ++ if self._vllm_api_server_process.poll() is not None: ++ self._vllm_instance_info.status = VllmInstanceStatus.SUBPROCESS_EXITED ++ elif self._vllm_instance_info.uri is not None: # only check DP master's healthy ++ try: ++ async with session.get( ++ f"{self._vllm_instance_info.uri}/health", timeout=aiohttp.ClientTimeout(total=2) ++ ) as response: ++ self._vllm_instance_info.status = ( ++ VllmInstanceStatus.HEALTHCHECK_FAILED ++ if response.status != 200 ++ else VllmInstanceStatus.RUNNING ++ ) ++ except (aiohttp.ClientError, asyncio.TimeoutError): ++ self._vllm_instance_info.status = VllmInstanceStatus.HEALTHCHECK_FAILED ++ if ( ++ # not healthy ++ self._vllm_instance_info.status != VllmInstanceStatus.RUNNING ++ # or changed ++ or self._vllm_instance_info.status != last_status ++ # or past quite long time, we should let controller know that we are still alive ++ or asyncio.get_event_loop().time() - last_report_time > VLLM_INSTANCE_HEALTH_CHECK_INTERVAL_S ++ ): ++ await self._balancer_handle.update_vllm_instance_health.remote([self._vllm_instance_info]) ++ last_report_time = asyncio.get_event_loop().time() ++ last_status = self._vllm_instance_info.status ++ ++ if self._vllm_instance_info.status == VllmInstanceStatus.SUBPROCESS_EXITED: ++ # terminate self ++ logger.info(f"vllm subprocess exited unexpectedly, VllmInstance exit with vllm together") ++ await asyncio.sleep(5) ++ ++ ++def start_vllm_instance(vllm_instance_config: VllmInstanceConfig, pg: PlacementGroup = None, name:str = None ++ ) -> ray.actor.ActorHandle: ++ """ ++ Start a VLLM instance. ++ ++ Args: ++ vllm_instance_config (VllmInstanceConfig): Configuration for the VLLM instance. ++ pg (PlacementGroup, optional): Placement group for scheduling. Defaults to None. ++ name (str, optional): Name of the actor. Defaults to None. ++ ++ Returns: ++ ray.actor.ActorHandle: Handle to the newly created VLLM instance actor. ++ """ ++ actor_options = { ++ "resources": {"NPU": vllm_instance_config.dp_config.dp_local_size * vllm_instance_config.tp}, ++ "name": name, ++ "num_cpus": 0, ++ } ++ if pg: ++ actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( ++ placement_group=pg, ++ ) ++ ++ vllm_instance_actor = ray.remote(VllmInstance).options(**actor_options).remote(name, vllm_instance_config) ++ return vllm_instance_actor +diff --git a/dllm_tools/dllm/cpp_ext/__init__.pyi b/dllm_tools/dllm/cpp_ext/__init__.pyi +new file mode 100644 +index 000000000..a574b24e2 +--- /dev/null ++++ b/dllm_tools/dllm/cpp_ext/__init__.pyi +@@ -0,0 +1,8 @@ ++""" ++pybind11 for DLLM ++""" ++from __future__ import annotations ++from . import kvc ++from . import perf ++from . import utils ++__all__ = ['kvc', 'perf', 'utils'] +diff --git a/dllm_tools/dllm/cpp_ext/kvc.pyi b/dllm_tools/dllm/cpp_ext/kvc.pyi +new file mode 100644 +index 000000000..913f83afd +--- /dev/null ++++ b/dllm_tools/dllm/cpp_ext/kvc.pyi +@@ -0,0 +1,290 @@ ++""" ++pybind11 for KV Cache ++""" ++from __future__ import annotations ++import typing ++__all__ = ['Blob', 'DeviceBlobList', 'KvcFuture', 'KvcResult', 'KvcStore', 'KvcTensor', 'PageAttnUtils'] ++class Blob: ++ """ ++ A class for representing a continuous range of device memory. ++ """ ++ __hash__: typing.ClassVar[None] = None ++ def __eq__(self, other: Blob) -> bool: ++ ... ++ def __init__(self, pointer: int, size: int) -> None: ++ """ ++ Constructor. ++ ++ Args: ++ pointer (int): The beginning address on the device memory. ++ size (int): The size in byte. ++ """ ++ ... ++ def __repr__(self) -> str: ++ ... ++ @property ++ def pointer(self) -> int: ++ """ ++ The beginning address on the device memory. ++ """ ++ ... ++ @property ++ def size(self) -> int: ++ """ ++ The size in byte. ++ """ ++ ... ++class DeviceBlobList: ++ """ ++ A class for representing a list of continuous range of memory on a certain device. ++ ++ Attributes: ++ blobs (list[Blob]): A list of continuous range of device memory. ++ device_idx (int): Device index. ++ """ ++ blobs: list[Blob] ++ device_idx: int ++ def __init__(self, blobs: list[Blob], device_idx: int) -> None: ++ """ ++ Constructor. ++ ++ Args: ++ blobs (list[Blob]): Ranges of device memory. ++ device_idx (int): Device index. ++ """ ++ ... ++ def append(self, blob: Blob) -> None: ++ """ ++ Append a blob to the blobs list. ++ ++ Args: ++ blob (Blob): Blob to be appended. ++ """ ++ ... ++ def to_array(self) -> list[list[int]]: ++ """ ++ Convert to a 2D int list. ++ ++ Returns: ++ 2D int list ++ """ ++ ... ++class KvcFuture: ++ """ ++ Future for a KvcStore operation. ++ """ ++ def done(self) -> bool: ++ """ ++ Check if the operation is done. ++ """ ++ ... ++ def result(self, timeout: float = -1) -> KvcResult: ++ """ ++ Get the operation result. ++ ++ Args: ++ timeout (float): Timeout in milliseconds, negative means infinite. ++ ++ Returns: ++ The operation result. ++ ++ Raise: ++ TimeoutError: If the waited time exceeds the specified timeout. ++ """ ++ ... ++ def running(self) -> bool: ++ """ ++ Check if the operation is still running. ++ """ ++ ... ++class KvcResult: ++ """ ++ Result of a KvcStore operation. ++ ++ Attributes: ++ error_message (str): Error message. ++ failed_list (list[str]): Failed key list. ++ status_code (int): Status code, 0 means success, otherwise means failed. ++ """ ++ error_message: str ++ failed_list: list[str] ++ status_code: int ++ def __repr__(self) -> str: ++ ... ++class KvcStore: ++ """ ++ Distributed Store for device memory content. ++ """ ++ def __init__(self) -> None: ++ ... ++ def delete(self, keys: list[str]) -> KvcFuture: ++ """ ++ Delete objects placed by mset_d2h() on the distribute host memory. ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ ++ Returns: ++ Future of deletions. ++ """ ++ ... ++ def exist(self, keys: list[str]) -> list[bool]: ++ """ ++ Check if the objects exist. ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ ++ Returns: ++ Flags of existences. ++ """ ++ ... ++ def get_d2d(self, keys: list[str], dev_blob_lists: list[DeviceBlobList]) -> list[KvcFuture]: ++ """ ++ Get objects placed by put_d2d(). ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ dev_blob_lists (list[DeviceBlobList]): Allocated spaces for the objects. ++ ++ Returns: ++ Futures of getting. ++ """ ++ ... ++ def init(self, host: str, port: int, conn_timeout_ms: int, num_threads: int=2) -> None: ++ """ ++ Connect to Datasystem service. ++ ++ Args: ++ host (str): Host ip of the service. ++ port (int): Port number of the service. ++ conn_timeout_ms (int): Connection timeout in milliseconds. ++ num_threads (int): Number of worker threads. ++ """ ++ ... ++ def mget_h2d(self, keys: list[str], dev_blob_lists: list[DeviceBlobList]) -> KvcFuture: ++ """ ++ Copy objects from the distribute host memory to device memory. ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ dev_blob_lists (list[DeviceBlobList]): Allocated spaces for the objects. ++ ++ Returns: ++ Future of copying. ++ """ ++ ... ++ def mset_d2h(self, keys: list[str], dev_blob_lists: list[DeviceBlobList]) -> KvcFuture: ++ """ ++ Copy objects from device memory to the distribute host memory. ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ dev_blob_lists (list[DeviceBlobList]): Device memory ranges of the objects. ++ ++ Returns: ++ Future of copying. ++ """ ++ ... ++ def put_d2d(self, keys: list[str], dev_blob_lists: list[DeviceBlobList]) -> list[KvcFuture]: ++ """ ++ Give objects of the source device memory to the request side device memory host memory of get_d2d(). ++ ++ Args: ++ keys (list[str]): Keys of objects. ++ dev_blob_lists (list[DeviceBlobList]): Device memory ranges of the objects. ++ ++ Returns: ++ Futures of giving. ++ """ ++ ... ++class KvcTensor: ++ """ ++ Abstraction of a tensor. ++ ++ Attributes: ++ elem_size (int): Element size. ++ ptr (int): Device memory address. ++ shape (list[int]): Tensor shape. ++ """ ++ __hash__: typing.ClassVar[None] = None ++ elem_size: int ++ ptr: int ++ shape: list[int] ++ def __eq__(self, other: KvcTensor) -> bool: ++ ... ++ def __init__(self, ptr: int, elem_size: int, shape: list[int]) -> None: ++ """ ++ Constructor. ++ ++ Args: ++ ptr (int): Device memory address. ++ elem_size (int): Element size. ++ shape (list[int]): Tensor shape. ++ """ ++ ... ++ def __repr__(self) -> str: ++ ... ++class PageAttnUtils: ++ """ ++ Page Attention Utilities. ++ """ ++ @staticmethod ++ def blk_2_blob(ptr: int, elem_size: int, num_block_elem: int, block_id: int) -> Blob: ++ """ ++ Convert a block to a blob. ++ ++ Args: ++ ptr (int): Device memory address. ++ elem_size (int): Element size. ++ num_block_elem (int): Number of block elements. ++ block_id (int): Block id. ++ ++ Returns: ++ The converted blob. ++ """ ++ ... ++ @staticmethod ++ def blks_2_dev_blob_list(device_idx: int, ptr: int, elem_size: int, num_block_elem: int, block_ids: list[int]) -> DeviceBlobList: ++ """ ++ Convert a block list to a device blob list. ++ ++ Args: ++ device_idx (int): Device index. ++ ptr (int): Device memory address. ++ elem_size (int): Element size. ++ num_block_elem (int): Number of block elements. ++ block_ids (list[int]): Block id list. ++ ++ Returns: ++ The converted device blob list. ++ """ ++ ... ++ @staticmethod ++ def blockwise_dev_blob_lists(device_idx: int, layer_tensors: list[KvcTensor], block_ids: list[int]) -> list[DeviceBlobList]: ++ """ ++ Convert a block list of layers to a device blob list, each block gives a device blob list. ++ ++ Args: ++ device_idx (int): Device index. ++ layer_tensors (list[KvcTensor]): Layer tensors. ++ block_ids (list[int]): block id lists. ++ ++ Returns: ++ The converted device blob list. ++ """ ++ ... ++ @staticmethod ++ def layerwise_dev_blob_lists(device_idx: int, layer_tensors: list[KvcTensor], block_ids: list[int]) -> list[DeviceBlobList]: ++ """ ++ Convert a block list of layers to a device blob list, each layer gives a device blob list. ++ ++ Args: ++ device_idx (int): Device index. ++ layer_tensors (list[KvcTensor]): Layer tensors. ++ block_ids (list[int]): block id lists. ++ ++ Returns: ++ The converted device blob list. ++ """ ++ ... +diff --git a/dllm_tools/dllm/cpp_ext/perf.pyi b/dllm_tools/dllm/cpp_ext/perf.pyi +new file mode 100644 +index 000000000..27c9f0051 +--- /dev/null ++++ b/dllm_tools/dllm/cpp_ext/perf.pyi +@@ -0,0 +1,10 @@ ++""" ++pybind11 for Performance ++""" ++from __future__ import annotations ++__all__ = ['PerfManager'] ++class PerfManager: ++ def __init__(self) -> None: ++ ... ++ def print(self) -> None: ++ ... +diff --git a/dllm_tools/dllm/cpp_ext/utils.pyi b/dllm_tools/dllm/cpp_ext/utils.pyi +new file mode 100644 +index 000000000..d9c46c73b +--- /dev/null ++++ b/dllm_tools/dllm/cpp_ext/utils.pyi +@@ -0,0 +1,27 @@ ++""" ++pybind11 for Utilities ++""" ++from __future__ import annotations ++__all__ = ['Logger'] ++class Logger: ++ @staticmethod ++ def debug(msg: str) -> None: ++ ... ++ @staticmethod ++ def error(msg: str) -> None: ++ ... ++ @staticmethod ++ def info(msg: str) -> None: ++ ... ++ @staticmethod ++ def log_off() -> None: ++ ... ++ @staticmethod ++ def warn(msg: str) -> None: ++ ... ++ def __init__(self, log_dir: str = '', base_name: str = 'dllm', max_log_size: int = 10485760, max_files_num: int = 5, enable_console_logging: bool = False) -> None: ++ ... ++ def set_flush_level(self, log_level: str) -> None: ++ ... ++ def set_log_level(self, log_level: str) -> None: ++ ... +diff --git a/dllm_tools/dllm/dkvc/README.md b/dllm_tools/dllm/dkvc/README.md +new file mode 100644 +index 000000000..d26a7152c +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/README.md +@@ -0,0 +1,142 @@ ++ ++## 分布式KV Cache 编译 ++ ++### 不依赖元戎的构建流程 ++若您希望进行不依赖元戎的编译,在下面的过程中请勿设置 `export COMPILE_WITH_YR=TRUE`,您也可以手动指定为 `export COMPILE_WITH_YR=FALSE`. ++ ++对于不依赖元戎的构建,datasystem 功能无法使用。所以请勿使用 `--ds-worker-addr`和`--enable-datasystem` ++ ++### 正常构建流程 ++ ++前提条件: 请确保编译 ascend-pytorch-inference 之前,已经 `pip install yr` ++ ++``` ++# 依赖元戎编译(默认为 FALSE,不依赖元戎) ++export COMPILE_WITH_YR=TRUE ++cd ascend-pytorch-inference/tools/dllm.dkvc ++python setup.py bdist_wheel ++``` ++ ++ ++### 开发调试流程 ++ ++以develop 模式调试分布式KV Cache ++ ++``` ++# 依赖元戎编译(默认为 FALSE,不依赖元戎) ++export COMPILE_WITH_YR=TRUE ++# 移动到分布式KVC源码所在的路径 ++cd tools/distributed_kv_cache ++# 先删除原来已安装的库,确保develop模式能正常生效 ++pip uninstall distributed_kv_cache ++# 如果需要开启性能打点,可以配置 ENABLE_PERF=true ++export ENABLE_PERF=true ++python setup.py develop ++ ++``` ++ ++#### 开发调试时手动指定元戎数据系统的路径 ++ ++在编译前,可以通过添加 `DS_DIR` 环境变量来指定链接的数据系统路径, cmake 会在 `${DS_DIR}/sdk/cpp/lib/cmake/Datasystem` 找到数据系统的 CMake 配置文件。 ++ ++ ++``` ++# 不设置DS_DIR时,默认链接 yr wheel 包里的数据系统SDK ++# 手动指定数据系统仓里的SDK ++export DS_DIR=datasystem/output ++# 手动指定元戎发布包里的SDK ++export DS_DIR=yuanrong/data_system ++ ++python setup.py develop ++``` ++ ++## 分布式KV Cache使用数据系统作为CPU Cache ++ ++配置参数: ++ +++ `--prefix-sharing-type gpu_cpu` : 前缀匹配使能二级缓存 +++ `--ds-worker-addr "127.0.0.1:4396"` : 连接数据系统ds-worker的地址, **PD分离时不需要设置** +++ `--enable-datasystem` : 使能数据系统。(注意:由于python argpaser实现机制的问题,配置`--enable-datasystem = false` 并不能**关掉**数据系统,如果要关掉建议还是直接去掉该参数) +++ `--multipath-devices "1,2,3"` : 可选,开启多路径传输使用的卡ID列表,卡ID之间用`,`隔开 ++ ++示例: ++ ++```bash ++python vllm/entrypoints/api_server.py \ ++--host 127.0.0.1 \ ++--port 58191 \ ++--tokenizer "/workspace/models/llama-2-7b-chat-hf" \ ++--first-token-timeout 300 \ ++--tensor-parallel-size 1 \ ++--model "/workspace/models/llama-2-7b-chat-hf" \ ++--scheduler-budget-len 8192 \ ++--max-num-seq 32 --prefix-sharing-type gpu_cpu \ ++--ds-worker-addr "127.0.0.1:4396" --enable-datasystem True \ ++--swap-space 32 &> output.log ++``` ++ ++## C++扩展的python接口提示生成 ++ ++``` ++cd ascend-pytorch-inference/tools/distributed_kv_cache/distributed_kv_cache ++# pybind11-stubgen 需要手动 pip install 一下 ++pybind11-stubgen distributed_kv_cache.cpp_ext -o . ++pybind11-stubgen distributed_kv_cache.cpp_ext_abi -o . ++``` ++ ++## 数据系统的部署 ++### 数据系统使用前需要在宿主机/容器配置大页内存 ++ ++数据系统使用共享内存作为CPU Cache,在和NPU Cache 进行memory copy 时,由于共享内存在ds-worker申请,而ds-worker 并不存在昇腾上下文,无法通过 CANN的 `aclrtMallocHost` 接口来申请大页内存,因此需要使用linux 的显式大页内存 hugetlb,作为cpu cache,来实现高效的memory copy。 ++ ++在容器 或者 宿主机配置大页内存: ++ ++``` ++# 配置30w个2MB的大页 ++echo 300000 > /proc/sys/vm/nr_hugepages ++``` ++ ++检查生效的大页数量 ++ ++``` ++cat /proc/sys/vm/nr_hugepages ++ ++``` ++ ++注意!:没有配置大页的话,会导致 worker 启动分配不到足够的大页内存,导致元戎集群拉起失败 ++ ++### PD合并需要`yr start`手动拉起元戎集群 ++ ++PD 合并时,如果要使用数据系统,需要先拉起一个元戎集群,元戎集群会启动数据系统worker,使得推理实例里的数据系统client能否访问到worker,从而可以把npu block 换到 ds-worker 的共享内存中缓存。 ++ ++手动拉起元戎集群的命令: ++ ++``` ++yr start --master --etcd_port=30000 --etcd_peer_port=30001 --global_scheduler_port=30002 --ds_arena_per_tenant=1 --ds_enable_fallocate=false --ds_enable_huge_tlb=true --shared_memory_num=262144 ++ ++``` ++ ++注意:shared_memory_num 的值,应该保持在 '推理实例 * tp数量 * swap_space * 1024 <= shared_memory_num <= 2 * /proc/sys/vm/nr_hugepages' ++ ++举个例子:假如拉起4个推理实例,每个实例TP为2,swap_space 为32GB,那么需要cpu cache 256GB,那么shared_memory_num 需要大于 256GB才够用,否则会走到 spill流程,导致换入换出变慢。 ++ ++### 使能多路径需要保证进程内卡ID 可见 ++ ++##### ASCEND_RT_VISIBLE_DEVICES 配合使用的时候,multipath-devices 需要使用相对的rank id。 ++ ++这里举个例子说明一下: ++比如 进程配置 ASCEND_RT_VISIBLE_DEVICES = 4,5,6,7。 ++那么对于这个进程来说,它看到的逻辑上的device_id 是 0,1,2,3,对应物理上的4,5,6,7 ++multipath-devices 也应该配置成 0,1,2,3,而不是4,5,6,7。 否则会导致multipath set_device 失败,进程无法正常启动。 ++ ++##### multipath-devices 能看到的卡的范围,也受限于 ASCEND_RT_VISIBLE_DEVICES 。 ++ ++比如说 ASCEND_RT_VISIBLE_DEVICES 设置了4,5,那么multipath-devices 只能用0,1,而不能使用0,1,2,3。 ++ ++##### 配置 `RAY_LOCAL_VISIBLE_DEVICES` 时,正常使用 物理卡ID即可。 ++ ++比如 `RAY_LOCAL_VISIBLE_DEVICES` 设置 2,3,6,7 ,那么multipath-devices 也是用2,3,6,7 ++ ++ ++## 设计文档 ++https://wiki.huawei.com/domains/94866/wiki/170901/WIKI202412205472997 +\ No newline at end of file +diff --git a/dllm_tools/dllm/dkvc/__init__.py b/dllm_tools/dllm/dkvc/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/dkvc/cpp_ext.pyi b/dllm_tools/dllm/dkvc/cpp_ext.pyi +new file mode 100644 +index 000000000..0a3349b09 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/cpp_ext.pyi +@@ -0,0 +1,124 @@ ++""" ++pybind11 for distributed kv cache ++""" ++from __future__ import annotations ++import typing ++__all__ = ['Blob', 'DeviceBlobList', 'FutureTimeoutException', 'KVCLogger', 'KvcTensorHandler', 'KVCacheStore', ++ 'PerfManager', 'SwapResult', "Future", "KvcResultFuture"] ++class Blob: ++ __hash__: typing.ClassVar[None] = None ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def __eq__(self, other: Blob) -> bool: ++ ... ++ def __init__(self, pointer: int, size: int) -> None: ++ ... ++ def __repr__(self) -> str: ++ ... ++ @property ++ def pointer(self) -> int: ++ ... ++ @property ++ def size(self) -> int: ++ ... ++class DeviceBlobList: ++ blobs: list[Blob] ++ device_id: int ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def __init__(self, blobs: list[Blob], device_id: int) -> None: ++ ... ++ def append(self, blob: Blob) -> None: ++ ... ++class FutureTimeoutException(Exception): ++ pass ++class KVCLogger: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ @staticmethod ++ def debug(msg: str) -> None: ++ ... ++ @staticmethod ++ def error(msg: str) -> None: ++ ... ++ @staticmethod ++ def info(msg: str) -> None: ++ ... ++ @staticmethod ++ def log_off() -> None: ++ ... ++ @staticmethod ++ def warn(msg: str) -> None: ++ ... ++ def __init__(self, log_dir: str = '', base_name: str = 'distributed_kvc', max_log_size: int = 10485760, max_files_num: int = 5, enable_console_logging: bool = False) -> None: ++ ... ++ def set_flush_level(self, log_level: str) -> None: ++ ... ++ def set_log_level(self, log_level: str) -> None: ++ ... ++class KvcTensorHandler: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++class KVCacheStore: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ @staticmethod ++ def multipath_destory() -> None: ++ ... ++ @staticmethod ++ def multipath_init(devices: list[int]) -> None: ++ ... ++ def __init__(self, host: str, port: int, connect_timeout_ms: int = 60000) -> None: ++ ... ++ def async_delete(self, keys: list[str]) -> SwapResult: ++ ... ++ def swap_in(self, keys: list[str], blob_list: list[DeviceBlobList]) -> SwapResult: ++ ... ++ def swap_out(self, keys: list[str], blob_list: list[DeviceBlobList]) -> SwapResult: ++ ... ++ def dev_publish(self, keys: list[str], blob_list: list[DeviceBlobList]) -> list[Future]: ++ ... ++ def dev_subscribe(self, keys: list[str], blob_list: list[DeviceBlobList]) -> list[Future]: ++ ... ++ def send_tensor(self, key: str, ptr: int, size: int, dev_id: int) -> list[Future]: ++ ... ++ def recv_tensor(self, key: str, ptr: int, size: int, dev_id: int) -> list[Future]: ++ ... ++ def send_kvcache(self, keys: list[str], handler: KvcTensorHandler, dev_id: int, blob_list: list[DeviceBlobList], layer_idxs: list[int] = []) -> list[Future]: ++ ... ++ def recv_kvcache(self, keys: list[str], handler: KvcTensorHandler, dev_id: int, blob_list: list[DeviceBlobList], layer_idxs: list[int] = []) -> list[Future]: ++ ... ++class PerfManager: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def __init__(self) -> None: ++ ... ++ def print(self) -> None: ++ ... ++class SwapResult: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def get_result(self, timeout_ms: int = 60000) -> list[str]: ++ ... ++ def is_ready(self) -> bool: ++ ... ++class Future: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def get(self, timeout_ms: int = 30000) -> str: ++ ... ++ ++class KvcResultFuture: ++ @staticmethod ++ def _pybind11_conduit_v1_(*args, **kwargs): ++ ... ++ def get(self, timeout_ms: int = 30000) -> str: ++ ... +\ No newline at end of file +diff --git a/dllm_tools/dllm/dkvc/cpu_cache_evictor.py b/dllm_tools/dllm/dkvc/cpu_cache_evictor.py +new file mode 100644 +index 000000000..10aa745d8 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/cpu_cache_evictor.py +@@ -0,0 +1,57 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++"""Evict KV cache.""" ++ ++from collections import OrderedDict ++import os ++ ++from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size ++ ++from dllm.cpp_ext.kvc import KvcStore ++ ++DS_CPU_CACHE_SPACE = os.environ.get("DS_CPU_CACHE_SPACE", 0) ++ ++ ++class DSCPUCacheManager: ++ """ DSCPUCacheManager """ ++ def __init__(self, max_block_nums: int): ++ self._ds_cached_blocks: OrderedDict[int, bool] = OrderedDict() ++ # NPU block 的倍数 ++ self.max_block_nums = max_block_nums ++ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000") ++ host, port = ds_worker_addr.split(":") ++ self.kvc_store = KvcStore() ++ self.kvc_store.init(host, int(port), 5000) ++ ++ def check_block_cached(self, content_hash: int) -> bool: ++ """check block cached""" ++ if content_hash not in self._ds_cached_blocks: ++ return False ++ self._ds_cached_blocks.move_to_end(content_hash) ++ return True ++ ++ def add_cache(self, content_hash: int): ++ """add cache""" ++ self._ds_cached_blocks[content_hash] = True ++ if len(self._ds_cached_blocks) > self.max_block_nums: ++ prefix_hash, _ = self._ds_cached_blocks.popitem(last=False) ++ self._del_in_ds(prefix_hash) ++ ++ def _del_in_ds(self, prefix_hash: int): ++ tp_size = get_tensor_model_parallel_world_size() ++ if tp_size == 0: ++ tp_size = 1 ++ ds_key_list = [] ++ for i in range(tp_size): ++ ds_key_list.append(f"rank{i}+{prefix_hash}") ++ self.kvc_store.delete(ds_key_list) +diff --git a/dllm_tools/dllm/dkvc/dllm_cache_engine.py b/dllm_tools/dllm/dkvc/dllm_cache_engine.py +new file mode 100644 +index 000000000..7b8eafa4f +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/dllm_cache_engine.py +@@ -0,0 +1,366 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++"""CacheEngine class for managing the KV cache.""" ++ ++from collections import defaultdict ++import os ++from typing import DefaultDict, Dict, List, Set, Tuple, Union ++ ++import torch ++from vllm.config import CacheConfig, ModelConfig, ParallelConfig ++from vllm.logger import init_logger ++from vllm.worker.cache_engine import CacheEngine ++ ++from dllm.cpp_ext.utils import Logger ++from dllm.cpp_ext.perf import PerfManager ++from dllm.cpp_ext.kvc import KvcStore, KvcFuture ++from dllm.kvc import TorchAdaptor ++ ++logger = init_logger(f"{__name__}") ++ ++KVCache = Tuple[torch.Tensor, torch.Tensor] ++ ++ ++STR_DTYPE_TO_TENSOR_DTYPE = { ++ "half": torch.half, ++ "float16": torch.half, ++ "bfloat16": torch.bfloat16, ++ "float": torch.float, ++ "fp8": torch.uint8, ++ "fp8_e4m3": torch.uint8, ++ "fp8_e5m2": torch.uint8, ++} ++ ++ ++class CacheMoveResult: ++ """Describe the send receive result. ++ ++ This class is responsible for describe the kvcahe or activition move result. ++ """ ++ def __init__(self) -> None: ++ self.send_kvcache_success: List[str] = [] ++ self.receive_kvcache_success: List[str] = [] ++ self.send_activition_success: List[str] = [] ++ self.receive_activition_success: List[str] = [] ++ ++ def result_append(self, key: str, is_kv_cache: bool = True, is_send: bool = True): ++ """Append success key to result""" ++ if is_kv_cache: ++ if is_send: ++ self.send_kvcache_success.append(key) ++ else: ++ self.receive_kvcache_success.append(key) ++ else: ++ if is_send: ++ self.send_activition_success.append(key) ++ else: ++ self.receive_activition_success.append(key) ++ ++ ++class DLLMCacheEngine(CacheEngine): ++ """Manages the KV cache. ++ ++ This class is responsible for initializing and managing the GPU and CPU KV ++ caches. It also provides methods for performing KV cache operations, such ++ as swapping and copying. ++ """ ++ ++ def __init__( ++ self, ++ cache_config: CacheConfig, ++ model_config: ModelConfig, ++ parallel_config: ParallelConfig, ++ local_rank: int, ++ gpu_cache: List, ++ ) -> None: ++ self.cache_config = cache_config ++ self.model_config = model_config ++ self.parallel_config = parallel_config ++ ++ self.head_size = model_config.get_head_size() ++ self.num_layers = model_config.get_num_layers(parallel_config) ++ self.num_heads = model_config.get_num_kv_heads(parallel_config) ++ self.dtype = model_config.dtype ++ ++ self.block_size = cache_config.block_size ++ self.num_gpu_blocks = cache_config.num_gpu_blocks ++ self.num_cpu_blocks = cache_config.num_cpu_blocks ++ ++ self.cache_size_per_block = ( ++ self.block_size ++ * self.num_heads ++ * self.head_size ++ * DLLMCacheEngine._get_dtype_size(self.dtype) ++ ) ++ ++ # Initialize the cache. ++ # self.gpu_cache = self.allocate_gpu_cache() # 原本是个 List[(tensor,tensor)] 长度是层数 ++ # gpu_cache 的下标是层数,每个元素 shape 是[block_num, block_size, kv_heads, head_size] ++ self.gpu_cache = gpu_cache ++ if os.environ.get("vLLM_MODEL_BACKEND", None) and os.environ["vLLM_MODEL_BACKEND"] == "MindFormers": ++ from mindspore.communication import get_rank ++ self.device = get_rank() ++ else: ++ self.device = torch.npu.current_device() ++ logger.info(f"DLLMCacheEngine init: {self.device}") ++ self.local_rank = local_rank ++ self.swap_hash_count: Dict[int, int] = {} # key: hash_val, value: count ++ self.kvcache_send_events: Dict[str, List[KvcFuture]] = {} ++ self.kvcache_receive_events: Dict[str, List[KvcFuture]] = {} ++ self._init_distributed_kvc() ++ self.swap_out_manager = AsyncSwapManager() ++ ++ def __del__(self): ++ self.perf_manager.print() ++ if self.remain_keys: ++ swap_result = self.kvc_store.delete(list(self.remain_keys)) ++ swap_result.result() ++ ++ @staticmethod ++ def _get_dtype_size(dtype: torch.dtype) -> int: ++ if isinstance(dtype, str): ++ dtype = STR_DTYPE_TO_TENSOR_DTYPE[dtype] ++ return torch.tensor([], dtype=dtype).element_size() ++ ++ # Host to device ++ # cpu_block_ids, npu_blockids, ++ def swap_in(self, src_to_dst: List[Tuple[int, int]], key: str) -> None: ++ """swap in by key""" ++ key_list = [] ++ block_token_ids, npu_block_ids = zip(*src_to_dst) ++ for token_id in block_token_ids: ++ key_list.append(self._gen_swap_key(token_id)) ++ self.swap_out_manager.wait_swap_finished(key_list) ++ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.gpu_cache, npu_block_ids) ++ try: ++ future.result(10) # 同步调用,等待搬运完成事件, 10s 超时 ++ except Exception as e: ++ self.__del__() ++ raise e ++ self.swap_out_manager.check_swap_finished() ++ ++ # Send device to device ++ # keys, blocks_ids ++ def send_kvcache(self, key: str, block_list: List[int]): ++ """Send kv cache by key """ ++ keys = self._prepare_keys(key) ++ future_list = self.kvc_store.put_page_attn_layerwise_d2d(keys, self.gpu_cache, block_list) ++ self.kvcache_send_events[key] = future_list ++ ++ # Get device from device ++ # keys, blocks_ids ++ def receive_kvcache(self, key: str, block_list: List[int]): ++ """Receive kv cache by key """ ++ keys = self._prepare_keys(key) ++ future_list = self.kvc_store.get_page_attn_layerwise_d2d(keys, self.gpu_cache, block_list) ++ self.kvcache_receive_events[key] = future_list ++ ++ # Check is device data transfer finished ++ def check_transfer_finished(self) -> CacheMoveResult: ++ """Get activation/kvcache transfer result""" ++ result = CacheMoveResult() ++ result.send_kvcache_success.extend(self._future_check(self.kvcache_send_events)) ++ result.receive_kvcache_success.extend(self._future_check(self.kvcache_receive_events)) ++ return result ++ ++ # Device to host, or host to null (evict) ++ # npu_block_ids, cpu_blockids, ++ def swap_out(self, src_to_dst: List[Tuple[int, int]], key: str) -> None: ++ """swap out by key""" ++ request_id = key ++ self.logger.debug(f"swap out {request_id}, with blocks num: {len(src_to_dst)}") ++ # npu_block_ids: block_number, block_token_ids: hash ++ npu_block_ids, block_token_ids = zip(*src_to_dst) ++ if all([x == -1 for x in npu_block_ids]): ++ # delete cpu block when all npu_block_id is -1 ++ key_list = [] ++ for hash_val in block_token_ids: ++ # 减计数,减到0才移除 ++ if self._dec_swap_hash_count(hash_val): ++ key_list.append(self._gen_swap_key(hash_val)) ++ self.kvc_store.delete(key_list) ++ for _key in key_list: ++ self.remain_keys.discard(_key) ++ else: ++ npu_block_ids = [] ++ block_token_ids = [] ++ for npu_blk_id, hash_val in src_to_dst: ++ npu_block_ids.append(npu_blk_id) ++ block_token_ids.append(hash_val) ++ if not npu_block_ids: ++ return ++ key_list = [] ++ for token_id in block_token_ids: ++ key_list.append(self._gen_swap_key(token_id)) ++ future = self.kvc_store.mset_page_attn_layerwise_d2h(key_list, self.gpu_cache, npu_block_ids) ++ self.swap_out_manager.add_swap_result(key_list, key, future) ++ for _key in key_list: ++ self.remain_keys.add(_key) ++ self.swap_out_manager.check_swap_finished() ++ ++ def _init_distributed_kvc(self): ++ self.perf_manager = PerfManager() ++ self.logger = Logger(enable_console_logging=False) ++ self.logger.set_log_level(os.getenv("LOG_LEVEL", "DEBUG")) ++ ip, port, timeout_ms = self._get_kvc_store_connect_opts() ++ logger.info(f"start init KVCacheStore, ip = {ip}, port = {port}, timeout_ms = {timeout_ms}") ++ ++ store = KvcStore() ++ store.init(ip, port, timeout_ms) ++ self.kvc_store = TorchAdaptor(store) ++ self.multipath_devices = self.cache_config.prefix_sharing_kwargs.get( ++ "multipath_devices", [] ++ ) ++ logger.info(f"self.multipath_devices = {self.multipath_devices}") ++ self.remain_keys: Set[str] = set() ++ self.logger.info("distributed kv cache init ok") ++ ++ def _get_kvc_store_connect_opts(self): ++ ds_worker_addr = self.cache_config.prefix_sharing_kwargs.get( ++ "ds_worker_addr", None ++ ) ++ if not ds_worker_addr: ++ raise RuntimeError("ds_worker_addr is not set in prefix sharing_kwargs") ++ host, port = ds_worker_addr.split(":") ++ return host, int(port), 60000 ++ ++ def _gen_swap_key(self, token_id: Union[int, str]) -> str: ++ return f"rank{self.local_rank}+{token_id}" ++ ++ def _dec_swap_hash_count(self, hash_val: int): ++ """ ++ Decrease hash count ++ Args: ++ hash: ++ ++ Returns: ++ True if decrease to 0. ++ """ ++ if hash_val in self.swap_hash_count: ++ if self.swap_hash_count[hash_val] == 1: ++ self.logger.debug(f"{hash_val} dec ref_count to 0") ++ del self.swap_hash_count[hash_val] ++ return True ++ ++ self.swap_hash_count[hash_val] -= 1 ++ self.logger.debug(f"dec {hash_val}. ref_count: {self.swap_hash_count[hash_val]}") ++ return False ++ self.logger.warn(f"attempt to decrease hash {hash_val} not exist in swap_hash_count dict") ++ return False ++ ++ def _inc_swap_hash_count(self, hash_val: int): ++ """ ++ Increase hash count ++ Args: ++ hash: ++ ++ Returns: ++ True if first time increase. ++ """ ++ if hash_val not in self.swap_hash_count: ++ self.logger.debug(f"first time inc {hash_val}") ++ self.swap_hash_count[hash_val] = 1 ++ return True ++ ++ self.swap_hash_count[hash_val] += 1 ++ self.logger.debug(f"inc {hash_val}. ref_count: {self.swap_hash_count[hash_val]}") ++ return False ++ ++ def _prepare_keys(self, prefix: str) -> List[str]: ++ keys = [] ++ for layer_id in range(self.num_layers): ++ keys.append(f"{prefix}_tprank:{self.local_rank}_layerid:{layer_id}") ++ return keys ++ ++ def _check_swap_result(self, swap_result_map: Dict[str, KvcFuture]) -> List[str]: ++ finished_req_ids: List[str] = [] ++ for key, swap_result in swap_result_map.items(): ++ if swap_result.done(): ++ self.logger.debug(f"{key} is swap ready") ++ finished_req_ids.append(key) ++ else: ++ break ++ for key in finished_req_ids: ++ swap_result_map.pop(key) ++ return finished_req_ids ++ ++ def _future_check(self, future_map: Dict[str, List[KvcFuture]]) -> List[str]: ++ finished_kyes: List[str] = [] ++ for key in list(future_map.keys()): ++ all_done_ok = True ++ for fut in future_map[key]: ++ try: ++ result = fut.result(0) ++ if result.status_code != 0: ++ all_done_ok = False ++ break ++ except TimeoutError: ++ all_done_ok = False ++ break ++ if all_done_ok: ++ finished_kyes.append(key) ++ del future_map[key] ++ return finished_kyes ++ ++ ++class AsyncSwapManager: ++ def __init__(self): ++ # key -> Set of batch id ++ self._key_to_batch: DefaultDict[str, Set[str]] = defaultdict(set) ++ # batch id -> (set of key, future) ++ self._pending_batch: Dict[str, Tuple[Set[str], KvcFuture]] = {} ++ ++ def add_swap_result(self, keys_list: List[str], batch_id: str, swap_result: KvcFuture): ++ """add swap result by key list""" ++ keys_set = set(keys_list) ++ self._pending_batch[batch_id] = (keys_set, swap_result) ++ for key in keys_set: ++ self._key_to_batch[key].add(batch_id) ++ ++ # 阻塞检查 ++ def wait_swap_finished(self, keys_list: List[str], timeout_ms: int = 10000): ++ """wait until swap finished""" ++ for key in keys_list: ++ batch_id_set = self._key_to_batch.get(key) ++ if batch_id_set is None: ++ continue ++ batch_id_set_copy = batch_id_set.copy() ++ for batch_id in batch_id_set_copy: ++ value = self._pending_batch.get(batch_id) ++ if value is None: ++ batch_id_set.discard(batch_id) ++ continue ++ keys_set, swap_result = value ++ swap_result.result(timeout_ms/1000) ++ self._clean_up_when_swap_finished(batch_id, keys_set) ++ ++ # 非阻塞 ++ def check_swap_finished(self): ++ """check if swap finished""" ++ ready_batch_list = [] ++ for batch_id in self._pending_batch: ++ keys_set, swap_result = self._pending_batch[batch_id] ++ if swap_result.done(): ++ logger.debug(f"batch_id {batch_id} is swap ready") ++ ready_batch_list.append(batch_id) ++ for batch_id in ready_batch_list: ++ self._clean_up_when_swap_finished(batch_id, keys_set) ++ ++ def _clean_up_when_swap_finished(self, batch_id: str, keys_set: Set[str]): ++ if batch_id in self._pending_batch: ++ del self._pending_batch[batch_id] ++ for key in keys_set: ++ self._key_to_batch[key].discard(batch_id) ++ if not self._key_to_batch[key]: ++ del self._key_to_batch[key] +diff --git a/dllm_tools/dllm/dkvc/dllm_connector.py b/dllm_tools/dllm/dkvc/dllm_connector.py +new file mode 100644 +index 000000000..b11fc8080 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/dllm_connector.py +@@ -0,0 +1,277 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++""" DLLM Connector """ ++ ++import os ++import hashlib ++from typing import TYPE_CHECKING, List, Tuple, Union ++ ++import torch ++ ++from vllm import envs ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase ++from vllm.logger import init_logger ++from vllm.distributed.parallel_state import get_pp_group ++from vllm.sequence import IntermediateTensors ++ ++from dllm.cpp_ext.kvc import KvcStore ++from dllm.kvc import TorchAdaptor ++ ++if TYPE_CHECKING: ++ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class DLLMConnector(KVConnectorBase): ++ ++ def __init__( ++ self, ++ rank: int, ++ local_rank: int, ++ config: VllmConfig, ++ ): ++ logger.info(f"init DLLMConnector") ++ self.config = config.kv_transfer_config ++ self.tp_size = config.parallel_config.tensor_parallel_size ++ self.is_deepseek_mla = config.model_config.is_deepseek_mla ++ self.use_mla_opt = not envs.VLLM_MLA_DISABLE ++ ++ self.tp_rank = rank ++ logger.info(f"tp_rank = {self.tp_rank}, device_id = {local_rank}") ++ logger.info(f"is_deepseek_mla = {self.is_deepseek_mla}, use_mla_opt = {self.use_mla_opt}") ++ # bypass the kv_store setup and checking right now ++ self.device = local_rank ++ ++ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000") ++ ip_port = ds_worker_addr.split(":") ++ ip = ip_port[0] ++ port = int(ip_port[1]) ++ logger.info(f"ip = {ip}, port = {port}") ++ ###> self.kvc_store = TransferEngine(ip=ip, port=port, device_id=self.device) ++ store = KvcStore() ++ store.init(ip, port, 6000) ++ self.kvc_store = TorchAdaptor(store) ++ ++ @staticmethod ++ def tensor_hash(tensor: torch.Tensor) -> int: ++ """Calculate the hash value of the tensor.""" ++ tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() ++ hash_object = hashlib.blake2b(tensor_bytes) ++ hash_hex = hash_object.hexdigest() ++ return int(hash_hex[:16], 16) ++ ++ def close(self) -> None: ++ """Close the buffer and release resources. ++ This method is responsible for cleaning up resources related to the ++ connector when it is no longer needed. ++ Raises: ++ NotImplementedError: This method must be implemented in subclasses. ++ """ ++ logger.info("connecter close") ++ ++ def send_kv_caches_and_hidden_states( ++ self, ++ model_executable: torch.nn.Module, ++ model_input: "ModelInputForGPUWithSamplingMetadata", ++ kv_caches: List[torch.Tensor], ++ hidden_or_intermediate_states: Union[torch.Tensor, ++ IntermediateTensors], ++ ) -> None: ++ """send kv_caches and hidden_states""" ++ input_tokens_tensor = model_input.input_tokens ++ seq_lens = model_input.attn_metadata.seq_lens ++ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() ++ mf_model = False ++ if hasattr(model_executable, "mf_model_config"): # vllm-mindspore ++ model_config = model_executable.mf_model_config ++ from vllm.distributed.utils import get_pp_indices ++ start_layer, end_layer = get_pp_indices(model_config.num_layers, ++ get_pp_group().rank_in_group, ++ get_pp_group().world_size) ++ num_heads = int((model_config.n_kv_heads if model_config.n_kv_heads else model_config.num_heads) ++ / self.tp_size) ++ hidden_size = model_config.hidden_size ++ num_attention_heads = model_config.num_heads ++ mf_model = True ++ else: # vllm-ascend ++ model_config = model_executable.model.config ++ start_layer = model_executable.model.start_layer ++ end_layer = model_executable.model.end_layer ++ num_heads = int(model_config.num_key_value_heads / self.tp_size) ++ hidden_size = model_config.hidden_size ++ num_attention_heads = model_config.num_attention_heads ++ head_size = int(hidden_size / num_attention_heads) ++ if self.is_deepseek_mla and self.use_mla_opt: ++ head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim ++ num_heads = 1 ++ elif self.is_deepseek_mla and not self.use_mla_opt: ++ head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim ++ else: ++ head_size = getattr(model_config, "head_dim", int(hidden_size // num_attention_heads)) ++ ++ for idx, slen in enumerate(seq_lens): ++ start_pos = sum(seq_lens[:idx]) ++ end_pos = start_pos + slen ++ ++ current_tokens = input_tokens_tensor[start_pos:end_pos] ++ store_key_prefix = self.tensor_hash(current_tokens) ++ ++ keys, values = [], [] ++ ++ for layer_id in range(start_layer, end_layer): ++ kv_cache = kv_caches[layer_id - start_layer] ++ ++ if self.is_deepseek_mla and self.use_mla_opt: ++ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ value_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ else: ++ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) ++ ++ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] ++ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) ++ values.append(value_cache[current_slot_mapping].unsqueeze(0)) ++ keys = torch.cat(keys, dim=0) ++ values = torch.cat(values, dim=0) ++ kvcache_to_sent = torch.stack((keys, values), dim=0) ++ store_kvcache_key = f"{store_key_prefix}_{self.tp_rank}" ++ logger.debug(f"store_kvcache_key: {store_kvcache_key}, kvcache_to_sent shape: {kvcache_to_sent.shape}") ++ self.kvc_store.put_tensors_d2d([store_kvcache_key], [kvcache_to_sent]) ++ ++ hidden_states_key = f"hidden_states_{store_key_prefix}_{self.tp_rank}" ++ logger.debug(f"sending hidden state. shape= {hidden_or_intermediate_states[start_pos:end_pos].shape}") ++ if mf_model: ++ self.kvc_store.put_tensors_d2d([hidden_states_key], ++ [hidden_or_intermediate_states[start_pos:end_pos].copy()]) ++ else: ++ self.kvc_store.put_tensors_d2d([hidden_states_key], ++ [hidden_or_intermediate_states[start_pos:end_pos].clone()]) ++ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) ++ ++ def recv_kv_caches_and_hidden_states( ++ self, model_executable: torch.nn.Module, ++ model_input: "ModelInputForGPUWithSamplingMetadata", ++ kv_caches: List[torch.Tensor] ++ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, ++ "ModelInputForGPUWithSamplingMetadata"]: ++ """recv kv_caches and hidden_states""" ++ bypass_model_exec = True ++ ++ input_tokens_tensor = model_input.input_tokens ++ seq_lens = model_input.attn_metadata.seq_lens ++ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() ++ num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens ++ if hasattr(model_executable, "mf_model_config"): # vllm-mindspore ++ model_config = model_executable.mf_model_config ++ from vllm.distributed.utils import get_pp_indices ++ start_layer, end_layer = get_pp_indices(model_config.num_layers, ++ get_pp_group().rank_in_group, ++ get_pp_group().world_size) ++ num_heads = int((model_config.n_kv_heads if model_config.n_kv_heads else model_config.num_heads) ++ / self.tp_size) ++ hidden_size = model_config.hidden_size ++ num_attention_heads = model_config.num_heads ++ else: # vllm-ascend ++ model_config = model_executable.model.config ++ start_layer = model_executable.model.start_layer ++ end_layer = model_executable.model.end_layer ++ num_heads = int(model_config.num_key_value_heads / self.tp_size) ++ hidden_size = model_config.hidden_size ++ num_attention_heads = model_config.num_attention_heads ++ head_size = int(hidden_size / num_attention_heads) ++ if self.is_deepseek_mla and self.use_mla_opt: ++ head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim ++ num_heads = 1 ++ elif self.is_deepseek_mla and not self.use_mla_opt: ++ head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim ++ else: ++ head_size = getattr(model_config, "head_dim", int(hidden_size // num_attention_heads)) ++ ++ hidden_or_intermediate_states_for_one_req = [] ++ ++ for idx, slen in enumerate(seq_lens): ++ start_pos = sum(seq_lens[:idx]) ++ end_pos = start_pos + slen ++ if start_pos >= num_prefill_tokens: ++ # This can happen during inflight batching. See: ++ # vllm/worker/model_runner.py::_prepare_model_input_tensors: ++ # - input_tokens[:num_prefill_tokens] contains prefill tokens. ++ # - input_tokens[num_prefill_tokens:] contains decode tokens. ++ logger.warning("You should set --enable_chunked_prefill=False " ++ "and --max_num_batched_tokens " ++ "should be equal to max_seq_len_to_capture") ++ bypass_model_exec = False ++ if start_pos != num_prefill_tokens: ++ logger.error(f"start_pos={start_pos} is not equals to num_prefill_tokens={num_prefill_tokens}") ++ raise ValueError("start_pos is not equals to num_prefill_tokens") ++ ++ current_tokens = input_tokens_tensor[start_pos:end_pos] ++ num_computed_tokens = current_tokens.shape[0] ++ # get roi for current seq ++ load_key_prefix = self.tensor_hash(current_tokens) ++ load_kvcache_key = f"{load_key_prefix}_{self.tp_rank}" ++ temp_k_cache = torch.zeros((end_layer - start_layer, num_computed_tokens, num_heads, head_size), ++ dtype=kv_caches[0][0].dtype, device=self.device) ++ temp_v_cache = torch.zeros((end_layer - start_layer, num_computed_tokens, num_heads, head_size), ++ dtype=kv_caches[0][0].dtype, device=self.device) ++ kvcache_to_recv = torch.stack((temp_k_cache, temp_v_cache), dim=0) ++ logger.debug(f"load_kvcache_key: {load_kvcache_key}, kvcache_to_recv shape: {kvcache_to_recv.shape}") ++ ###> self.kvc_store.recv_tensor(load_kvcache_key, kvcache_to_recv) ++ ++ self.kvc_store.get_tensors_d2d([load_kvcache_key], [kvcache_to_recv]) ++ ++ keys, values = kvcache_to_recv[0], kvcache_to_recv[1] ++ hidden = torch.zeros(num_computed_tokens, hidden_size, dtype=kv_caches[0][0].dtype, device=self.device) ++ hidden_key = f"hidden_states_{load_key_prefix}_{self.tp_rank}" ++ ###> self.kvc_store.recv_tensor(hidden_key, hidden) ++ ++ self.kvc_store.get_tensors_d2d([hidden_key], [hidden]) ++ ++ logger.debug(f"received hidden state shape= {hidden.shape}") ++ hidden_or_intermediate_states_for_one_req.append(hidden) ++ ++ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] ++ ++ for layer_id in range(start_layer, end_layer): ++ kv_cache = kv_caches[layer_id - start_layer] ++ ++ if self.is_deepseek_mla and self.use_mla_opt: ++ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ value_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ else: ++ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) ++ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) ++ ++ for i, slot_idx in enumerate(current_slot_mapping): ++ key_cache[slot_idx] = keys[layer_id][i] ++ value_cache[slot_idx] = values[layer_id][i] ++ ++ if not bypass_model_exec: ++ # Some of the KV cache is not retrieved ++ # Here we will fall back to normal model forwarding ++ # But optionally you can adjust model_input so that you only do ++ # prefilling on those tokens that are missing KV caches. ++ logger.debug( ++ "[rank%d]: Failed to receive all KVs and hidden " ++ "states, redo model forwarding.", torch.distributed.get_rank()) ++ hidden_or_intermediate_states = None ++ else: ++ logger.debug( ++ "[rank%d]: Successfully received all KVs and hidden " ++ "states, skip model forwarding.", torch.distributed.get_rank()) ++ hidden_or_intermediate_states = torch.cat( ++ hidden_or_intermediate_states_for_one_req, dim=0) ++ ++ return hidden_or_intermediate_states, bypass_model_exec, model_input +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/__init__.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/block.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/block.py +new file mode 100644 +index 000000000..bb96ecff0 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/block.py +@@ -0,0 +1,67 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++"""Extend PhysicalTokenBlock""" ++import enum ++ ++from vllm.block import PhysicalTokenBlock ++from vllm.utils import Device ++ ++ ++class KVCacheProgressStatus(enum.Enum): ++ """Status of a SequenceData in Radix Cache.""" ++ # 初始化状态 ++ INITIALIZE = enum.auto() ++ # 正在换出 ++ SWAPPING_OUT = enum.auto() ++ # 正在换入 ++ SWAPPING_IN = enum.auto() ++ ++ def __str__(self) -> str: ++ if self == KVCacheProgressStatus.INITIALIZE: ++ return "Initialize" ++ if self == KVCacheProgressStatus.SWAPPING_OUT: ++ return "swapping out" ++ return "swapping in" ++ ++ ++class PhysicalTokenBlockExt(PhysicalTokenBlock): ++ def __init__( ++ self, ++ device: Device, ++ block_number: int, ++ block_size: int, ++ ) -> None: ++ super().__init__(device, block_number, block_size) ++ # ref_count will not count radix cache ref. Only count seqs that not swapped out. ++ self.ref_set_of_seq = set() # ALL seqs which is referencing this Block ++ self.in_radix_cache = False # Whether this Block is referenced by Radix Cache ++ self.depth = None # type: int ++ self.progress_status = KVCacheProgressStatus.INITIALIZE ++ ++ def __repr__(self) -> str: ++ return (f'PhysicalTokenBlock(device={self.device}, ' ++ f'block_number={self.block_number}, ' ++ f'ref_count={self.ref_count},' ++ f'block_hash={self.block_hash},' ++ f'depth={self.depth},' ++ f'in_radix={self.in_radix_cache},' ++ f'used_by={[_s.seq_id for _s in self.ref_set_of_seq]}),') ++ ++ def ref_count_up(self): ++ """increase ref_count""" ++ self.ref_count += 1 ++ ++ def ref_count_down(self): ++ """decrease ref_count""" ++ self.ref_count -= 1 +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/__init__.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/block_manager.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/block_manager.py +new file mode 100644 +index 000000000..1df267979 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/block_manager.py +@@ -0,0 +1,706 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++"""A block manager that manages the block allocation or free, inherit the vllm's block manager.""" ++ ++import random ++import string ++from collections import defaultdict ++from threading import Lock ++from typing import Dict, List, Tuple ++ ++from vllm.block import PhysicalTokenBlock ++from vllm.core.block_manager import (BlockAllocator, BlockTable, SwapType, ++ local_thread) ++from vllm.engine.prefix_sharing_type import PrefixSharingType ++from vllm.logger import init_logger ++from vllm.sequence import Sequence, SequenceGroup, SequenceStatus ++from vllm.utils import Device, round_up ++from vllm.core import block_manager ++ ++from dllm.dkvc.prefix_sharing_multi_level.block import ( ++ KVCacheProgressStatus, PhysicalTokenBlockExt) ++from dllm.dkvc.prefix_sharing_multi_level.core.swap_in_watcher import \ ++ ReqSwapInWatcher ++from dllm.dkvc.prefix_sharing_multi_level.index.index_tree_manager import \ ++ IndexTreeManager ++from dllm.dkvc.util import assert_or_raise ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class BlockAllocatorExt(BlockAllocator): ++ def __init__( ++ self, ++ device: Device, ++ block_size: int, ++ num_blocks: int, ++ ) -> None: ++ self.device = device ++ self.block_size = block_size ++ self.num_blocks = num_blocks ++ ++ # Initialize the free blocks. ++ self.free_blocks: List[PhysicalTokenBlockExt] = [] ++ numbers = list(range(num_blocks)) ++ random.shuffle(numbers) ++ for i in numbers: ++ block = PhysicalTokenBlockExt( ++ device=device, block_number=i, block_size=block_size ++ ) ++ self.free_blocks.append(block) ++ ++ def allocate(self) -> PhysicalTokenBlockExt: ++ """分配block块""" ++ sync_swap_out_blocks = [] ++ if hasattr(local_thread, "sync_swap_out_blocks"): ++ sync_swap_out_blocks = local_thread.sync_swap_out_blocks ++ ++ if not self.free_blocks and not sync_swap_out_blocks: ++ raise ValueError("Out of memory! No free blocks are available.") ++ ++ pop_from_free_blocks = True ++ block = None ++ if self.free_blocks: ++ block = self.free_blocks.pop(0) ++ else: ++ block = sync_swap_out_blocks.pop(0) ++ pop_from_free_blocks = False ++ ++ assert_or_raise( ++ not block.ref_set_of_seq, ++ f"allocate err! block: {block}, pop from free_blosk: {pop_from_free_blocks}", ++ ) ++ block.ref_count = 0 # 这里改成 = 0 拿到块后再增加 ++ block.in_radix_cache = False ++ block.progress_status = KVCacheProgressStatus.INITIALIZE ++ block.depth = None ++ block.block_hash = "" ++ return block ++ ++ def free(self, block: PhysicalTokenBlockExt, sync_swap_out: bool = False) -> None: ++ """释放block块""" ++ # block.ref_count -= 1 # 这里不做减计数了,外面就要做减计数和解除绑定 ++ if ( ++ block.ref_count <= 0 ++ and not block.ref_set_of_seq ++ and not block.in_radix_cache ++ ): ++ if sync_swap_out: ++ if not hasattr(local_thread, "sync_swap_out_blocks"): ++ local_thread.sync_swap_out_blocks = [] ++ ++ local_thread.sync_swap_out_blocks.append(block) ++ else: ++ self.free_blocks.append(block) ++ else: ++ logger.debug(f"cannot free block: {block}") ++ ++ ++class AutoPrefixSharingBlockManager(block_manager.DefaultBlockManager): ++ ++ def __init__( ++ self, ++ block_size: int, ++ scheduler_budget_len: int, ++ num_gpu_blocks: int, ++ num_cpu_blocks: int, ++ watermark: float = 0.01, ++ block_sliding_window: int = 256, ++ sink_block_num: int = 0, ++ prefix_sharing_type: PrefixSharingType = PrefixSharingType.GPU_CPU, ++ using_datasystem: bool = False, ++ ) -> None: ++ """Manages the mapping between logical and physical token blocks. ++ ++ Args: inherit the vllm's block manager ++ """ ++ self.block_size = block_size ++ self.scheduler_budget_len = scheduler_budget_len ++ self.num_total_gpu_blocks = num_gpu_blocks ++ self.num_total_cpu_blocks = num_cpu_blocks ++ ++ self.sink_block_num = sink_block_num ++ self.block_sliding_window = block_sliding_window ++ self.window_size = ( ++ self.sink_block_num + self.block_sliding_window ++ ) * self.block_size ++ ++ self.watermark = watermark ++ assert_or_raise(watermark >= 0.0) ++ ++ self.watermark_blocks = int(watermark * num_gpu_blocks) ++ self.gpu_allocator = BlockAllocatorExt(Device.GPU, block_size, num_gpu_blocks) ++ self.cpu_allocator = BlockAllocatorExt(Device.CPU, block_size, num_cpu_blocks) ++ # Mapping: seq_id -> BlockTable. ++ self.block_tables: Dict[int, BlockTable] = {} ++ ++ self.sys_prefix_token_ids: Dict[int, List[int]] = defaultdict(list) ++ ++ self.seq_group_swapping_src_blocks: Dict[str, BlockTable] = {} ++ ++ # Used in allocate memory for speculate token ids ++ self.max_speculate_len = 0 ++ self.blocks_to_copy: List[Tuple[int, int]] = [] ++ ++ self.req_swap_in_watcher = ReqSwapInWatcher() ++ ++ self.prefix_sharing_type = prefix_sharing_type ++ num_cpu_blocks_cache = num_cpu_blocks ++ num_gpu_blocks_cache = num_gpu_blocks ++ logger.info( ++ f"Init multi-level prefix sharing block manager with total gpu blocks: " ++ f"{num_gpu_blocks_cache}, total cpu blocks: {num_cpu_blocks_cache}" ++ ) ++ self.index_tree_manager = IndexTreeManager( ++ block_size, num_cpu_blocks_cache, num_gpu_blocks_cache ++ ) ++ # key 是req_id,Tuple[PhysicalTokenBlock, int] 分别表示NPU块,和cpu块号 ++ self.swapping_out_map_cpu2npu: Dict[ ++ str, List[Tuple[PhysicalTokenBlock, int]] ++ ] = {} ++ # key 是req_id,Tuple[PhysicalTokenBlock, int] 分别表示NPU块号,和cpu块 ++ self.swapping_in_map_cpu2npu: Dict[ ++ str, List[Tuple[int, PhysicalTokenBlock]] ++ ] = {} ++ self.index_tree_lock = Lock() ++ self.using_datasystem = using_datasystem ++ ++ def append_slot(self, seq: Sequence, blocks_to_copy: List[Tuple[int, int]]): ++ """Allocate a physical slot for a new token.""" ++ block_table = self.block_tables[seq.seq_id] ++ num_need_slots = self.max_speculate_len + 1 ++ ++ seq.position_start += seq.scheduled_len ++ seq.scheduled_len = num_need_slots ++ sliding_len = seq.cached_len + num_need_slots - self.window_size ++ ++ if sliding_len > 0: ++ self.slide_block_table(seq, sliding_len) ++ left_cap = self.get_block_table_left_cap(seq) ++ ++ if left_cap < num_need_slots: ++ block = self.gpu_allocator.allocate() ++ block.ref_count_up() ++ self.bind_seq_with_block(seq, block) ++ return ++ ++ # We want to append the token to the last physical block. ++ last_block = block_table[-1] ++ ++ # Not shared with other sequences. Appendable. ++ if last_block.ref_count <= 1: ++ return ++ ++ logger.error( ++ f"Should not be here, last_block.ref_count = {last_block.ref_count} {seq.cached_len}" ++ ) ++ # The last block is shared with other sequences. ++ # Copy on Write: Allocate a new block and copy the tokens. ++ new_block = self.gpu_allocator.allocate() ++ new_block.ref_count_up() ++ self.bind_seq_with_block(seq, new_block) ++ last_block.ref_count_down() ++ self.unbind_seq_with_block(seq, last_block) ++ ++ self.gpu_allocator.free(last_block) ++ blocks_to_copy.append((last_block.block_number, new_block.block_number)) ++ ++ def auto_prefix_sharing_multi_level( ++ self, seq: Sequence ++ ) -> (int, BlockTable, BlockTable, BlockTable): ++ """前缀匹配并处理在npu上的前缀缓存 ++ 匹配到的缓存块可能在npu和cpu上都有,该方法只处理在npu上的块。cpu上的块需要和scheduler的swap操作配合。 ++ Args: ++ seq: the sequence to be prefilled ++ ++ Returns: ++ cached_cpu_len: the length of cpu cache to be processed ++ block_table_on_npu ++ block_table_on_cpu ++ blocks_to_copy_on_write ++ """ ++ block_table = self.block_tables.get(seq.seq_id) ++ if not block_table: ++ with self.index_tree_lock: ++ total_cached_len, prefix_cache_npu, prefix_cache_cpu, blocks_to_copy = ( ++ self.index_tree_manager.match_prefix(seq) ++ ) ++ ++ npu_cached_len, cpu_cached_len = self._parse_cached_len( ++ total_cached_len, prefix_cache_npu, prefix_cache_cpu, blocks_to_copy ++ ) ++ logger.debug( ++ f"[stat] seq: {seq.seq_id}, prompt_len: {seq.get_prompt_len()}," ++ f" npu_cached_len: {npu_cached_len}, cpu_cached_len: {cpu_cached_len}" ++ ) ++ ++ seq.update_cached_len(npu_cached_len) ++ seq.position_start = npu_cached_len ++ seq.first_block_start_offset = 0 ++ ++ if npu_cached_len > 0 or cpu_cached_len > 0: ++ seq.prefix_matched = True ++ ++ return cpu_cached_len, prefix_cache_npu, prefix_cache_cpu, blocks_to_copy ++ ++ return 0, block_table, [], [] ++ ++ def generate_hash(self, seq: Sequence): ++ """生成hash值""" ++ all_tokens = seq.get_token_ids() ++ block_table = self.block_tables.get(seq.seq_id, []) ++ prev_block_hash = "" ++ for block in block_table: ++ token_begin_pos = block.depth * self.index_tree_manager.block_size ++ token_end_pos = token_begin_pos + self.index_tree_manager.block_size ++ cur_block_hash = self.index_tree_manager.generate_hash_key( ++ block.depth == 0, ++ prev_block_hash, ++ all_tokens[token_begin_pos:token_end_pos], ++ ) ++ block.set_hash(cur_block_hash) ++ prev_block_hash = cur_block_hash ++ ++ def generate_swap_mapping(self, seq_group: SequenceGroup, swap_type: SwapType): ++ """建立换出block-> block 映射关系""" ++ block_mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} ++ in_radix_cache_block_mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} ++ swapping_src_blocks: BlockTable = [] ++ allocator = ( ++ self.gpu_allocator if swap_type == SwapType.SWAP_IN else self.cpu_allocator ++ ) ++ seq_status = ( ++ SequenceStatus.SWAPPED ++ if swap_type == SwapType.SWAP_IN ++ else SequenceStatus.RUNNING ++ ) ++ logger.debug( ++ f"generate swap mapping, type: {'SWAP_OUT' if swap_type == SwapType.SWAP_OUT else 'SWAP_IN'} " ++ f"for seq {seq_group.request_id}" ++ ) ++ if swap_type == SwapType.SWAP_OUT: ++ for seq in seq_group.get_seqs(status=seq_status): ++ block_table: List[PhysicalTokenBlockExt] = self.block_tables.get( ++ seq.seq_id, [] ++ ) ++ self.generate_hash(seq) ++ for src_block in block_table: ++ src_block.ref_count -= 1 ++ assert_or_raise( ++ src_block.ref_count >= 0, f"err src block: {src_block}" ++ ) ++ if ( ++ src_block.device == Device.GPU ++ and src_block.ref_count == 0 ++ and src_block.progress_status ++ == KVCacheProgressStatus.INITIALIZE ++ ): ++ # 可以换出这个块 ++ assert_or_raise( ++ src_block not in block_mapping ++ ) # 减计数减到 0 了才走到这一步 ++ src_block.progress_status = KVCacheProgressStatus.SWAPPING_OUT ++ dst_block = allocator.allocate() ++ block_mapping[src_block] = dst_block ++ # 双向关系改动 ++ # 这里就移动了 ref_set_of_seq ,那么当请求 abort 时,是找 dst_block,因为双向关系改动已发生,下面修改了该 seq 对应的 block_table ++ dst_block.ref_set_of_seq = src_block.ref_set_of_seq ++ dst_block.depth = src_block.depth ++ dst_block.in_radix_cache = src_block.in_radix_cache ++ src_block.in_radix_cache = False ++ dst_block.set_hash(src_block.block_hash) ++ assert_or_raise(dst_block.depth >= 0) ++ for seq2 in dst_block.ref_set_of_seq: ++ block_table_inner = self.block_tables.get(seq2.seq_id, []) ++ block_table_inner[dst_block.depth] = dst_block ++ # src npu, dst cpu ++ if dst_block.in_radix_cache: ++ in_radix_cache_block_mapping[src_block] = dst_block ++ ++ # 清空原块的 引用关系 ++ src_block.ref_set_of_seq = set() ++ # 分配出的 CPU 块引用计数为0 ++ swapping_src_blocks.append(src_block) ++ else: ++ logger.debug( ++ f"req_id: {seq_group.request_id} can't swap out src block: " ++ f"{'GPU' if src_block.device == Device.GPU else src_block.device == Device.CPU}, " ++ f"ref_count:{src_block.ref_count}, " ++ f"progress_status:{src_block.progress_status}, " ++ f"used_by:{[seq.seq_id for seq in src_block.ref_set_of_seq]}" ++ ) ++ # 对于在 radix_cache 里的 block,swap out 成功时,要触发 radix_cache 的 swapped_out 方法 ++ self.swapping_out_map_cpu2npu[seq_group.request_id] = [ ++ (_cpu_blk, _npu_blk.block_number) ++ for _npu_blk, _cpu_blk in in_radix_cache_block_mapping.items() ++ ] ++ logger.debug( ++ f"generate swap out map: src -> dst:{list(block_mapping.items())}" ++ ) ++ if swap_type == SwapType.SWAP_IN: ++ dependency_cpu_block = set() # 一切依赖的 CPU block ++ for seq in seq_group.get_seqs(status=seq_status): ++ block_table: BlockTable = self.block_tables.get(seq.seq_id, []) ++ for src_block in block_table: ++ self.generate_hash(seq) ++ if src_block.device == Device.GPU: ++ # 不应当在换入流程中还有正在换出的块 ++ assert_or_raise( ++ src_block.progress_status ++ == KVCacheProgressStatus.INITIALIZE ++ ) ++ src_block.ref_count += 1 ++ else: ++ # 是 CPU 块,检查CPU块是否正在换入 ++ dependency_cpu_block.add(src_block) # 依赖于这个块 ++ if ( ++ src_block.progress_status ++ == KVCacheProgressStatus.SWAPPING_IN ++ ): ++ src_block.ref_count += 1 ++ logger.debug( ++ f"block is SWAPPING_IN: {src_block}, added ref_count. " ++ f"used by seqs: {[_s.seq_id for _s in src_block.ref_set_of_seq]}" ++ ) ++ else: ++ src_block.progress_status = ( ++ KVCacheProgressStatus.SWAPPING_IN ++ ) ++ dst_block = allocator.allocate() ++ dst_block.depth = src_block.depth ++ dst_block.in_radix_cache = src_block.in_radix_cache ++ src_block.in_radix_cache = False ++ block_mapping[src_block] = dst_block ++ dst_block.set_hash(src_block.block_hash) ++ # ref_count 在搬移成功后再迁移到 dst NPU 块,以统计搬移过程中,使用这个块的 seq 数量 ++ # ref_set_of_seq 在搬移成功后再迁移到 dst NPU 块 ++ swapping_src_blocks.append(src_block) ++ ++ if dst_block.in_radix_cache: ++ in_radix_cache_block_mapping[src_block] = dst_block ++ # 对于在 radix_cache 里的 block,swap in 成功时,要触发 radix_cache 的 swapped_in 方法 ++ self.swapping_in_map_cpu2npu[seq_group.request_id] = [ ++ (_cpu_blk.block_number, _npu_blk) ++ for _cpu_blk, _npu_blk in in_radix_cache_block_mapping.items() ++ ] ++ logger.debug( ++ f"generate swap in map: src -> dst:{list(block_mapping.items())}" ++ ) ++ ++ # watcher 为这个 seq_group 增加 dependency_cpu_block ++ self.req_swap_in_watcher.add_dependency( ++ seq_group, list(dependency_cpu_block), block_mapping ++ ) ++ ++ logger.debug(f"swapping_src_blocks {swapping_src_blocks}") ++ self.seq_group_swapping_src_blocks[seq_group.request_id] = swapping_src_blocks ++ return block_mapping ++ ++ def allocate( ++ self, seq_group: SequenceGroup, allocated_len: int, is_adjust: bool = False ++ ) -> None: ++ """分配block块""" ++ ++ # NOTE: Here we assume that all sequences in the group have the same ++ # prompt. ++ seq = seq_group.seqs[0] ++ if not is_adjust: ++ seq.position_start += seq.scheduled_len ++ seq.scheduled_len = allocated_len ++ ++ left_cap = self.get_block_table_left_cap(seq) ++ if left_cap < allocated_len: ++ block_num = round_up(allocated_len - left_cap, self.block_size) ++ for _ in range(block_num): ++ block = self.gpu_allocator.allocate() ++ self.bind_seq_with_block(seq, block) ++ block.ref_count_up() ++ ++ def free(self, seq: Sequence) -> None: ++ """reimplement the free ++ 为seq释放上下文,同时保存block ++ """ ++ # the last output token has no kv cache, so dont handle the last token ++ block_table, token_ids = self._get_verified_data(seq) ++ if not block_table: ++ return ++ ++ with self.index_tree_lock: ++ blocks_to_free = self.index_tree_manager.insert_nodes( ++ token_ids, block_table ++ ) ++ for block in block_table: ++ block.ref_count_down() ++ self.unbind_seq_with_block(seq, block) ++ if blocks_to_free: ++ self._free_block_table(blocks_to_free) ++ del self.block_tables[seq.seq_id] ++ ++ def generate_swap_mapping_for_cache( ++ self, src_blocks: BlockTable, is_swap_in: bool = True ++ ): ++ """生成对缓存的block-> block 映射关系""" ++ mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} ++ allocator = self.gpu_allocator if is_swap_in else self.cpu_allocator ++ logger.debug( ++ f'generate_swap_mapping_for_cache {"SWAP_IN" if is_swap_in else "SWAP_OUT"}' ++ ) ++ ++ for src_block in src_blocks: ++ if src_block in mapping: ++ dst_block = mapping[src_block] ++ else: ++ dst_block = allocator.allocate() ++ dst_block.block_hash = src_block.block_hash ++ mapping[src_block] = dst_block ++ dst_block.set_hash(src_block.block_hash) ++ dst_block.depth = src_block.depth ++ dst_block.in_radix_cache = src_block.in_radix_cache ++ src_block.in_radix_cache = False ++ if is_swap_in: ++ src_block.progress_status = KVCacheProgressStatus.SWAPPING_IN ++ # 换出肯定没有 seq, 且 ref_count == 0。 换入成功后有 swap_in_watcher 管 双向引用 ++ # 什么时候使 src_block.ref_set_of_seq 置空:换出完成后,在 swap_in_watcher.update_swapped_in_blocks 里 ++ # 对于匹配到的 CPU 块,也应当立即地使其 block_table[depth] = block ++ ++ return mapping ++ ++ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: ++ """生成swap in block number 映射关系""" ++ ++ mapping = self.generate_swap_mapping(seq_group, SwapType.SWAP_IN) ++ if self.using_datasystem: ++ block_number_mapping = [ ++ (gpu_block.block_hash, gpu_block.block_number) ++ for _, gpu_block in mapping.items() ++ ] ++ else: ++ block_number_mapping = [ ++ (cpu_block.block_number, gpu_block.block_number) ++ for cpu_block, gpu_block in mapping.items() ++ ] ++ ++ return block_number_mapping ++ ++ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: ++ """生成swap out block number 映射关系""" ++ mapping = self.generate_swap_mapping(seq_group, SwapType.SWAP_OUT) ++ if self.using_datasystem: ++ block_number_mapping = [ ++ (gpu_block.block_number, gpu_block.block_hash) ++ for gpu_block, _ in mapping.items() ++ ] ++ else: ++ block_number_mapping = [ ++ (gpu_block.block_number, cpu_block.block_number) ++ for gpu_block, cpu_block in mapping.items() ++ ] ++ return block_number_mapping ++ ++ def swap_in_cache( ++ self, seq_group: SequenceGroup, src_blocks: BlockTable, cached_len_cpu: int ++ ) -> List[Tuple[int, int]]: ++ """生成swap in cache 的block number 映射关系""" ++ mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = ( ++ self.generate_swap_mapping_for_cache(src_blocks, is_swap_in=True) ++ ) ++ if self.using_datasystem: ++ block_number_mapping = [ ++ (gpu_block.block_hash, gpu_block.block_number) ++ for _, gpu_block in mapping.items() ++ ] ++ else: ++ block_number_mapping = [ ++ (cpu_block.block_number, gpu_block.block_number) ++ for cpu_block, gpu_block in mapping.items() ++ ] ++ cpu2npu_mapping = [ ++ (cpu_block.block_number, gpu_block) ++ for cpu_block, gpu_block in mapping.items() ++ ] ++ # cpu块换入前会自动改 block_tables,所以这里不必改了 ++ seq_group.seqs[0].update_cached_len(cached_len_cpu) ++ seq_group.seqs[0].position_start = seq_group.seqs[0].cached_prompt_len ++ # 这个 seq_group 在 watcher 里记录,在 _check_finished_swappings 的最后一组检查成功后移除 ++ # cache seq 的换入,它指示换入的CPU块、期待的CPU块,都是一样的,都是 mapping.keys() ++ self.req_swap_in_watcher.add_dependency( ++ seq_group, mapping.keys(), mapping, True ++ ) ++ self.seq_group_swapping_src_blocks[seq_group.request_id] = src_blocks ++ self.swapping_in_map_cpu2npu[seq_group.request_id] = cpu2npu_mapping ++ logger.debug( ++ f"cache swap in: dummy_seq_group request_id:{seq_group.request_id} mapping: src -> dst {mapping}" ++ ) ++ ++ return block_number_mapping ++ ++ def swap_out_cache( ++ self, seq_group: SequenceGroup, src_blocks: BlockTable ++ ) -> List[Tuple[int, int]]: ++ """生成swap out cache 的block number 映射关系""" ++ mapping = self.generate_swap_mapping_for_cache(src_blocks, is_swap_in=False) ++ if self.using_datasystem: ++ block_number_mapping = [ ++ (gpu_block.block_number, gpu_block.block_hash) ++ for gpu_block, _ in mapping.items() ++ ] ++ else: ++ block_number_mapping = [ ++ (gpu_block.block_number, cpu_block.block_number) ++ for gpu_block, cpu_block in mapping.items() ++ ] ++ cpu2npu_mapping: List[Tuple[PhysicalTokenBlock, int]] = [ ++ (cpu_block, gpu_block.block_number) ++ for gpu_block, cpu_block in mapping.items() ++ ] ++ ++ self.seq_group_swapping_src_blocks[seq_group.request_id] = src_blocks ++ self.swapping_out_map_cpu2npu[seq_group.request_id] = cpu2npu_mapping ++ ++ logger.debug( ++ f"cache swap out: dummy_seq_group request_id:{seq_group.request_id} mapping: src -> dst {mapping}" ++ ) ++ ++ return block_number_mapping ++ ++ def free_prefix_cache(self, block_table: BlockTable): ++ """释放前缀缓存""" ++ self._free_block_table(block_table) ++ ++ def free_cpu_prefix_cache( ++ self, cpu_evict_block_list: BlockTable, blocks_to_swap_out: Dict ++ ) -> None: ++ """释放cpu前缀缓存""" ++ self.free_prefix_cache(cpu_evict_block_list) ++ if self.using_datasystem: ++ cpu_del_block_list = filter( ++ lambda x: x.ref_count == 0, cpu_evict_block_list ++ ) ++ request_id = "CPU_CACHE_DELETE_" + "".join( ++ random.choices(string.ascii_letters + string.digits, k=5) ++ ) ++ blocks_to_swap_out[request_id] = [ ++ (-1, cpu_block.block_hash) ++ for cpu_block in cpu_del_block_list ++ ] ++ ++ def fetch_finish_swap_in_seq_group(self, req_id): ++ """ ++ 通过 req_id 获得该 seq_group 指示换入的 CPU Block,然后通过 watcher 确定哪些 req 换入成功。 ++ Args: ++ req_id: ++ ++ Returns: ++ 换入成功的 List[req_id] ++ ++ """ ++ blocks = self.seq_group_swapping_src_blocks.get(req_id, None) ++ logger.debug(f"cpu blocks finished swapped in: {blocks}") ++ # 这些请求换入成功了 ++ newly_swapped_in_reqs = self.req_swap_in_watcher.update_swapped_in_blocks( ++ blocks, self.block_tables ++ ) ++ return newly_swapped_in_reqs ++ ++ def free_swapping_in_blocks(self, req_id): ++ """ ++ Radix Cache 匹配到的 cpu 块换入成功了 ++ Args: ++ req_id: ++ ++ Returns: ++ ++ """ ++ with self.index_tree_lock: ++ self.index_tree_manager.swapped_in( ++ self.swapping_in_map_cpu2npu.get(req_id, []) ++ ) ++ ++ self.free_swapping_blocks(req_id) ++ if req_id in self.swapping_in_map_cpu2npu: ++ del self.swapping_in_map_cpu2npu[req_id] ++ ++ def bind_seq_with_block( ++ self, seq: Sequence, block: PhysicalTokenBlockExt, depth=None ++ ): ++ """将seq与对应物理块绑定""" ++ # Here does not add ref_count ++ assert_or_raise(block not in self.gpu_allocator.free_blocks, ++ f"block {block} in free_blocks") ++ block_table: BlockTable = self.block_tables.get(seq.seq_id, []) ++ if depth is None: ++ # append by default ++ depth = len(block_table) ++ if len(block_table) == depth: ++ block_table.append(block) ++ else: ++ block_table[depth] = block ++ block.depth = depth ++ block.ref_set_of_seq.add(seq) ++ self.block_tables[seq.seq_id] = block_table ++ ++ def unbind_seq_with_block(self, seq: Sequence, block: PhysicalTokenBlockExt): ++ """将seq 与对应块解绑""" ++ # Here does not minus ref_count ++ self.block_tables[seq.seq_id][block.depth] = None ++ block.ref_set_of_seq.discard(seq) # 删除且不抛异常 ++ ++ def _parse_cached_len( ++ self, total_cached_len, prefix_cache_npu, prefix_cache_cpu, blocks_to_copy ++ ): ++ npu_cache_to_copy, cpu_cache_to_copy = ( ++ blocks_to_copy if blocks_to_copy else (None, None) ++ ) ++ ++ need_cow = npu_cache_to_copy or cpu_cache_to_copy ++ if prefix_cache_npu and prefix_cache_cpu: ++ npu_cached_len = len(prefix_cache_npu) * self.block_size ++ if need_cow: ++ cpu_cached_len = len(prefix_cache_cpu) * self.block_size ++ else: ++ cpu_cached_len = total_cached_len - npu_cached_len ++ elif prefix_cache_cpu: ++ if need_cow: ++ cpu_cached_len = len(prefix_cache_cpu) * self.block_size ++ else: ++ cpu_cached_len = total_cached_len ++ npu_cached_len = 0 ++ elif prefix_cache_npu: ++ if need_cow: ++ npu_cached_len = len(prefix_cache_npu) * self.block_size ++ else: ++ npu_cached_len = total_cached_len ++ cpu_cached_len = 0 ++ else: ++ npu_cached_len = 0 ++ cpu_cached_len = 0 ++ ++ return npu_cached_len, cpu_cached_len ++ ++ def _get_verified_data(self, seq: Sequence) -> (BlockTable, List[int]): ++ """ ++ Given a sequence, filter out the block table that contains speculative tokens. ++ """ ++ block_table = self.block_tables.get(seq.seq_id, []) ++ tokens_verified = seq.data.prompt_token_ids + seq.data.output_token_ids[:-1] ++ if not seq.data.speculate_data: ++ return block_table, tokens_verified ++ ++ len_of_verified_kv = ( ++ len(seq.data.prompt_token_ids) ++ + len(seq.data.output_token_ids) ++ - len(seq.data.speculate_data.get_id_and_position_list()[0]) ++ ) ++ tokens_verified = (seq.data.prompt_token_ids + seq.data.output_token_ids)[:len_of_verified_kv] ++ return block_table, tokens_verified +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/scheduler.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/scheduler.py +new file mode 100644 +index 000000000..2d007b686 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/scheduler.py +@@ -0,0 +1,521 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++"""A scheduler that schedule the seq served in this iteration, inherit the vllm's scheduler.""" ++ ++import random ++import string ++from typing import Dict, List, Optional, Tuple ++ ++from vllm.config import CacheConfig, SchedulerConfig ++from vllm.core.block_manager import ActionType, BlockTable ++from vllm.core.scheduler import PreemptionMode ++from vllm.engine.prefix_sharing_type import PrefixSharingType ++from vllm.logger import init_logger ++from vllm.sequence import SequenceGroup, SequenceGroupStatus, SequenceStatus ++from vllm.utils import Device, SchedulingMode ++from vllm.core import scheduler ++ ++from dllm.dkvc.prefix_sharing_multi_level.core.block_manager import \ ++ AutoPrefixSharingBlockManager ++from dllm.dkvc.prefix_sharing_multi_level.index.radix_cache import \ ++ TreeNode ++from dllm.dkvc.util import assert_or_raise ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class SwappingSequenceGroup(scheduler.SwappingSequenceGroup): ++ def __init__( ++ self, ++ seq_group: SequenceGroup, ++ num_swapping_workers: int, ++ nodes: List[TreeNode] = None ++ ) -> None: ++ super().__init__(seq_group, num_swapping_workers) ++ self.nodes = nodes ++ ++ ++class AutoPrefixSharingScheduler(scheduler.DefaultScheduler): ++ def __init__( ++ self, ++ scheduler_config: SchedulerConfig, ++ cache_config: CacheConfig ++ ) -> None: ++ super().__init__( ++ scheduler_config=scheduler_config, ++ cache_config=cache_config) ++ ++ if cache_config.swap_space_bytes == 0: ++ logger.warning("No Swap Space for multi-level Prefix Cache") ++ self.gpu_usage_threshold: float = cache_config.prefix_sharing_kwargs.get( ++ "gpu_usage_threshold", 0.7 ++ ) ++ self.evict_gpu_usage: float = cache_config.prefix_sharing_kwargs.get( ++ "evict_gpu_usage", 0.2 ++ ) ++ self.memory_usage_threshold: float = cache_config.prefix_sharing_kwargs.get( ++ "memory_usage_threshold", 0.5 ++ ) ++ self.evict_memory_usage: float = cache_config.prefix_sharing_kwargs.get( ++ "evict_memory_usage", 0.2 ++ ) ++ self.using_datasystem = cache_config.prefix_sharing_kwargs.get( ++ "enable_datasystem", False ++ ) ++ self.block_manager = AutoPrefixSharingBlockManager( ++ block_size=self.cache_config.block_size, ++ scheduler_budget_len=self.scheduler_config.scheduler_budget_len, ++ num_gpu_blocks=self.cache_config.num_gpu_blocks, ++ num_cpu_blocks=self.cache_config.num_cpu_blocks, ++ block_sliding_window=self.cache_config.block_sliding_window, ++ sink_block_num=self.cache_config.sink_block_num, ++ prefix_sharing_type=self.prefix_sharing_type, ++ using_datasystem=self.using_datasystem, ++ ) ++ logger.info(f'Auto prefix sharing gpu_usage_threshold = {self.gpu_usage_threshold}, ' ++ f'evict_gpu_usage = {self.evict_gpu_usage}, ' ++ f'cpu_usage_threshold = {self.memory_usage_threshold}, ' ++ f'evict_cpu_usage = {self.evict_memory_usage}') ++ if self.block_manager.num_total_gpu_blocks * (self.evict_gpu_usage + 0.1) > \ ++ self.block_manager.num_total_cpu_blocks * (1-self.memory_usage_threshold): ++ raise RuntimeError(f"The memory_usage_threshold setting is too high, resulting in " ++ "insufficient CPU cache to store the number of NPU cache entries " ++ "evicted in one go. Please consider lowering the " ++ "memory_usage_threshold ,or reducing evict_gpu_usage, " ++ "or setting a larger swap_cache.") ++ ++ self.index_tree_manager = self.block_manager.index_tree_manager ++ self.swapping_in_prefix_cache: List[SwappingSequenceGroup] = [] ++ self.swapping_in_prefix_cache_record = set() ++ self.prefilling_prefix_cache = [] ++ self.prefill_left_budget: int = self.scheduler_config.scheduler_budget_len ++ self.swapping_out_prefix_cache = [] ++ ++ def get_needed_gpu_block_num(self, seq_group: SequenceGroup, action_type: ActionType) -> int: ++ """获取需要evict 的GPU block数量""" ++ num_free_gpu_blocks = self.block_manager.gpu_allocator.get_num_free_blocks() ++ num_need_gpu_blocks = self.block_manager.get_needed_gpu_blocks_num(seq_group, action_type) ++ ++ # lcwdbg print debug info ++ seq = seq_group.get_seqs()[0] ++ logger.debug(f"seq: {seq_group.request_id}, token_len: {seq.get_prompt_len()}, " ++ f"num_need_gpu_blocks: {num_need_gpu_blocks}, num_free_gpu_blocks: {num_free_gpu_blocks}") ++ ++ num_blocks_to_evict = num_need_gpu_blocks - num_free_gpu_blocks ++ return num_blocks_to_evict if num_blocks_to_evict > 0 else 0 ++ ++ def schedule_prefillings(self, blocks_to_copy: List[Tuple[int, int]], ++ blocks_to_swap_out: Dict[str, List[Tuple[int, int]]], ++ ignored_seq_groups: List[SequenceGroup], ++ scheduling: List[SequenceGroup], ++ left_budget: int, ++ scheduling_mode: SchedulingMode = None, ++ blocks_to_swap_in: Dict[str, List[Tuple[int, int]]] = None, ++ ) -> None: ++ """对 prefill请求进行调度""" ++ budget_len = None ++ is_strict = False # In strict mode, the left_budget must be longer than or equal to uncached_prompt_len ++ while self.prefilling and not self.swapping_out: ++ if left_budget <= 0: ++ break ++ ++ seq_group = self.prefilling[0] ++ if seq_group.multi_modal_data is not None and seq_group.seqs[0].get_uncached_prompt_len() > left_budget: ++ break ++ ++ if seq_group.prompt_len >= self.scheduler_config.max_num_batched_tokens: ++ logger.debug('Request %s prompt_len %d is overflow of max_num_batched_tokens %d', ++ seq_group.request_id, seq_group.prompt_len, ++ self.scheduler_config.max_num_batched_tokens) ++ seq_group.seqs[0].status = SequenceStatus.FINISHED_IGNORED ++ ignored_seq_groups.append(seq_group) ++ self.prefilling.pop(0) ++ continue ++ ++ # If the sequence group cannot be allocated, preempt. ++ num_new_seqs = seq_group.get_max_num_running_seqs() ++ ++ # prefill模式,需传入blocks_to_swap_in,配合cpu->npu缓存拷贝 ++ if not self.try_running_prefilling(seq_group, blocks_to_copy, blocks_to_swap_in, blocks_to_swap_out, ++ ignored_seq_groups, num_new_seqs, ++ budget_len=budget_len, is_strict=is_strict): ++ logger.debug(f"lcwdbg: not run seq {seq_group.request_id} this turn.") ++ break ++ ++ # 先移除waiting队列,放到临时的调度队列,在最后adjust_prefill_by_budget中觉得是否重新放入waiting还是running ++ seq_group = self.prefilling.pop(0) ++ scheduling.insert(0, seq_group) ++ ++ seq = seq_group.seqs[0] ++ prefilling_len = min(seq.get_uncached_prompt_len() + self.max_speculate_len, ++ left_budget if budget_len is None else budget_len) ++ if seq_group.multi_modal_data is not None and left_budget < prefilling_len: ++ scheduling.pop(0) ++ return ++ self._allocate(seq_group, prefilling_len) ++ left_budget -= prefilling_len ++ self.num_batched_prefill_tokens += prefilling_len ++ ++ seq_group.is_prompt = True ++ self.num_curr_seqs += num_new_seqs ++ ++ def avoid_swapped_out_seqs(self, seq_group: SequenceGroup): ++ """ ++ 在有 swapped_out 的请求的情况下,判断当前空闲 NPU 块数是否足够 seq_group 推理,若不够则返回 True ++ Args: ++ seq_group: ++ ++ Returns: ++ ++ """ ++ if self.swapping_out: ++ logger.debug( ++ f"not run seq_group: {seq_group.request_id}. There are {len(self.swapping_out)} seqs swapping out.") ++ return True ++ if self.swapped_out: ++ if len(self.swapped_out) > 1: ++ logger.debug(f"not run seq_group: {seq_group.request_id}. " ++ f"There are {len(self.swapped_out)} seqs swapped out.") ++ return True ++ seq_group_needed_gpu_block_num = self.get_needed_gpu_block_num(seq_group, ActionType.ALLOCATE) ++ if seq_group_needed_gpu_block_num > 0: ++ logger.debug(f"not run seq_group: {seq_group.request_id}. " ++ f"need more npu blocks: {seq_group_needed_gpu_block_num}") ++ return True ++ return False ++ ++ def try_running_prefilling(self, ++ seq_group: SequenceGroup, ++ blocks_to_copy: List[Tuple[int, int]], ++ blocks_to_swap_in: Dict[str, List[Tuple[int, int]]], ++ blocks_to_swap_out: Dict[str, List[Tuple[int, int]]], ++ ignored_seq_groups: List[SequenceGroup], ++ num_new_seqs: int, ++ budget_len: Optional[int] = None, ++ is_strict: bool = False) -> bool: ++ """计算prefill阶段能否正常执行block的分配""" ++ assert_or_raise( ++ seq_group.status == SequenceGroupStatus.WAITING, ++ "the seq status should be WAITTING", ++ ) ++ # should check: 若有 swapped_out 的 seq,那么当前NPU块必须很充足,足够这个 seq_group 推。否则又需要 swap out ++ self.check_cache_cap(blocks_to_swap_out, seq_group, ActionType.ALLOCATE) ++ if self.avoid_swapped_out_seqs(seq_group): ++ return False ++ if not self.check_swap_in_blocks( ++ seq_group=seq_group, blocks_to_swap_in=blocks_to_swap_in ++ ): ++ return False ++ return self.try_running(seq_group, blocks_to_copy, blocks_to_swap_out, ignored_seq_groups, ++ num_new_seqs, budget_len, is_strict) ++ ++ def check_swap_in_blocks( ++ self, ++ seq_group: SequenceGroup, ++ blocks_to_swap_in: Dict[str, List[Tuple[int, int]]], ++ ) -> bool: ++ """检查swap_in的blocks是否满足要求""" ++ num_blocks_required = 0 ++ prefix_cache_cpu_tmp = [] ++ seq = seq_group.seqs[0] ++ cached_len_cpu = 0 ++ depth = 0 ++ if not seq_group.seqs[0].prefix_matched: ++ cached_len_cpu, prefix_cache_npu_tmp, prefix_cache_cpu_tmp, _ = ( ++ self.block_manager.auto_prefix_sharing_multi_level(seq_group.seqs[0]) ++ ) ++ num_blocks_to_copy = 0 # not processing any COW blocks right now ++ num_blocks_required = len(prefix_cache_cpu_tmp) + num_blocks_to_copy ++ # bind npu and cpu block with seq ++ for block in prefix_cache_npu_tmp: ++ self.block_manager.bind_seq_with_block(seq, block, depth) ++ depth += 1 ++ num_blocks_free = self.block_manager.get_num_free_gpu_blocks() ++ logger.debug(f"num_blocks_required={num_blocks_required}") ++ if 0 < num_blocks_required <= num_blocks_free: ++ # 需要换入 CPU 块,且,有足够的空闲块供 CPU Block 换入 NPU ++ if prefix_cache_cpu_tmp: ++ for block in prefix_cache_cpu_tmp: ++ self.block_manager.bind_seq_with_block(seq, block, depth) ++ depth += 1 ++ # cpu块换入前会自动改 block_tables ++ blk_mapping_host2npu = self.block_manager.swap_in_cache( ++ seq_group, prefix_cache_cpu_tmp, cached_len_cpu ++ ) ++ self.swapping_in_prefix_cache.append( ++ SwappingSequenceGroup(seq_group, self.num_workers) ++ ) ++ self.swapping_in_prefix_cache_record.add(seq_group.request_id) ++ blocks_to_swap_in[seq_group.request_id] = blk_mapping_host2npu ++ return False # 需要等待cpu上的cache完全copy到npu之后才能运行 ++ elif 0 < num_blocks_required: ++ # 当前的空闲 NPU 块还不够给 匹配到的 CPU Block 换入 ++ logger.error( ++ f"err! num_blocks_required {num_blocks_required} > num_blocks_free {num_blocks_free}" ++ ) ++ if seq_group.request_id in self.swapping_in_prefix_cache_record: ++ return False ++ return True ++ ++ def check_cache_cap(self, ++ blocks_to_swap_out: Dict, ++ seq_group: Optional[SequenceGroup] = None, ++ action_type: Optional[ActionType] = None): ++ """检查gpu block 是否满足要求""" ++ threshold_gpu_blocks = int(self.block_manager.num_total_gpu_blocks * self.gpu_usage_threshold) ++ evict_gpu_num = (int(self.block_manager.num_total_gpu_blocks * self.evict_gpu_usage) + ++ random.randint(0, int(self.block_manager.num_total_gpu_blocks * 0.1))) ++ threshold_cpu_blocks_rtc = int(self.block_manager.num_total_cpu_blocks * self.memory_usage_threshold) ++ evict_cpu_num = (int(self.block_manager.num_total_cpu_blocks * self.evict_memory_usage) + ++ random.randint(0, int(self.block_manager.num_total_cpu_blocks * 0.1))) ++ num_gpu_blocks = self.block_manager.num_total_gpu_blocks - self.block_manager.get_num_free_gpu_blocks() ++ num_cpu_blocks = self.block_manager.num_total_cpu_blocks - self.block_manager.get_num_free_cpu_blocks() ++ ++ logger.info(f'Prefix sharing checking, ' ++ f'rtc gpu blocks: {num_gpu_blocks}, threshold: {threshold_gpu_blocks}, ' ++ f'total num: {self.block_manager.num_total_gpu_blocks}, ' ++ f'rtc cpu blocks: {num_cpu_blocks}, threshold: {threshold_cpu_blocks_rtc}, ' ++ f'total num: {self.block_manager.num_total_cpu_blocks}') ++ ++ if seq_group is not None: ++ needed_gpu_blocks = self.get_needed_gpu_block_num(seq_group, action_type) ++ evict_gpu_num = max(evict_gpu_num, needed_gpu_blocks) ++ ++ if num_cpu_blocks >= threshold_cpu_blocks_rtc: ++ with self.block_manager.index_tree_lock: ++ evict_cpu_num = max(evict_cpu_num, evict_gpu_num) ++ _, cpu_evict_block_list = self.index_tree_manager.select_evict( ++ evict_cpu_num, device=Device.CPU ++ ) ++ self.block_manager.free_cpu_prefix_cache(cpu_evict_block_list, blocks_to_swap_out) ++ ++ if num_gpu_blocks >= threshold_gpu_blocks: ++ with self.block_manager.index_tree_lock: ++ evict_gpu_num = min(evict_gpu_num, self.block_manager.get_num_free_cpu_blocks()) ++ cache_blocks_to_swap_out: BlockTable = self.index_tree_manager.select_swap_out(evict_gpu_num) ++ if cache_blocks_to_swap_out: ++ dummy_seq_group_swapping_out = SequenceGroup( ++ request_id='RTC_CACHE_SWAPPING_OUT_' + \ ++ ''.join(random.choices(string.ascii_letters + string.digits, k=5)), ++ seqs=[], ++ sampling_params=None ++ ) ++ self.swapping_out_prefix_cache.append( ++ SwappingSequenceGroup( ++ dummy_seq_group_swapping_out, self.num_workers ++ ) ++ ) ++ block_mapping_swap_out = self.block_manager.swap_out_cache( ++ dummy_seq_group_swapping_out, cache_blocks_to_swap_out) ++ else: ++ block_mapping_swap_out = [] ++ ++ if block_mapping_swap_out: ++ blocks_to_swap_out[dummy_seq_group_swapping_out.request_id] = block_mapping_swap_out ++ ++ def schedule_decodings(self, blocks_to_copy, blocks_to_swap_out, ignored_seq_groups, ++ scheduling, reqs_sync_swap_out, not_scheduled_seq_ids=None): ++ """对decode阶段进行调度""" ++ preempted: bool = False ++ idx = 0 ++ while self.decoding: ++ if len(self.decoding) <= idx: ++ break ++ ++ if not_scheduled_seq_ids is not None \ ++ and self.decoding[idx].seqs[0].seq_id in not_scheduled_seq_ids: ++ idx += 1 ++ continue ++ ++ num_seqs = self.decoding[idx].num_seqs(status=SequenceStatus.RUNNING) ++ if self.decoding_first_left_budget < (1 + self.max_speculate_len) * num_seqs: ++ break ++ ++ seq_group = self.decoding.pop(idx) ++ while not self.block_manager.can_append_slot(seq_group): ++ # 1. 优先淘汰offline Cache ++ if self.evict_offline_cache(seq_group, ActionType.APPEND_SLOT): ++ continue ++ ++ # 2. 将async swap(进行中)的改为sync swap,立即释放gpu block ++ if self.swapping_out: ++ waiting_seq_group = self.swapping_out.pop(0).seq_group ++ reqs_sync_swap_out.append(waiting_seq_group.request_id) ++ logger.debug(f"async swap to sync swap: {waiting_seq_group.request_id}") ++ if not waiting_seq_group.request_id.startswith( ++ "RTC_CACHE_SWAPPING_OUT" ++ ): ++ self.swapped_out.append(waiting_seq_group) ++ self.block_manager.free_swapping_blocks(waiting_seq_group.request_id, sync_swap_out=True) ++ # 下面这段可以提取为函数 ++ if (PrefixSharingType.using_cpu_cache(self.prefix_sharing_type) ++ and waiting_seq_group.request_id ++ in self.block_manager.swapping_out_map_cpu2npu): ++ with self.block_manager.index_tree_lock: ++ self.block_manager.index_tree_manager.swapped_out( ++ self.block_manager.swapping_out_map_cpu2npu[ ++ waiting_seq_group.request_id] ++ ) ++ del self.block_manager.swapping_out_map_cpu2npu[waiting_seq_group.request_id] ++ preempted = True ++ continue ++ ++ # 3. SYNC_SWAP the lowest-priority,立即释放gpu block后再尝试调度 ++ if self.decoding: ++ # Preempt the lowest-priority sequence groups. ++ victim_seq_group = self.decoding.pop(-1) ++ self._preempt(victim_seq_group, blocks_to_swap_out, ignored_seq_groups, PreemptionMode.SYNC_SWAP, ++ reqs_sync_swap_out) ++ preempted = True ++ continue ++ ++ # 4. 无法调度seq_group,则ASYNC_SWAP seq_group ++ self._preempt(seq_group, blocks_to_swap_out, ignored_seq_groups, PreemptionMode.ASYNC_SWAP) ++ preempted = True ++ break ++ else: ++ # Append new slots to the sequence group. ++ if seq_group.pd_info is not None: ++ blocks_to_copy.extend(seq_group.pd_info.blocks_to_copy) ++ self._append_slot(seq_group, blocks_to_copy) ++ seq_group.is_prompt = False ++ ++ scheduling.append(seq_group) ++ ++ self.decoding_first_left_budget -= (1 + self.max_speculate_len) * num_seqs ++ self.num_batched_decoding_tokens += num_seqs * (1 + self.max_speculate_len) ++ ++ return preempted ++ ++ def has_unfinished_reqs_ml_prefix_cache(self) -> bool: ++ """是否有未完成的请求""" ++ if self.swapping_in_prefix_cache_record or self.swapping_out_prefix_cache: ++ return True ++ ++ return False ++ ++ def _swap_out( ++ self, ++ seq_group: SequenceGroup, ++ blocks_to_swap_out: Dict[str, List[Tuple[int, int]]], ++ ignored_seq_groups: List[SequenceGroup], ++ sync_op: bool = False, ++ reqs_sync_swap_out: List[str] = None, ++ ) -> None: ++ if (len(self.swapped_out) >= self.scheduler_config.max_swapped_req_num ++ or not self.block_manager.can_swap_out(seq_group)): ++ self._abort_single_seq_group(seq_group, ignored_seq_groups) ++ return ++ ++ logger.debug(f'%s out seq_group %s', 'Sync swap' if sync_op else 'Async swapping', ++ seq_group.request_id) ++ ++ blocks_to_swap_out[seq_group.request_id] = self.block_manager.swap_out(seq_group) ++ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): ++ seq.status = SequenceStatus.SWAPPED ++ ++ if sync_op: ++ reqs_sync_swap_out.append(seq_group.request_id) ++ self.swapped_out.append(seq_group) ++ # 需要立即做后处理 ++ self.block_manager.free_swapping_blocks(seq_group.request_id, sync_swap_out=True) ++ # 下面这段可以提取为函数 ++ if (PrefixSharingType.using_cpu_cache(self.prefix_sharing_type) ++ and seq_group.request_id ++ in self.block_manager.swapping_out_map_cpu2npu): ++ with self.block_manager.index_tree_lock: ++ self.block_manager.index_tree_manager.swapped_out( ++ self.block_manager.swapping_out_map_cpu2npu[ ++ seq_group.request_id] ++ ) ++ del self.block_manager.swapping_out_map_cpu2npu[seq_group.request_id] ++ else: ++ self.swapping_out.append(SwappingSequenceGroup(seq_group, self.num_workers)) ++ ++ def _handle_swap_out_queue(self, finished_swap_out_req_ids, swapping_queue): ++ for swapping_seq_group in reversed(swapping_queue): ++ if swapping_seq_group.seq_group.request_id not in finished_swap_out_req_ids: ++ continue ++ swapping_seq_group.num_swapping_workers -= 1 ++ if swapping_seq_group.num_swapping_workers != 0: ++ continue ++ swapping_queue.remove(swapping_seq_group) ++ self.block_manager.free_swapping_blocks( ++ swapping_seq_group.seq_group.request_id ++ ) ++ ++ if not swapping_seq_group.seq_group.request_id.startswith( ++ "RTC_CACHE_SWAPPING_OUT" ++ ): ++ self.swapped_out.append(swapping_seq_group.seq_group) ++ if ( ++ PrefixSharingType.using_cpu_cache(self.prefix_sharing_type) ++ and swapping_seq_group.seq_group.request_id ++ in self.block_manager.swapping_out_map_cpu2npu ++ ): ++ with self.block_manager.index_tree_lock: ++ self.block_manager.index_tree_manager.swapped_out( ++ self.block_manager.swapping_out_map_cpu2npu[ ++ swapping_seq_group.seq_group.request_id ++ ] ++ ) ++ del self.block_manager.swapping_out_map_cpu2npu[ ++ swapping_seq_group.seq_group.request_id ++ ] ++ ++ def _check_finished_swappings(self, swap_finished_req_ids) -> None: ++ # Move local_thread.sync_swap_out_blocks to global free_blocks ++ self.block_manager.merge_sync_swap_out_blocks() ++ logger.debug("begin _check_finished_swappings") ++ ++ for single_worker_finished_req_ids in swap_finished_req_ids: ++ swap_in_req_ids, swap_out_req_ids = single_worker_finished_req_ids ++ logger.debug(f"swap_finished_req_ids swap_in:{swap_in_req_ids}, swap_out:{swap_out_req_ids}") ++ logger.debug(f"swapping_in list: {[_g.seq_group.request_id for _g in self.swapping_in]}") ++ # check swap in ++ for swapping_seq_group in reversed(self.swapping_in): ++ if swapping_seq_group.seq_group.request_id in swap_in_req_ids: ++ swapping_seq_group.num_swapping_workers -= 1 ++ if swapping_seq_group.num_swapping_workers == 0: ++ # 这里通知 watcher 哪些块换入成功了,以得到完全成功换入的 seqs,再放到 decoding 队列里 ++ fin_req_seq_group_list = self.block_manager.fetch_finish_swap_in_seq_group( ++ swapping_seq_group.seq_group.request_id) ++ for req_seq_group in fin_req_seq_group_list: ++ logger.debug(f"seq {req_seq_group.request_id} finished swap in, become decoding seq") ++ self.decoding.insert(0, req_seq_group) ++ ++ self.swapping_in.remove(swapping_seq_group) ++ # 这里面能通知 Radix Cache 块已换入 swapped_in ++ self.block_manager.free_swapping_in_blocks(swapping_seq_group.seq_group.request_id) ++ # swapping_seq_group 指示换入的块已换入,但它可能还依赖于其他正在换入中的块,因此不能直接插入 decoding 队列 ++ self._handle_swap_out_queue(swap_out_req_ids, self.swapping_out) ++ self._handle_swap_out_queue( ++ swap_out_req_ids, self.swapping_out_prefix_cache ++ ) ++ if PrefixSharingType.using_cpu_cache(self.prefix_sharing_type): ++ for swapping_seq_group in self.swapping_in_prefix_cache: ++ if swapping_seq_group.seq_group.request_id in swap_in_req_ids: ++ swapping_seq_group.num_swapping_workers -= 1 ++ if swapping_seq_group.num_swapping_workers == 0: ++ if swapping_seq_group not in self.swapping_in_prefix_cache: ++ continue ++ # 这里通知 watcher 哪些块换入成功了,以得到完全成功换入的 seqs,再放到 decoding 队列里 ++ fin_req_seq_group_list = self.block_manager.fetch_finish_swap_in_seq_group( ++ swapping_seq_group.seq_group.request_id) ++ for req_seq_group in fin_req_seq_group_list: ++ self.decoding.insert(0, req_seq_group) ++ ++ self.swapping_in_prefix_cache.remove(swapping_seq_group) ++ self.swapping_in_prefix_cache_record.remove(swapping_seq_group.seq_group.request_id) ++ self.block_manager.free_swapping_in_blocks(swapping_seq_group.seq_group.request_id) +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/swap_in_watcher.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/swap_in_watcher.py +new file mode 100644 +index 000000000..e0fa611bf +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/core/swap_in_watcher.py +@@ -0,0 +1,121 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from vllm.logger import init_logger ++ ++from dllm.dkvc.util import assert_or_raise ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class ReqSwapInWatcher: ++ """ ++ 用于跟踪当前正在换入的 reqs 依赖于哪些还在换入过程中的 cpu blocks。 ++ 仅当一个 req 依赖的所有 cpu blocks 都换入了,这个 req 所表示的 seq_group 才能执行。 ++ """ ++ def __init__(self): ++ # req_blocks 存储 正在换入的 req 依赖哪些正在换入的 blocks ++ # 格式: {seq_group: [block, block, ...]} ++ self.req_blocks = {} ++ # swapped_in_blocks 已换入的 block ++ self.swapped_in_blocks = set() ++ # prefill 阶段的 req,不要让它作为 update_swapped_in_blocks 的返回值,避免被放入 decoding 队列。 ++ self.prefill_reqs = set() ++ self.cpu_npu_block_mapping = {} # cpu block --> npu block 的映射 ++ ++ def add_dependency(self, req_seq_group, block_list, block_map, is_prefill=False): ++ """ ++ 添加 req 与换入中的 block 的依赖关系 ++ ++ Args: ++ req_seq_group: req 的 seq_group ++ block_list: req 期待换入的 cpu block list. ++ block_map: req 指示换入的 block dict, key 是 CPU Block, value 是对应的 NPU Block. block_map.keys 是 block_list 的子集 ++ is_prefill: 如果是 prefill 阶段的 Seq,那说明是 prefix match 命中的,不要在 update_swapped_in_blocks 里把它作为返回值 TODO ++ """ ++ self.req_blocks[req_seq_group] = block_list ++ self.cpu_npu_block_mapping.update(block_map) ++ if is_prefill: ++ self.prefill_reqs.add(req_seq_group) ++ logger.debug(f"req_id: {req_seq_group.request_id} is swapping in, " ++ f"dependency cpu blocks: {block_list}, is_prefill: {is_prefill}") ++ ++ def remove_req(self, req_id) -> None: ++ """ ++ Args: ++ req_id: 要移除的 req 的 ID,给 abort 用 ++ ++ """ ++ if req_id in self.req_blocks: ++ del self.req_blocks[req_id] ++ self._remove_unused_blocks() ++ ++ def update_swapped_in_blocks(self, blocks, block_tables): ++ """ ++ 检查已换入的 block,将 cpu block 的信息迁移到 npu block,确定哪些 req 换入成功, 并移除已换入的 req 和 不再使用的 block ++ Args: ++ blocks: 完成换入的 CPU block 列表 ++ block_tables: block_manager.block_tables ++ Returns: ++ newly_swapped_in_reqs ++ """ ++ if not blocks: ++ return set() ++ self.swapped_in_blocks.update(blocks) # 这些 block 已换入,下面实施 cpu --> npu block 的 seq 信息双向迁移 ++ ++ for cpu_blk in blocks: ++ npu_blk = self.cpu_npu_block_mapping.get(cpu_blk) ++ npu_blk.depth = cpu_blk.depth ++ npu_blk.ref_count = 1 + cpu_blk.ref_count ++ npu_blk.ref_set_of_seq = cpu_blk.ref_set_of_seq ++ for seq2 in npu_blk.ref_set_of_seq: ++ block_table_inner = block_tables[seq2.seq_id] ++ assert_or_raise(npu_blk.depth >= 0) ++ block_table_inner[npu_blk.depth] = npu_blk ++ cpu_blk.ref_count = 0 ++ cpu_blk.ref_set_of_seq = set() ++ ++ newly_swapped_in_reqs = set() ++ newly_swapped_in_reqs_of_prefill = [] ++ for req_seq_group, dependency_blocks in self.req_blocks.items(): ++ if all(block in self.swapped_in_blocks for block in dependency_blocks): ++ logger.debug(f"seq {req_seq_group.request_id} swapped in all cpu blocks.") ++ if req_seq_group not in self.prefill_reqs: ++ newly_swapped_in_reqs.add(req_seq_group) ++ else: ++ logger.debug(f"watcher: seq {req_seq_group.request_id} is prefill_reqs. not return.") ++ self.prefill_reqs.remove(req_seq_group) ++ newly_swapped_in_reqs_of_prefill.append(req_seq_group) ++ ++ for req_seq_group in newly_swapped_in_reqs_of_prefill: ++ del self.req_blocks[req_seq_group] ++ # 移除已换入的 req ++ for req_seq_group in newly_swapped_in_reqs: ++ del self.req_blocks[req_seq_group] ++ ++ # 移除不再使用的 block ++ self._remove_unused_blocks() ++ ++ return newly_swapped_in_reqs ++ ++ def _remove_unused_blocks(self): ++ """ ++ 移除不再使用的 block ++ """ ++ using_blocks = set() ++ for blocks in self.req_blocks.values(): ++ using_blocks.update(blocks) ++ # 移除掉已经不被换入中req使用的 blocks ++ self.swapped_in_blocks = using_blocks & self.swapped_in_blocks ++ # 移除掉已经用不上的 cpu->npu 映射 ++ self.cpu_npu_block_mapping = {k: v for k, v in self.cpu_npu_block_mapping.items() if k in using_blocks} +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/LRU.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/LRU.py +new file mode 100644 +index 000000000..0245ef4b9 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/LRU.py +@@ -0,0 +1,81 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++class DoublyLinkedListNode: ++ def __init__(self, value=None, prev_node=None, next_node=None): ++ self.value = value ++ self.prev = prev_node # 指向前一个节点 ++ self.next = next_node # 指向后一个节点 ++ ++ ++class DoublyLinkedList: ++ def __init__(self): ++ self.head = None ++ self.tail = None ++ ++ def push_front(self, value): ++ """Push new node to the front of a linked list""" ++ new_node = DoublyLinkedListNode(value) ++ if self.head is None: ++ # 如果链表为空,初始化头和尾 ++ self.head = new_node ++ self.tail = new_node ++ else: ++ # 将新节点添加到链表的头 ++ new_node.next = self.head ++ self.head.prev = new_node ++ self.head = new_node ++ ++ def delete(self, node: DoublyLinkedListNode): ++ """delete node""" ++ if node.prev is not None: ++ node.prev.next = node.next ++ else: ++ self.head = node.next ++ if node.next is not None: ++ node.next.prev = node.prev ++ else: ++ self.tail = node.prev ++ ++ ++class LRU: ++ def __init__(self): ++ # dll, map ++ self.doubly_linked_list = DoublyLinkedList() ++ self.tree_node_to_DLL_node_map = dict() # key: TreeNode -> DoublyLinkedListNode ++ ++ # The last_access_time of element must be the newest ++ # The element is TreeNode ++ def push_front(self, element): ++ """Push new node to the front of a linked list""" ++ if element in self.tree_node_to_DLL_node_map: ++ # refresh: have duplicate element, delete it, then push to front. ++ self.delete(element) ++ self.doubly_linked_list.push_front(element) ++ self.tree_node_to_DLL_node_map[element] = self.doubly_linked_list.head ++ ++ # The element is TreeNode ++ def delete(self, element): ++ """delete node""" ++ dll_node = self.tree_node_to_DLL_node_map[element] ++ self.doubly_linked_list.delete(dll_node) ++ del self.tree_node_to_DLL_node_map[element] ++ ++ def get_tail(self) -> DoublyLinkedListNode: ++ """get tail node""" ++ return self.doubly_linked_list.tail ++ ++ def get_head(self) -> DoublyLinkedListNode: ++ """get head node""" ++ return self.doubly_linked_list.head +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/__init__.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/index_tree_manager.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/index_tree_manager.py +new file mode 100644 +index 000000000..b20e02df8 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/index_tree_manager.py +@@ -0,0 +1,149 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from hashlib import sha1 ++from typing import List, Tuple ++ ++from vllm.block import PhysicalTokenBlock ++from vllm.core.block_manager import BlockTable ++from vllm.logger import init_logger ++from vllm.sequence import Sequence ++ ++from dllm.dkvc.prefix_sharing_multi_level.index.radix_cache import ( ++ RadixCache, TreeNode) ++from dllm.dkvc.util import assert_or_raise ++ ++logger = init_logger(__name__) ++ ++ ++class IndexTreeManager: ++ def __init__( ++ self, ++ block_size, ++ cpu_block_num, ++ npu_block_num, ++ ) -> None: ++ """IndexManager. ++ ++ Args: ++ block_size ++ allocate_block: the function to allocate a block ++ free_block: the function to free a block ++ """ ++ self.tree_cache = RadixCache(block_size, cpu_block_num, npu_block_num) ++ self.block_size = block_size ++ self.cpu_block_num = cpu_block_num ++ self.npu_block_num = npu_block_num ++ ++ @staticmethod ++ def generate_hash_key(is_first_block: bool, prev_block_hash: str, cur_block_token_ids: Tuple[int]) -> int: ++ """生成当前block的hash值""" ++ assert_or_raise((prev_block_hash == "") == is_first_block) ++ prev_block_hash = int(prev_block_hash) if prev_block_hash else -1 ++ hash_object = sha1() ++ data = (is_first_block, int(prev_block_hash), *cur_block_token_ids) ++ for item in data: ++ hash_object.update(str(item).encode('utf-8')) ++ return str(int(hash_object.hexdigest(), 16)) ++ ++ def match_prefix(self, seq: Sequence): ++ """ ++ 查找匹配点位置,直接更新seq的cache_nodes(匹配的Radix Tree Nodes) ++ """ ++ cached_len, matched_nodes_in_npu, matched_nodes_in_cpu, need_copy_block_list = self._match_prefix( ++ seq.data.prompt_token_ids, seq) ++ logger.debug(f'Seq %s Auto prefix sharing len %d', seq.seq_id, cached_len) ++ ++ return cached_len, matched_nodes_in_npu, matched_nodes_in_cpu, need_copy_block_list ++ ++ def insert_nodes(self, keys: List[int], values: BlockTable) -> List[PhysicalTokenBlock]: ++ """ ++ 根据原始key插入block到radix tree, prompt_token_ids按block切分,与value中的block一一映射 ++ """ ++ # 推理完成后,保存token和block到radix tree ++ need_free_blocks = [] ++ # the last output token has no kv cache, so dont handle the last token ++ split_keys = self.split_keys(keys) ++ if len(split_keys) == 0: ++ return [] ++ ++ if len(values) > len(split_keys): ++ for value in values[len(split_keys):]: ++ need_free_blocks.append(value) ++ values = values[:len(split_keys)] ++ before_hash = "" ++ for i, block in enumerate(values): ++ if not block.block_hash: ++ block.set_hash(IndexTreeManager.generate_hash_key(before_hash == "", before_hash, split_keys[i])) ++ before_hash = block.block_hash ++ need_free_blocks.extend(self.tree_cache.insert_nodes(split_keys, values)) ++ return need_free_blocks ++ ++ def split_keys(self, keys: List[int]) -> List[Tuple[int]]: ++ """将keys 分割成不超过block_size的元组""" ++ split_keys: List[Tuple[int]] = [] ++ while len(keys) >= self.block_size: ++ split_keys.append(tuple(keys[:self.block_size])) ++ keys = keys[self.block_size:] ++ if len(keys) != 0: ++ split_keys.append(tuple(keys)) ++ return split_keys ++ ++ def select_evict(self, num_blocks, device) -> Tuple[List[PhysicalTokenBlock], List[PhysicalTokenBlock]]: ++ """ ++ 淘汰目标对象:叶子节点按访问时间LRU排序,把最长时间不访问且引用计数为1的block,从 cache淘汰 ++ 返回:npu 上应当释放的块 List, cpu 上应当释放的块 List ++ """ ++ return self.tree_cache.select_evict(num_blocks, device) ++ ++ def select_swap_out(self, num_blocks: int): ++ """选出num_block个swapout的block""" ++ return self.tree_cache.select_swap_out(num_blocks) ++ ++ def swapped_in(self, swapped_pair_list: List[Tuple[int, PhysicalTokenBlock]]): ++ """根据swapped_pair_list映射关系swap in""" ++ self.tree_cache.swapped_in(swapped_pair_list) ++ ++ def swapped_out( ++ self, ++ swapped_pair_list: List[Tuple[PhysicalTokenBlock, int]], ++ ): ++ """根据swapped_pair_list映射关系swap out""" ++ self.tree_cache.swapped_out(swapped_pair_list) ++ ++ def refresh_nodes_status(self, nodes: List[TreeNode]): ++ """初始化节点状态为INITIALIZE状态""" ++ if nodes: ++ return self.tree_cache.refresh_node_status(nodes) ++ return None ++ ++ def delete_leaf(self, node: TreeNode): ++ """删除node节点""" ++ return self.tree_cache.delete_leaf(node) ++ ++ def get_num_nodes(self, device=None) -> int: ++ """获取指定device类型的节点数量""" ++ return self.tree_cache.get_num_nodes(device) ++ ++ def get_num_nodes_can_swap_out(self) -> int: ++ """获取可以换出的节点数量""" ++ return self.tree_cache.get_num_nodes_can_swap_out() ++ ++ # 把 seq 传进去,就是为了通过双向索引进行检查,防止 ref_count 重复加 ++ def _match_prefix(self, token_ids: List[int], seq) -> Tuple[int, List[PhysicalTokenBlock], ++ List[PhysicalTokenBlock], ++ List[PhysicalTokenBlock]]: ++ split_keys = self.split_keys(token_ids) ++ if len(split_keys) == 0: ++ return 0, [], [], [] ++ return self.tree_cache.trie_match(len(token_ids), split_keys, seq) +diff --git a/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/radix_cache.py b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/radix_cache.py +new file mode 100644 +index 000000000..7bad2bef9 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/prefix_sharing_multi_level/index/radix_cache.py +@@ -0,0 +1,619 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++ ++from collections import defaultdict ++from typing import List, Optional, Tuple ++ ++from vllm.logger import init_logger ++from vllm.prefix_sharing.index.LRU import LRU, DoublyLinkedListNode ++from vllm.utils import Device ++ ++# PhysicalTokenBlock replaced with PhysicalTokenBlockExt ++from dllm.dkvc.prefix_sharing_multi_level.block import \ ++ KVCacheProgressStatus ++from dllm.dkvc.prefix_sharing_multi_level.block import \ ++ PhysicalTokenBlockExt as PhysicalTokenBlock ++from dllm.dkvc.util import assert_or_raise ++ ++logger = init_logger(__name__) ++ ++ ++class TreeNode: ++ def __init__(self, num_valid_token: int = 0, block: PhysicalTokenBlock = None): ++ self.children = defaultdict(TreeNode) # key: id, value: TreeNode ++ self.parent: Optional[TreeNode] = None ++ self.block: Optional[PhysicalTokenBlock] = block if block else None ++ self.num_valid_token = num_valid_token ++ self.key: Tuple[int] = None ++ self.id: int = -1 # id 应由 TreeNodePool 生成 ++ ++ def clear(self): ++ """重置节点""" ++ self.children = defaultdict(TreeNode) ++ self.parent = None ++ self.block = None ++ self.num_valid_token = 0 ++ self.key = None ++ ++ def insert_child(self, key: Tuple[int], block: PhysicalTokenBlock, new_node): ++ """插入一个child""" ++ block.in_radix_cache = True ++ new_node.parent = self ++ new_node.block = block ++ new_node.num_valid_token = len(key) ++ new_node.key = key ++ self.children[new_node.id] = new_node ++ ++ return new_node ++ ++ def ref_count_up(self): ++ """增加节点的引用计数""" ++ self.block.ref_count += 1 ++ ++ def ref_count_down(self): ++ """减少节点的引用计数""" ++ self.block.ref_count -= 1 ++ ++ ++class TreeNodePool: ++ def __init__(self): ++ self._pool = [] ++ self._create_count = 0 # 总共创建过的对象数量 ++ ++ def acquire(self): ++ """ ++ Acquire an object from the pool, creating a new one if necessary. ++ :return: An TreeNode object from the pool ++ """ ++ if self._pool: ++ obj = self._pool.pop() ++ obj.clear() ++ return obj ++ obj = TreeNode() ++ obj.id = self._create_count ++ self._create_count += 1 ++ return obj ++ ++ def release(self, obj): ++ """ ++ Release an object back into the pool. ++ :param obj: The object to be released back into the pool ++ """ ++ self._pool.append(obj) ++ ++ def size(self): ++ """ ++ Get the current size of the pool. ++ :return: The number of objects currently in the pool ++ """ ++ return len(self._pool) ++ ++ ++# 按位比较key和seq的单个元素 ++# 如果seq比key长,则zip以key为准 ++def match(key: Tuple, seq: Tuple) -> int: ++ """获取有多少个相同位置一样长的key,seq元素""" ++ match_len = 0 ++ for left, right in zip(key, seq): ++ if left != right: ++ break ++ match_len += 1 ++ return match_len ++ ++ ++class RadixCache: ++ def __init__(self, block_size, cpu_block_num, npu_block_num): ++ logger.debug(f"init RadixCache. block_size is {block_size}") ++ self.root_node = TreeNode() ++ self.root_node.block = PhysicalTokenBlock(None, 0, 0) ++ self.block_size = block_size ++ self.cpu_block_num = cpu_block_num ++ self.npu_block_num = npu_block_num ++ self.block_num_to_cpu_tn = [None] * cpu_block_num # cpu block_num --> cpu TreeNode ++ self.block_num_to_npu_tn = [None] * npu_block_num # npu block_num --> npu TreeNode ++ self.node_in_cpu_count = 0 ++ self.node_in_npu_count = 0 ++ self.lru = LRU() ++ self.tree_node_pool = TreeNodePool() ++ # 删掉的TreeNode 先追加在这里,延迟到空闲时回收 ++ ++ @staticmethod ++ def refresh_node_status(nodes: List[TreeNode]): ++ """将对应节点的状态设置为INITIALIZE""" ++ for node in nodes: ++ node.block.progress_status = KVCacheProgressStatus.INITIALIZE ++ ++ @staticmethod ++ def _build_copy_block_list(match_node, already_matched): ++ if not already_matched: ++ match_node.ref_count_down() ++ # 因为做了增计数,拷贝到一半不会被 evict 或者换出。但是拷贝完成后需要对原 Block 做减计数。 ++ # 判断节点是在 CPU 还是 NPU ++ if match_node.block.device == Device.GPU: ++ need_copy_block_list = [match_node.block, None] ++ else: ++ need_copy_block_list = [None, match_node.block] ++ return need_copy_block_list ++ ++ @staticmethod ++ def _check_node_in_cpu( ++ matched_nodes_in_npu, matched_nodes_in_cpu, match_node, already_matched ++ ): ++ if match_node.block.device == Device.GPU: ++ matched_nodes_in_npu.append(match_node.block) ++ in_cpu = False ++ else: ++ if not already_matched: ++ match_node.ref_count_down() # 不用增计数,swap in 完成后,会在 update_swapped_in_blocks 对 npu block 增计数 ++ in_cpu = True ++ matched_nodes_in_cpu.append(match_node.block) ++ return in_cpu ++ ++ def all_check(self): ++ """检查 TN, LRU, TREE 三者完全一致""" ++ # 递归获取 TreeNode 到一个 set 且确保没有重复 ++ dfs_visited = set() ++ ++ def dfs(node: TreeNode): ++ if node in dfs_visited: ++ raise ValueError(f"Found duplicate TreeNode {node}") ++ dfs_visited.add(node) ++ for _, child in node.children.items(): ++ dfs(child) ++ ++ dfs(self.root_node) ++ dfs_visited.discard(self.root_node) ++ ++ # 获取 LRU 所有节点到一个 set 并检查,需和 TreeNode 完全一致 ++ lru_treenode = self.lru.get_all_values() ++ lru_set = set(lru_treenode) ++ assert_or_raise(len(lru_set) == len(lru_treenode)) # 确保 LRU 所有节点不重复 ++ assert_or_raise(lru_set == dfs_visited) ++ ++ # 检查两个 TN 的正确性: 块号/块种类到 TreeNode的 block。 TreeNode 和 set1 完全一致 ++ _tn_visited = set() ++ for block_num, _ in enumerate(self.block_num_to_npu_tn): ++ if self.block_num_to_npu_tn[block_num] is not None: ++ cur_treenode: TreeNode = self.block_num_to_npu_tn[block_num] ++ assert_or_raise(block_num == cur_treenode.block.block_number) ++ assert_or_raise(cur_treenode.block.device == Device.GPU) ++ assert_or_raise(cur_treenode not in _tn_visited) ++ _tn_visited.add(cur_treenode) ++ for block_num, _ in enumerate(self.block_num_to_cpu_tn): ++ if self.block_num_to_cpu_tn[block_num] is not None: ++ cur_treenode: TreeNode = self.block_num_to_cpu_tn[block_num] ++ assert_or_raise(block_num == cur_treenode.block.block_number) ++ assert_or_raise(cur_treenode.block.device == Device.CPU) ++ assert_or_raise(cur_treenode not in _tn_visited) ++ _tn_visited.add(cur_treenode) ++ assert_or_raise(_tn_visited == dfs_visited) ++ ++ # 检查 每个 TreeNode 的 ref_count 都 >= 1 且小于等于其父节点 ++ for node in dfs_visited: ++ assert_or_raise(node.block.ref_count >= 1) ++ if node.parent != self.root_node: ++ assert_or_raise(node.block.ref_count <= node.parent.block.ref_count) ++ if node.block.ref_count > node.parent.block.ref_count: ++ self.pretty_print() ++ assert_or_raise(node.block.ref_count <= node.parent.block.ref_count) ++ ++ return True ++ ++ def insert_nodes( ++ self, ++ split_keys: List[Tuple[int]], ++ block_table: List[PhysicalTokenBlock], ++ ) -> List[PhysicalTokenBlock]: ++ """根据根据split_keys 与block_table往缓存中从插入节点""" ++ return self._insert_helper(self.root_node, split_keys, block_table) ++ ++ def pretty_print(self): ++ """打印当前树的统计信息""" ++ self._print_helper(self.root_node, 2) ++ current_node = self.lru.get_tail() ++ ++ while current_node is not None: ++ tree_node: TreeNode = current_node.value ++ device = "NPU" ++ if tree_node.block.device == Device.CPU: ++ device = "CPU" ++ logger.debug(f"{device}_{tree_node.block.block_number}") ++ current_node = current_node.prev ++ ++ def select_evict(self, num_blocks, device: Device) -> Tuple[List[PhysicalTokenBlock], List[PhysicalTokenBlock]]: ++ """ 按num_tokens个数淘汰,node的value存的是内存的索引,不是真的内存。kvcache修改函数:self._free_value_block""" ++ npu_evict_block_list = [] ++ cpu_evict_block_list = [] ++ num_evicted = 0 ++ cur_dll_node = self.lru.get_tail() ++ while num_evicted < num_blocks: ++ if cur_dll_node is None: ++ break ++ reverse_next_dll_node: DoublyLinkedListNode = cur_dll_node.prev ++ cur_node: TreeNode = cur_dll_node.value ++ ++ can_evict = ( ++ not cur_node.block.ref_set_of_seq ++ and len(cur_node.children) == 0 ++ and cur_node.block.device == device ++ and cur_node.block.progress_status != KVCacheProgressStatus.SWAPPING_IN ++ ) ++ if can_evict: ++ # evict ++ cur_node.block.in_radix_cache = False ++ if cur_node.block.device == Device.GPU: ++ npu_evict_block_list.append(cur_node.block) ++ else: ++ cpu_evict_block_list.append(cur_node.block) ++ num_evicted += 1 ++ self.delete_leaf(cur_node) ++ cur_dll_node = reverse_next_dll_node ++ # 释放得不够多,且要求释放GPU块,那么从后往前再释放一遍。这一次,CPU和GPU块都可选中 ++ if num_evicted < num_blocks and device == Device.GPU: ++ cur_dll_node = self.lru.get_tail() ++ while num_evicted < num_blocks: ++ if cur_dll_node is None: ++ break ++ reverse_next_dll_node: DoublyLinkedListNode = cur_dll_node.prev ++ cur_node: TreeNode = cur_dll_node.value ++ ++ can_evict = ( ++ not cur_node.block.ref_set_of_seq ++ and len(cur_node.children) == 0 ++ and cur_node.block.progress_status ++ != KVCacheProgressStatus.SWAPPING_IN ++ ) ++ if can_evict: ++ # evict ++ cur_node.block.in_radix_cache = False ++ if cur_node.block.device == Device.GPU: ++ npu_evict_block_list.append(cur_node.block) ++ num_evicted += 1 ++ else: ++ cpu_evict_block_list.append(cur_node.block) ++ # CPU块,它不算释放块数 ++ self.delete_leaf(cur_node) ++ cur_dll_node = reverse_next_dll_node ++ ++ return npu_evict_block_list, cpu_evict_block_list ++ ++ def select_swap_out(self, num_blocks) -> List[PhysicalTokenBlock]: ++ """按num_tokens个数swap,node的value存的是内存的索引,不是真的内存。kvcache释放函数:self._modify_tree_value_block""" ++ npu_swap_block_list = [] ++ num_selected = 0 ++ cur_dll_node = self.lru.get_tail() ++ while num_selected < num_blocks: ++ if cur_dll_node is None: ++ break ++ reverse_next_dll_node: DoublyLinkedListNode = cur_dll_node.prev ++ cur_node: TreeNode = cur_dll_node.value ++ if ( ++ not cur_node.block.ref_set_of_seq ++ and cur_node.block.device == Device.GPU ++ and cur_node.block.progress_status == KVCacheProgressStatus.INITIALIZE ++ ): ++ # 它的子节点不可能在 SWAPPING_IN, SWAPPING_IN 的块,其所有祖先,引用计数必然 > 1 ++ cur_node.block.progress_status = KVCacheProgressStatus.SWAPPING_OUT ++ npu_swap_block_list.append(cur_node.block) ++ num_selected += 1 ++ cur_dll_node = reverse_next_dll_node ++ ++ return npu_swap_block_list ++ ++ def trie_match(self, key_len: int, split_keys: List[Tuple[int]], seq) -> (int, List[PhysicalTokenBlock], ++ List[PhysicalTokenBlock], ++ List[PhysicalTokenBlock]): ++ """对split_keys进行前缀匹配""" ++ tree_node_list = [] # for update lru ++ matched_len = 0 ++ matched_nodes_in_npu: List[PhysicalTokenBlock] = [] ++ matched_nodes_in_cpu: List[PhysicalTokenBlock] = [] # 匹配命中的节点 ++ need_copy_block_list = [] # npu, cpu. example: [(), (PhysicalTokenBlock(6))] ++ if len(split_keys) == 0: ++ return matched_len, matched_nodes_in_npu, matched_nodes_in_cpu, need_copy_block_list ++ cur_node = self.root_node ++ # 本次每匹配到一个节点,就向 nodes 列表追加一个 TreeNode,需要按相反顺序更新 LRU ++ prev_is_cpu_block = False ++ for split_key in split_keys: ++ # 找到子节点中最长匹配的节点 ++ match_node, matched_len, max_match_len = self._match_node_info( ++ matched_len=matched_len, cur_node=cur_node, split_key=split_key ++ ) ++ # 没有匹配,结束 ++ if max_match_len == 0: ++ break ++ # 这种情况处理起来比较复杂,所以这里直接停止匹配 ++ if prev_is_cpu_block and match_node.block.device == Device.GPU: ++ break ++ already_matched = seq in match_node.block.ref_set_of_seq ++ if not already_matched: ++ match_node.ref_count_up() ++ if max_match_len == match_node.num_valid_token and match_node.num_valid_token == self.block_size: ++ # 完全匹配到一个 block 的全部 token ++ tree_node_list.append(match_node) ++ # 若prefill的token完全匹配,后续调度逻辑会存在问题,则保留最后一个token来做prefill ++ if matched_len == key_len: ++ matched_len -= 1 ++ if match_node.num_valid_token == 1: ++ if not already_matched: ++ match_node.ref_count_down() ++ tree_node_list.pop() ++ break # 不用加入 matched_nodes_in_npu or matched_nodes_in_cpu ++ # 判断节点是在 CPU 还是 NPU ++ prev_is_cpu_block = self._check_node_in_cpu( ++ matched_nodes_in_npu=matched_nodes_in_npu, ++ matched_nodes_in_cpu=matched_nodes_in_cpu, ++ match_node=match_node, ++ already_matched=already_matched, ++ ) ++ else: ++ # 半匹配,需拷贝 ++ need_copy_block_list = self._build_copy_block_list( ++ match_node=match_node, already_matched=already_matched ++ ) ++ # 若使用半匹配,此处应当判断节点是在 CPU 还是 NPU,对CPU块改 SWAPPING_IN 状态 ++ # 节点满且完全匹配,递归匹配下一个节点 ++ if max_match_len == self.block_size: ++ cur_node = match_node ++ else: ++ break ++ # LRU 从后往前地更新 ++ self._update_lru_list(tree_node_list) ++ return matched_len, matched_nodes_in_npu, matched_nodes_in_cpu, need_copy_block_list ++ ++ def swapped_in(self, swapped_pair_list: List[Tuple[int, PhysicalTokenBlock]]): ++ """执行swap in""" ++ for cpu_block_num, npu_block in swapped_pair_list: ++ assert_or_raise( ++ self.block_num_to_npu_tn[npu_block.block_number] is None, ++ f"swapped in, duplicated nodes: {npu_block.block_number}", ++ ) ++ assert_or_raise( ++ self.block_num_to_cpu_tn[cpu_block_num] is not None, ++ f"swapped in, not exist cpu block: {cpu_block_num}", ++ ) ++ assert_or_raise(npu_block.in_radix_cache is True) ++ tree_node: TreeNode = self.block_num_to_cpu_tn[cpu_block_num] ++ logger.debug(f"radix cache swapped_in cpu: {tree_node.block} -> npu : {npu_block}") ++ tree_node.block.progress_status = KVCacheProgressStatus.INITIALIZE ++ tree_node.block = npu_block ++ # 引用计数在 generate_swap_mapping_for_cache 里已经赋值 ++ self.block_num_to_npu_tn[npu_block.block_number] = tree_node ++ self.block_num_to_cpu_tn[cpu_block_num] = None # Radix Tree 不再管这个 CPU 块 ++ ++ def swapped_out(self, swapped_pair_list: List[Tuple[PhysicalTokenBlock, int]]): ++ """执行swap out""" ++ for cpu_block, npu_block_num in swapped_pair_list: ++ if cpu_block.in_radix_cache is False: ++ logger.debug("debug pretty print") ++ self.pretty_print() ++ assert_or_raise( ++ cpu_block.in_radix_cache is True, ++ f"err cpu_block: {cpu_block}, npu_blk_num: {npu_block_num}", ++ ) ++ tree_node: TreeNode = self.block_num_to_npu_tn[npu_block_num] ++ logger.debug(f"radix cache swapped_out cpu: {cpu_block} <- npu: {tree_node.block}") ++ try: ++ tree_node.block.progress_status = KVCacheProgressStatus.INITIALIZE ++ except Exception as e: ++ raise e ++ ++ tree_node.block = cpu_block ++ self.block_num_to_cpu_tn[cpu_block.block_number] = tree_node ++ self.block_num_to_npu_tn[npu_block_num] = None ++ ++ def get_num_nodes(self, device: Device) -> int: ++ """获取指定device的节点数量""" ++ count = 0 ++ for tree_node in self.lru.tree_node_to_DLL_node_map: ++ if tree_node.block.device == device: ++ count += 1 ++ return count ++ ++ def get_max_matched_child(self, key: Tuple[int], node: TreeNode) -> Tuple[Tuple[int], TreeNode, int]: ++ """遍历获取node所有子节点,选出子节点key列表元素与参数key元素匹配最多的节点""" ++ max_match_len: int = 0 ++ match_node: Optional[TreeNode] = None ++ match_key: Optional[Tuple[int]] = None ++ for child in node.children.values(): ++ ++ c_key = child.key ++ match_len = match(c_key, key) ++ ++ if match_len > max_match_len: ++ match_key = c_key ++ max_match_len = match_len ++ match_node = child ++ ++ # 匹配到self.block_size,不可能更长了,不用再找 ++ if max_match_len == self.block_size: ++ break ++ ++ return match_key, match_node, max_match_len ++ ++ def get_num_nodes_can_swap_out(self) -> int: ++ """获取能够swapout的节点数量""" ++ def dfs_(cur_node: TreeNode): ++ num_nodes = 0 ++ if cur_node.block.device == Device.CPU: ++ return 0 ++ if not cur_node.block.ref_set_of_seq: ++ num_nodes += 1 ++ ++ for child_node in cur_node.children.values(): ++ num_nodes += dfs_(child_node) ++ ++ return num_nodes ++ ++ return dfs_(self.root_node) ++ ++ def delete_leaf(self, node: TreeNode): ++ """删除叶子节点""" ++ if node.block.device == Device.CPU: ++ self.block_num_to_cpu_tn[node.block.block_number] = None ++ else: ++ self.block_num_to_npu_tn[node.block.block_number] = None ++ ++ del node.parent.children[node.id] ++ self.lru.delete(node) ++ self.tree_node_pool.release(node) ++ ++ def _update_lru_list(self, tree_node_list): ++ for i in range(len(tree_node_list) - 1, -1, -1): ++ self.lru.push_front(tree_node_list[i]) ++ ++ def _match_node_info(self, matched_len, cur_node, split_key): ++ _, match_node, max_match_len = self.get_max_matched_child(split_key, cur_node) ++ matched_len += max_match_len ++ return match_node, matched_len, max_match_len ++ ++ def _insert_helper( ++ self, ++ root: TreeNode, ++ split_keys: List[Tuple[int]], ++ block_table: List[PhysicalTokenBlock], ++ ) -> List[PhysicalTokenBlock]: ++ cur_node = root ++ if len(split_keys) == 0: ++ return [] ++ tree_node_list = [] # 本次每插入或匹配到一个节点,就向这个列表追加一个 TreeNode,供按相反顺序创建 LRU ++ blocks_need_free = [] ++ ++ block_len = min(len(split_keys), len(block_table)) ++ for i in range(block_len): ++ block = block_table[i] ++ split_key = split_keys[i] ++ assert_or_raise( ++ block.block_number not in self.block_num_to_npu_tn, ++ f"duplicated node: {block.block_number}", ++ ) ++ ++ # 获取最长匹配的child ++ _, match_node, max_match_len = self.get_max_matched_child( ++ split_key, cur_node ++ ) ++ # 没有匹配的节点 或 两个节点都是部分匹配,插入新节点 ++ if (max_match_len == 0 ++ or (max_match_len != match_node.num_valid_token and max_match_len != len(split_key))): ++ child = cur_node.insert_child(split_key, block, self.tree_node_pool.acquire()) ++ assert_or_raise( ++ block.block_number not in self.block_num_to_npu_tn, ++ (f"insert1, duplicated nodes: " f"{block.block_number}"), ++ ) ++ self.block_num_to_npu_tn[block.block_number] = child ++ tree_node_list.append(child) ++ cur_node = child ++ ++ # 原Tree节点完全匹配,新节点key更长,释放并删除原Tree节点,插入新节点 ++ elif max_match_len == match_node.num_valid_token and max_match_len < len(split_key): ++ # 未满节点必须是叶子节点 ++ assert_or_raise(len(match_node.children) == 0) ++ ++ # 是复用的 block,那么就更新 key ++ if block == match_node.block: ++ match_node.key = split_key ++ match_node.num_valid_token = len(split_key) ++ tree_node_list.append(match_node) ++ else: ++ child = cur_node.insert_child(split_key, block, self.tree_node_pool.acquire()) ++ assert_or_raise( ++ block.block_number not in self.block_num_to_npu_tn, ++ (f"insert2, duplicated nodes: " f"{block.block_number}"), ++ ) ++ self.block_num_to_npu_tn[block.block_number] = child ++ tree_node_list.append(child) ++ cur_node = child ++ # 原节点若未被任何请求引用,则可移除 ++ if not match_node.block.ref_set_of_seq: ++ # 稍后释放块 ++ match_node.block.in_radix_cache = False ++ blocks_need_free.append(match_node.block) ++ self.delete_leaf(match_node) ++ ++ # 原Tree节点key更长新节点完全匹配 或 两个节点都是完全匹配,保留原Tree节点,释放新节点 ++ else: ++ if match_node.block == block: ++ # 是完全复用的节点,无需释放 ++ pass ++ else: ++ blocks_need_free.append(block) ++ cur_node = match_node ++ tree_node_list.append(cur_node) ++ ++ # LRU 从后往前地更新 ++ for i in range(len(tree_node_list) - 1, -1, -1): ++ self.lru.push_front(tree_node_list[i]) ++ return blocks_need_free ++ ++ def _print_helper(self, node: TreeNode, indent): ++ for _, child in node.children.items(): ++ device = "NPU" ++ if child.block.device == Device.CPU: ++ device = "CPU" ++ logger.debug("%s %s_block id: %s, ref count: %s, stat: %s, key_len: %s", ++ " " * indent, ++ device, ++ child.block.block_number, ++ child.block.ref_count, ++ child.block.progress_status, ++ len(child.key), ++ ) ++ self._print_helper(child, indent=indent + 2) ++ ++ # dfs遍历叶子节点 ++ def _collect_leaves_for_evict(self, num_blocks: int, device: Device): ++ ret_list: List[TreeNode] = [] ++ ++ def dfs_(cur_node: TreeNode) -> None: ++ if ( ++ len(cur_node.children) == 0 ++ and cur_node.block.device == device ++ and cur_node.block.progress_status == KVCacheProgressStatus.INITIALIZE ++ ): ++ ret_list.append(cur_node) ++ return ++ ++ if len(ret_list) >= num_blocks: ++ return ++ ++ for x in cur_node.children.values(): ++ if len(ret_list) >= num_blocks: ++ return ++ dfs_(x) ++ ++ dfs_(self.root_node) ++ return ret_list ++ ++ def _collect_nodes_for_swap(self): ++ ret_list = [] ++ ++ def dfs_(cur_node: TreeNode) -> None: ++ gpu_child = [] ++ for child_node in cur_node.children.values(): ++ if ( ++ child_node.block.device == Device.GPU ++ and child_node.block.progress_status ++ == KVCacheProgressStatus.INITIALIZE ++ ): ++ gpu_child.append(child_node) ++ # 子节点正在换入,父节点ref_count>0,不考虑 ++ ++ if len(gpu_child) == 0: ++ ret_list.append(cur_node) ++ ++ for x in gpu_child: ++ dfs_(x) ++ ++ dfs_(self.root_node) ++ return ret_list +diff --git a/dllm_tools/dllm/dkvc/util.py b/dllm_tools/dllm/dkvc/util.py +new file mode 100644 +index 000000000..14d2c4ffd +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/util.py +@@ -0,0 +1,17 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++def assert_or_raise(condition: bool, msg: str = "") -> None: ++ '''断言condition 为真''' ++ if not condition: ++ raise RuntimeError(msg) +diff --git a/dllm_tools/dllm/dkvc/v1/__init__.py b/dllm_tools/dllm/dkvc/v1/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py b/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py +new file mode 100644 +index 000000000..8580c5ccd +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py +@@ -0,0 +1,849 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import os ++import enum ++import hashlib ++from dataclasses import dataclass ++from typing import TYPE_CHECKING, List, Optional, Any ++import threading ++from collections import defaultdict ++import asyncio ++ ++import numpy ++import torch ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.logger import init_logger ++from vllm.v1.attention.backends.mla.common import MLACommonMetadata ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.distributed.parallel_state import (get_world_group, get_tp_group) ++ ++from dllm.cpp_ext.kvc import KvcStore, KvcFuture ++from dllm.kvc import TorchAdaptor ++ ++ENABLE_PREFIX_CACHING = int(os.environ.get("USING_PREFIX_CONNECTOR", 0)) ++FUTURE_TIMEOUT = int(os.getenv("FUTURE_TIMEOUT", 0)) ++SLEEP_TIMEOUT = 0.005 ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.request import Request ++ ++logger = init_logger(f"vllm.{__name__}") ++ ++ ++class RequestStatus(enum.IntEnum): ++ WAITING = enum.auto() ++ TIMEOUT = enum.auto() ++ FINISHED = enum.auto() ++ ++ ++@dataclass ++class RequestTracker: ++ # Request Id ++ request_id: str ++ # Request tokens ++ token_ids: torch.Tensor ++ # block_ids ++ block_ids: list[int] ++ num_scheduled_tokens: int ++ ++ @staticmethod ++ def from_new_request(request_id, token_ids, block_ids, num_scheduled_tokens) -> "RequestTracker": ++ """ ++ Create the request tracker from a new request. ++ """ ++ return RequestTracker( ++ request_id=request_id, ++ token_ids=token_ids, ++ block_ids=block_ids, ++ num_scheduled_tokens=num_scheduled_tokens ++ ) ++ ++ def update( ++ self, ++ block_ids, ++ num_external_scheduled_tokens ++ ) -> None: ++ """ ++ Update the request tracker when a running request is ++ scheduled again ++ """ ++ self.block_ids[0].extend(block_ids[0]) ++ self.num_scheduled_tokens += num_external_scheduled_tokens ++ ++ ++@dataclass ++class ReqMeta: ++ # Request Id ++ request_id: str ++ # Request tokens ++ token_ids: torch.Tensor ++ # block_ids ++ block_ids: list[int] ++ ds_cached_block_num: int ++ ++ @staticmethod ++ def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], ++ block_size: int, ds_cached_block_num: int) -> "ReqMeta": ++ """make request meta""" ++ valid_num_tokens = align_to_block_size(len(token_ids), block_size) ++ valid_block_ids = valid_num_tokens // block_size ++ return ReqMeta( ++ request_id=request_id, ++ token_ids=numpy.array(token_ids), ++ block_ids=block_ids[0][:valid_block_ids], ++ ds_cached_block_num=ds_cached_block_num ++ ) ++ ++ ++@dataclass ++class DLLMDsConnectorMetadata(KVConnectorMetadata): ++ requests: list[ReqMeta] ++ ++ def __init__(self): ++ self.requests = [] ++ ++ def add_request( ++ self, ++ request_id: str, ++ token_ids: list[int], ++ block_ids: list[int], ++ block_size: int, ++ ds_cached_block_num: int ++ ) -> None: ++ """add request meta""" ++ self.requests.append( ++ ReqMeta.make_meta(request_id, token_ids, block_ids, block_size, ds_cached_block_num)) ++ ++ ++@dataclass ++class ReqState: ++ """Per-request state for tracking async transfers.""" ++ num_pending: int = -1 ++ finished: bool = False ++ ++ ++class AsyncHandler: ++ """Manage async saving/loading in separate thread.""" ++ ++ def __init__(self, role, task_list): ++ self._async_save_reqs = defaultdict[str, ReqState](ReqState) ++ self._async_load_reqs = defaultdict[str, ReqState](ReqState) ++ self._is_producer = role ++ self._finished_save_reqs = asyncio.Queue() ++ self._finished_load_reqs = asyncio.Queue() ++ self._future_save_list = asyncio.Queue() ++ self._future_load_list = asyncio.Queue() ++ self.task = asyncio.get_event_loop().create_task(self.get_futures_async()) ++ task_list.append(self.task) ++ ++ async def start_event_loop(self): ++ """start event loop""" ++ self.task = asyncio.create_task(self.get_futures_async()) ++ ++ async def get_futures_async(self): ++ """async get futures""" ++ while True: ++ try: ++ while not self._future_save_list.empty(): ++ request_id, future = self._future_save_list.get_nowait() ++ res = get_future(future) ++ req_state = self._async_save_reqs[request_id] ++ if res == RequestStatus.FINISHED: ++ req_state.num_pending -= 1 ++ if req_state.finished and not req_state.num_pending: ++ logger.debug(f"self._finished_save_reqs:{self._finished_save_reqs.qsize()}, " ++ f"request_id:{request_id}") ++ self._finished_save_reqs.put_nowait(request_id) ++ del self._async_save_reqs[request_id] ++ elif res == RequestStatus.WAITING or not req_state.finished: ++ self._future_save_list.put_nowait((request_id, future)) ++ else: ++ logger.error(f"request:{request_id} get save future timeout, res:{res}") ++ self._finished_save_reqs.put_nowait(request_id) ++ del self._async_save_reqs[request_id] ++ ++ while not self._future_load_list.empty(): ++ request_id, future = self._future_load_list.get_nowait() ++ res = get_future(future) ++ req_state = self._async_load_reqs[request_id] ++ if res == RequestStatus.FINISHED: ++ req_state.num_pending -= 1 ++ if not req_state.num_pending: ++ logger.debug(f"self._finished_load_reqs:{self._finished_load_reqs.qsize()}, " ++ f"request_id:{request_id}") ++ self._finished_load_reqs.put_nowait(request_id) ++ del self._async_load_reqs[request_id] ++ elif res == RequestStatus.WAITING: ++ self._future_load_list.put_nowait((request_id, future)) ++ else: ++ logger.error(f"request:{request_id} get load future timeout, res:{res}") ++ self._finished_load_reqs.put_nowait(request_id) ++ del self._async_load_reqs[request_id] ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"get_futures_async fail, error:{e}") ++ ++ def add_save_request(self, request, future_num): ++ """add save request future""" ++ self._async_save_reqs[request.request_id].num_pending = future_num ++ ++ def add_load_request(self, request, future_num): ++ """add load reqeust future""" ++ self._async_load_reqs[request.request_id].num_pending = future_num ++ ++ def add_save_future(self, request, future): ++ """add save reqeust future""" ++ self._future_save_list.put_nowait((request.request_id, future)) ++ ++ def add_load_future(self, request, future): ++ """add load request future""" ++ self._future_load_list.put_nowait((request.request_id, future)) ++ ++ def get_save_finished(self, finished_request_ids: set[str]) -> set[str]: ++ """Finished saving request ids.""" ++ finished_reqs = set() ++ for req_id in finished_request_ids: ++ req_state = self._async_save_reqs[req_id] ++ if req_state: ++ req_state.finished = True ++ if not req_state.num_pending: ++ finished_reqs.add(req_id) ++ logger.debug(f"_finished_save_reqs, finished_reqs = {req_id}") ++ del self._async_save_reqs[req_id] ++ ++ while not self._finished_save_reqs.empty(): ++ logger.debug(f"_finished_save_reqs.qsize:{self._finished_save_reqs.qsize()}") ++ finished_reqs.add(self._finished_save_reqs.get_nowait()) ++ if len(finished_reqs) != 0: ++ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}") ++ else: ++ finished_reqs = None ++ return finished_reqs ++ ++ def get_load_finished(self) -> set[str]: ++ """Finished saving request ids.""" ++ finished_reqs = set() ++ while not self._finished_load_reqs.empty(): ++ logger.debug(f"finished_queue.qsize:{self._finished_load_reqs.qsize()}") ++ finished_reqs.add(self._finished_load_reqs.get_nowait()) ++ if len(finished_reqs) != 0: ++ logger.debug(f"get_finished, finished_reqs:{finished_reqs}, length:{len(finished_reqs)}") ++ else: ++ finished_reqs = None ++ return finished_reqs ++ ++ ++class DLLMDsConnector(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ++ super().__init__(vllm_config=vllm_config, role=role) ++ self._block_size = vllm_config.cache_config.block_size ++ self._requests_need_load: dict[str, Request] = {} ++ self.config = vllm_config.kv_transfer_config ++ self.is_producer = self.config.is_kv_producer ++ self.do_async_save = os.getenv("ASYNC_SAVE", True) ++ self.layer_name_list = [] ++ self.kv_caches = [] ++ self.key_caches = [] ++ self.value_caches = [] ++ self._skip_blocks: dict[str, int] = {} ++ self._ds_cached_blocks: dict[str, int] = {} ++ self._delay_save = {} ++ self._load_request_queue = asyncio.Queue() ++ self._save_request_queue = asyncio.Queue() ++ self.task_list = [] ++ self.is_ms_non_mla_type = False ++ self.is_ms_mla = False ++ self.is_mla = False ++ # # Complete transfer tracker. Used by the rank 0 to track finished ++ # # transactions on ranks 1 to N-1. ++ # # [req_id -> count] ++ self._done_recving_count: defaultdict[str, int] = defaultdict(lambda: 0) ++ self._done_sending_count: defaultdict[str, int] = defaultdict(lambda: 0) ++ ++ thread_num = int(os.getenv("THREAD_NUM", 64)) ++ conn_timeout_ms = int(os.getenv("CONN_TIMEOUT_MS", 6000)) ++ self.tp_size = vllm_config.parallel_config.tensor_parallel_size ++ if role == KVConnectorRole.WORKER: ++ self.tp_rank = get_tp_group().rank_in_group ++ self.tp_group = get_tp_group() ++ self.device = get_world_group().local_rank ++ else: ++ self.device = vllm_config.kv_transfer_config.kv_connector_extra_config["device_ids"][0] ++ self.tp_group = None ++ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000") ++ ip_port = ds_worker_addr.split(":") ++ ip = ip_port[0] ++ port = int(ip_port[1]) ++ logger.warning(f"init datasystem ip = {ip}, port = {port}, device_id = {self.device}") ++ self._request_token_key = {} ++ store = KvcStore() ++ store.init(ip, port, conn_timeout_ms, thread_num) ++ self.kvc_store = TorchAdaptor(store, self.device) ++ if self.do_async_save and role == KVConnectorRole.WORKER: ++ self.loop = asyncio.get_event_loop() ++ self._async_handler = AsyncHandler(self.is_producer, self.task_list) ++ if ENABLE_PREFIX_CACHING or not self.is_producer: ++ self.load_task = self.loop.create_task(self.consumer_request_task()) ++ self.task_list.append(self.load_task) ++ ++ if ENABLE_PREFIX_CACHING or self.is_producer: ++ self.save_task = self.loop.create_task(self.producer_request_task()) ++ self.task_list.append(self.save_task) ++ ++ thread = threading.Thread(target=self.start_event_loop, daemon=True) ++ thread.start() ++ else: ++ self._async_handler = None ++ ++ def start_event_loop(self): ++ """start event loop""" ++ self.loop.run_until_complete(asyncio.gather(*self.task_list)) ++ self.loop.close() ++ ++ async def producer_request_task(self): ++ """consumer request task""" ++ while True: ++ try: ++ while not self._save_request_queue.empty(): ++ request = self._save_request_queue.get_nowait() ++ self.do_save_request(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"producer_request_task fail, error:{e}") ++ self._save_request_queue.put_nowait(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ ++ async def consumer_request_task(self): ++ """consumer request task""" ++ while True: ++ try: ++ while not self._load_request_queue.empty(): ++ request = self._load_request_queue.get_nowait() ++ self.do_load_kv(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ except Exception as e: ++ logger.error(f"consumer_request_task fail, error:{e}") ++ self._load_request_queue.put_nowait(request) ++ await asyncio.sleep(SLEEP_TIMEOUT) ++ ++ def generate_kv_cache_token_key(self, request): ++ """ ++ generate kv_cache token key. ++ ++ Args: ++ request: request. ++ block_index: block_index ++ """ ++ if not self.is_mla: ++ external_key = "-" + str(self.tp_rank) ++ else: ++ external_key = "-0" ++ return generate_hash_md5(len(request.block_ids), request.token_ids, self._block_size, external_key) ++ ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """Start loading the KV cache from the connector buffer to vLLM's ++ paged KV buffer. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ """ ++ # effective only when prefix cache is disabled and the role is producer. ++ if self.is_producer and not ENABLE_PREFIX_CACHING: ++ return ++ ++ metadata: KVConnectorMetadata = self._get_connector_metadata() ++ if len(metadata.requests) == 0: ++ return ++ ++ if len(self.kv_caches) == 0: ++ self._init_kv_caches_from_forward_context(forward_context) ++ ++ for request in metadata.requests: ++ if self._async_handler is not None: ++ self._load_request_queue.put_nowait(request) ++ else: ++ self.do_load_kv(request) ++ ++ def get_finished( ++ self, finished_req_ids: set[str] ++ ) -> tuple[Optional[set[str]], Optional[set[str]]]: ++ """Finished (saving, loading) request ids.""" ++ logger.debug(f"get_finished, finished_req_ids:{finished_req_ids}") ++ sending_count = 1 if self.is_mla else self.tp_size ++ finished_saved_req, finished_loaded_req = None, None ++ if self._async_handler is not None: ++ if self.is_producer or ENABLE_PREFIX_CACHING: ++ finished_saved_req = self._async_handler.get_save_finished(finished_req_ids) ++ ++ if not self.is_producer or ENABLE_PREFIX_CACHING: ++ finished_loaded_req = self._async_handler.get_load_finished() ++ ++ if self.tp_size == 1: ++ return finished_saved_req, finished_loaded_req ++ ++ if self.tp_rank == 0: ++ for req_id in finished_saved_req or []: ++ self._done_sending_count[req_id] += 1 ++ for req_id in finished_loaded_req or []: ++ self._done_recving_count[req_id] += 1 ++ other_ranks_finished_save_ids: list[str] = [] ++ other_ranks_finished_rec_ids: list[str] = [] ++ for i in range(1, self.tp_size): ++ receive_object = self.tp_group.recv_object(src=i) ++ other_ranks_finished_save_ids.extend(receive_object[0]) ++ other_ranks_finished_rec_ids.extend(receive_object[1]) ++ ++ for req_id in other_ranks_finished_save_ids: ++ self._done_sending_count[req_id] += 1 ++ for req_id in other_ranks_finished_rec_ids: ++ self._done_recving_count[req_id] += 1 ++ all_done_recving: set[str] = set() ++ for req_id in list(self._done_recving_count.keys()): ++ if self._done_recving_count[req_id] == self.tp_size: ++ del self._done_recving_count[req_id] ++ all_done_recving.add(req_id) ++ all_done_sending: set[str] = set() ++ for req_id in list(self._done_sending_count.keys()): ++ if self._done_sending_count[req_id] == sending_count: ++ del self._done_sending_count[req_id] ++ all_done_sending.add(req_id) ++ logger.debug(f"all_done_sending = {all_done_sending}, all_done_recving = {all_done_recving}") ++ return all_done_sending, all_done_recving ++ ++ self.tp_group.send_object((list(finished_saved_req or []), list(finished_loaded_req or [])), dst=0) ++ # Unused as only Rank 0 results are sent to scheduler. ++ return finished_saved_req, finished_loaded_req ++ return None, None ++ ++ def do_load_kv(self, request) -> None: ++ """Start loading the KV cache from the connector buffer to vLLM's ++ paged KV buffer. ++ ++ Args: ++ forward_context (ForwardContext): the forward context. ++ **kwargs: additional arguments for the load operation ++ ++ Note: ++ The number of elements in kv_caches and layer_names should be ++ the same. ++ """ ++ ds_cached_block_num = request.ds_cached_block_num ++ logger.debug(f" ds_cached_block_num = {ds_cached_block_num}") ++ if ds_cached_block_num == 0: ++ return ++ token_key_list = self.generate_kv_cache_token_key(request) ++ key_list = token_key_list[:ds_cached_block_num] ++ block_id_list = request.block_ids[:ds_cached_block_num] ++ if not block_id_list: ++ return ++ if not self.is_mla: ++ key_cache_key_list = key_list ++ value_cache_key_list = [key + "-value" for key in key_list] ++ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_cache_key_list, self.key_caches, block_id_list) ++ future_1 = self.kvc_store.mget_page_attn_blockwise_h2d(value_cache_key_list, self.value_caches, ++ block_id_list) ++ if not self.do_async_save: ++ get_future(future) ++ get_future(future_1) ++ else: ++ self._async_handler.add_load_request(request, 2) ++ self._async_handler.add_load_future(request, future) ++ self._async_handler.add_load_future(request, future_1) ++ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}," ++ f"key_list length:{len(key_cache_key_list)}, block_id_list:{block_id_list}") ++ return ++ ++ future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.kv_caches, block_id_list) ++ if not self.do_async_save: ++ get_future(future) ++ else: ++ self._async_handler.add_load_request(request, 1) ++ self._async_handler.add_load_future(request, future) ++ logger.debug(f"mget_tensors_h2d success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(key_list)}, block_id_list:{block_id_list}") ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """ ++ wait_for_layer_load ++ """ ++ return ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """ ++ save_kv_layer ++ """ ++ if not ENABLE_PREFIX_CACHING and not self.is_producer: ++ return ++ if layer_name not in self.layer_name_list: ++ self.layer_name_list.append(layer_name) ++ self.is_ms_non_mla_type = isinstance(kv_layer, tuple) and len(kv_layer) == 2 ++ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type ++ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla ++ if self.is_mla: ++ self.kv_caches.append(kv_layer) ++ else: ++ self.key_caches.append(kv_layer[0]) ++ self.value_caches.append(kv_layer[1]) ++ ++ def do_save_request(self, request) -> None: ++ """Start saving the KV cache of the layer from vLLM's paged buffer ++ to the connector. ++ ++ Args: ++ layer_name (str): the name of the layer. ++ kv_layer (torch.Tensor): the paged KV buffer of the current ++ layer in vLLM. ++ attn_metadata (AttentionMetadata): the attention metadata. ++ **kwargs: additional arguments for the save operation. ++ """ ++ logger.debug(f"do_save_request, request:{request}") ++ if not self.is_producer: ++ return ++ ++ if self.is_mla and self.tp_rank > 0: ++ return ++ ++ if not request.block_ids: ++ return ++ ++ token_key_list = self.generate_kv_cache_token_key(request) ++ if not self.is_mla: ++ key_cache_key_list = token_key_list ++ value_cache_key_list = [key + "-value" for key in token_key_list] ++ future = self.kvc_store.mset_page_attn_blockwise_d2h(key_cache_key_list, self.key_caches, request.block_ids) ++ future_1 = self.kvc_store.mset_page_attn_blockwise_d2h(value_cache_key_list, self.value_caches, ++ request.block_ids) ++ if not self.do_async_save: ++ get_future(future) ++ get_future(future_1) ++ else: ++ self._async_handler.add_save_request(request, 2) ++ self._async_handler.add_save_future(request, future) ++ self._async_handler.add_save_future(request, future_1) ++ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(key_cache_key_list)}") ++ return ++ ++ future = self.kvc_store.mset_page_attn_blockwise_d2h(token_key_list, self.kv_caches, request.block_ids) ++ if not self.do_async_save: ++ get_future(future) ++ else: ++ self._async_handler.add_save_request(request, 1) ++ self._async_handler.add_save_future(request, future) ++ logger.debug(f"mset_tensors_d2h success, request.request_id:{request.request_id}, " ++ f"key_list length:{len(token_key_list)}, request.block_ids:{request.block_ids}") ++ ++ def wait_for_save(self): ++ """ ++ wait_for_save ++ """ ++ if not self.is_producer: ++ return ++ ++ connector_metadata = self._get_connector_metadata() ++ if not isinstance(connector_metadata, DLLMDsConnectorMetadata): ++ raise ValueError("connector_metadata is not an instance of DLLMDsConnectorMetadata") ++ ++ if not connector_metadata.requests: ++ return ++ ++ for request in connector_metadata.requests: ++ if self._async_handler is not None: ++ self._save_request_queue.put_nowait(request) ++ else: ++ self.do_save_request(request) ++ ++ def get_num_new_matched_tokens( ++ self, ++ request: "Request", ++ num_computed_tokens: int, ++ ) -> tuple[int, bool]: ++ """ ++ Get number of new tokens that can be loaded from the ++ external KV cache beyond the num_computed_tokens. ++ ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ ++ Returns: ++ the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ """ ++ if not ENABLE_PREFIX_CACHING and not self.is_producer: ++ num_tokens_to_check = align_to_block_size(len(request.prompt_token_ids), self._block_size) ++ self._skip_blocks[request.request_id] = num_computed_tokens // self._block_size ++ num_external_computed_tokens = num_tokens_to_check - num_computed_tokens ++ self._ds_cached_blocks[request.request_id] = num_tokens_to_check // self._block_size ++ if self.do_async_save and num_external_computed_tokens > 0: ++ return num_external_computed_tokens, True ++ ++ return num_external_computed_tokens, False ++ if ENABLE_PREFIX_CACHING: ++ nuw_external_hit_blocks = 0 ++ tokens = request.prompt_token_ids ++ hashes = generate_hash_md5(len(tokens) // self._block_size, torch.tensor(tokens).numpy(), ++ self._block_size, "") ++ keys = [] ++ for hash_value in hashes: ++ keys.append(f"{hash_value}-0") ++ ++ try: ++ logger.debug(f"exists, keys = {keys}") ++ exists = self.kvc_store.exist(keys) ++ except RuntimeError as e: ++ logger.warning(f"KVCacheStore exist() failed, all tensors are not saved: {e}") ++ exists = [False] * len(keys) ++ ++ for item in exists: ++ if item: ++ nuw_external_hit_blocks += 1 ++ else: ++ break ++ num_external_hit_tokens = nuw_external_hit_blocks * self._block_size ++ logger.info( ++ f"req id = {request.request_id}, get_num_new_matched_tokens, exists keys = {exists}, " ++ f"hit blocks = {nuw_external_hit_blocks}, computed tokens = {num_computed_tokens}") ++ ++ self._skip_blocks[request.request_id] = num_computed_tokens // self._block_size ++ self._ds_cached_blocks[request.request_id] = nuw_external_hit_blocks ++ ++ need_to_allocate = num_external_hit_tokens - num_computed_tokens ++ logger.info( ++ "Reqid: %s, Total tokens %d, DSCache hit tokens: %d, " ++ "need to load: %d", request.request_id, request.num_tokens, num_external_hit_tokens, need_to_allocate) ++ ++ if self.do_async_save and need_to_allocate > 0: ++ logger.info(f"need to allocate > 0, waiting for remote kv. req id = {request.request_id}") ++ return need_to_allocate, True ++ ++ return 0, False ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ Update KVConnector state after block allocation. ++ ++ If blocks were allocated, add to _requests_need_load, ++ such that we load the KVs in the next forward pass. ++ """ ++ if num_external_tokens > 0: ++ block = blocks.get_unhashed_block_ids() ++ self._requests_need_load[request.request_id] = (request, [block]) ++ logger.debug(f"_requests_need_load add request_id: {request.request_id}, block:{block}") ++ ++ def build_connector_meta( ++ self, ++ scheduler_output: SchedulerOutput, ++ ) -> KVConnectorMetadata: ++ """Build the connector metadata for this step. ++ ++ This function should NOT modify any fields in the scheduler_output. ++ Also, calling this function will reset the state of the connector. ++ ++ Args: ++ scheduler_output (SchedulerOutput): the scheduler output object. ++ """ ++ meta = DLLMDsConnectorMetadata() ++ total_need_load = 0 ++ for new_req in scheduler_output.scheduled_new_reqs: ++ if new_req.req_id in self._requests_need_load: ++ meta.add_request(request_id=new_req.req_id, ++ token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ block_size=self._block_size, ++ ds_cached_block_num=self._ds_cached_blocks.get(new_req.req_id, 0)) ++ total_need_load += 1 ++ else: ++ if self.is_producer: ++ num_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(new_req.req_id) ++ num_scheduled_tokens += new_req.num_computed_tokens ++ if len(new_req.prompt_token_ids) > num_scheduled_tokens: ++ self._delay_save[new_req.req_id] = RequestTracker.from_new_request(new_req.req_id, ++ new_req.prompt_token_ids, ++ new_req.block_ids, ++ num_scheduled_tokens) ++ else: ++ meta.add_request(request_id=new_req.req_id, ++ token_ids=new_req.prompt_token_ids, ++ block_ids=new_req.block_ids, ++ block_size=self._block_size, ++ ds_cached_block_num=self._ds_cached_blocks.get(new_req.req_id, 0)) ++ ++ for cached_req in scheduler_output.scheduled_cached_reqs: ++ # NOTE(rob): here we rely on the resumed requests being ++ # the first N requests in the list scheduled_cache_reqs. ++ if not cached_req.resumed_from_preemption: ++ if cached_req.req_id in self._delay_save: ++ request_tracker = self._delay_save.get(cached_req.req_id) ++ num_external_scheduled_tokens = scheduler_output.num_scheduled_tokens.get(cached_req.req_id) ++ request_tracker.update(cached_req.new_block_ids, num_external_scheduled_tokens) ++ if len(request_tracker.token_ids) <= request_tracker.num_scheduled_tokens: ++ del self._delay_save[cached_req.req_id] ++ logger.debug(f"add delay save request, request id:{request_tracker.request_id}") ++ meta.add_request(request_id=request_tracker.request_id, ++ token_ids=request_tracker.token_ids, ++ block_ids=request_tracker.block_ids, ++ block_size=self._block_size, ++ ds_cached_block_num=self._ds_cached_blocks.get(request_tracker.request_id, 0)) ++ break ++ ++ if cached_req.req_id in self._requests_need_load: ++ # NOTE(rob): cached_req_data does not have the full ++ # list of token ids (only new tokens). So we look it ++ # up in the actual request object. ++ request = self._requests_need_load[cached_req.req_id] ++ token_ids = request.all_token_ids[:len(request.prompt_token_ids)] ++ logger.debug(f"request_id:{request.request_id} resumed from preemption") ++ # NOTE(rob): For resumed req, new_block_ids is all ++ # of the block_ids for the request. ++ block_ids = cached_req.new_block_ids ++ meta.add_request(request_id=cached_req.req_id, ++ token_ids=token_ids, ++ block_ids=block_ids, ++ block_size=self._block_size, ++ ds_cached_block_num=self._ds_cached_blocks.get(cached_req.req_id, 0)) ++ total_need_load += 1 ++ if self.do_async_save: ++ for req_id, (req, block_ids) in self._requests_need_load.items(): ++ if not block_ids: ++ logger.debug( ++ "Skipping adding request %s to ConnectorMetadata, " ++ "as there are no remote blocks to pull", req_id) ++ continue ++ ++ meta.add_request( ++ request_id=req_id, ++ token_ids=req.prompt_token_ids, ++ block_ids=block_ids, ++ block_size=self._block_size, ++ ds_cached_block_num=self._ds_cached_blocks.get(req_id, 0)) ++ total_need_load += 1 ++ ++ # Clear the list once workers start the transfers ++ logger.debug(f"total_need_load:{total_need_load}, self._requests_need_load:{len(self._requests_need_load)}") ++ # Clear the list once workers start the transfers ++ if total_need_load != len(self._requests_need_load): ++ logger.error(f"total_need_load={total_need_load} " ++ f"is not equal to requests_need_load={len(self._requests_need_load)}") ++ raise ValueError("total_need_load is not equal to requests_need_load") ++ self._requests_need_load.clear() ++ self._ds_cached_blocks.clear() ++ return meta ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ # Return True to indicate that saving may be happening ++ # asynchronously. ++ """ ++ request_finished ++ """ ++ if self.is_producer: ++ return self.do_async_save, None ++ ++ return False, None ++ ++ def _init_kv_caches_from_forward_context( ++ self, forward_context: "ForwardContext"): ++ """ ++ _init_kv_caches_from_forward_context ++ ++ Args: ++ forward_context: forward_context. ++ """ ++ attn_metadata = forward_context.attn_metadata ++ for layer_name in forward_context.no_compile_layers: ++ attn_layer = forward_context.no_compile_layers[layer_name] ++ kv_layer = attn_layer.kv_cache[forward_context.virtual_engine] ++ self.is_ms_non_mla_type: bool = isinstance(kv_layer, tuple) and len(kv_layer) == 2 ++ self.is_ms_mla = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" and not self.is_ms_non_mla_type ++ self.is_mla = isinstance(attn_metadata, MLACommonMetadata) or self.is_ms_mla ++ if layer_name not in self.layer_name_list: ++ self.layer_name_list.append(layer_name) ++ logger.debug(f"_init_kv_caches_from_forward_context, layer_name:{layer_name}") ++ if not self.is_mla: ++ self.key_caches.append(kv_layer[0]) ++ self.value_caches.append(kv_layer[1]) ++ elif self.is_ms_mla: ++ self.kv_caches.append(kv_layer[0]) ++ else: ++ self.kv_caches.append(kv_layer) ++ ++ ++def extract_number(s): ++ """extract number""" ++ parts = s.split('.') ++ for part in parts: ++ if part.isdigit(): ++ return int(part) ++ return None # 如果没有找到数字则返回None ++ ++ ++def align_to_block_size(num_tokens: int, block_size) -> int: ++ """Align the number of tokens to the block size. ++ """ ++ return (num_tokens - 1) // block_size * block_size ++ ++ ++def generate_hash_md5(block_id_num, token_ids, block_size, external_key): ++ """ ++ generate kv_cache token key. ++ ++ Args: ++ block_id_num: number of block ids. ++ token_ids: token ids ++ block_size: block size of vllm ++ external_key: additional key ++ """ ++ hash_list = [] ++ for block_index in range(block_id_num): ++ end_index = (block_index + 1) * block_size ++ input_ids = token_ids[:end_index] ++ input_ids_bytes = input_ids.tobytes() ++ token_hash = hashlib.md5(input_ids_bytes).hexdigest() + external_key ++ hash_list.append(token_hash) ++ return hash_list ++ ++ ++def get_future(fut: KvcFuture) -> List[str]: ++ """get future""" ++ try: ++ rc = fut.result(FUTURE_TIMEOUT) ++ except TimeoutError: ++ return RequestStatus.WAITING ++ ++ if rc.status_code != 0 or len(rc.failed_list) != 0: ++ # failed, should wait again ++ logger.warning(f"rc.status_code:{rc.status_code}, rc.error_message:{rc.error_message}" ++ f"rc.failed_list:{rc.failed_list}") ++ return RequestStatus.TIMEOUT ++ ++ return RequestStatus.FINISHED +diff --git a/dllm_tools/dllm/dkvc/v1/dllm_ds_d2d_connector.py b/dllm_tools/dllm/dkvc/v1/dllm_ds_d2d_connector.py +new file mode 100644 +index 000000000..8f7701e2a +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/v1/dllm_ds_d2d_connector.py +@@ -0,0 +1,798 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import contextlib ++import math ++import threading ++import time ++import socket ++from collections import defaultdict ++from collections.abc import Iterator ++from dataclasses import dataclass ++from typing import TYPE_CHECKING, Any, Optional, List, Tuple ++import pickle ++from concurrent.futures import Future ++from concurrent.futures import ThreadPoolExecutor ++import os ++ ++import msgspec ++import torch ++import zmq ++ ++from vllm.config import VllmConfig ++from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ++ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) ++from vllm.distributed.parallel_state import ( ++ get_tensor_model_parallel_rank, ++ get_tp_group, get_world_group) ++from vllm.logger import logger ++from vllm.utils import make_zmq_path, make_zmq_socket ++from vllm.v1.core.sched.output import SchedulerOutput ++from vllm.v1.request import RequestStatus ++ ++from datasystem.hetero_client import HeteroClient, Blob, DeviceBlobList ++ ++if TYPE_CHECKING: ++ from vllm.attention.backends.abstract import AttentionMetadata ++ from vllm.forward_context import ForwardContext ++ from vllm.v1.core.kv_cache_manager import KVCacheBlocks ++ from vllm.v1.request import Request ++ ++ ++BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) ++GET_META_MSG = b"get_meta_msg" ++DONE_RECVING_MSG = b"done_recving_msg" ++ ++ ++class DLLMDsD2DAgentMetadata( ++ msgspec.Struct, ++ omit_defaults=True, # type: ignore[call-arg] ++ # required for @cached_property. ++ dict=True): ++ engine_id: str ++ kv_caches_base_addr: list[int] ++ num_blocks: int ++ ++ ++@dataclass ++class ReqMeta: ++ local_block_ids: list[int] ++ remote_block_ids: list[int] ++ remote_host: str ++ remote_port: int ++ remote_engine_id: str ++ ++ ++class DLLMDsD2DConnectorMetadata(KVConnectorMetadata): ++ ++ def __init__(self): ++ self.requests: dict[str, ReqMeta] = {} ++ ++ def add_new_req( ++ self, ++ request_id: str, ++ local_block_ids: list[int], ++ kv_transfer_params: dict[str, Any], ++ ): ++ """ ++ add new request metadata ++ """ ++ self.requests[request_id] = ReqMeta( ++ local_block_ids=local_block_ids, ++ remote_block_ids=kv_transfer_params["remote_block_ids"], ++ remote_engine_id=kv_transfer_params["remote_engine_id"], ++ remote_host=kv_transfer_params["remote_host"], ++ remote_port=kv_transfer_params["remote_port"], ++ ) ++ ++ ++class DLLMDsD2DConnector(KVConnectorBase_V1): ++ ++ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): ++ if vllm_config.kv_transfer_config is None: ++ raise ValueError("vllm_config.kv_transfer_config is none") ++ self.engine_id = vllm_config.kv_transfer_config.engine_id ++ ++ if role == KVConnectorRole.SCHEDULER: ++ self.connector_scheduler: Optional[DLLMDsD2DConnectorScheduler] = \ ++ DLLMDsD2DConnectorScheduler(vllm_config, str(self.engine_id)) ++ self.connector_worker: Optional[DLLMDsD2DConnectorWorker] = None ++ elif role == KVConnectorRole.WORKER: ++ self.connector_scheduler = None ++ self.connector_worker = DLLMDsD2DConnectorWorker( ++ vllm_config, str(self.engine_id)) ++ ++ ############################################################ ++ # Scheduler Side Methods ++ ############################################################ ++ ++ def get_num_new_matched_tokens( ++ self, request: "Request", ++ num_computed_tokens: int) -> tuple[int, bool]: ++ """ ++ get the number of new matched tokens ++ """ ++ return self.connector_scheduler.get_num_new_matched_tokens( ++ request, num_computed_tokens) ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ update state after alloc ++ """ ++ return self.connector_scheduler.update_state_after_alloc( ++ request, blocks, num_external_tokens) ++ ++ def build_connector_meta( ++ self, ++ scheduler_output: SchedulerOutput, ++ ) -> KVConnectorMetadata: ++ """ ++ build connector meta ++ """ ++ return self.connector_scheduler.build_connector_meta(scheduler_output) ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ request finished ++ """ ++ return self.connector_scheduler.request_finished(request, block_ids) ++ ++ ############################################################ ++ # Worker Side Methods ++ ############################################################ ++ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ++ """ ++ register kv caches ++ """ ++ self.connector_worker.register_kv_caches(kv_caches) ++ ++ def get_finished(self, ++ finished_req_ids: set[str]) -> tuple[set[str], set[str]]: ++ """Get the finished recving and sending requests.""" ++ return self.connector_worker.get_finished() ++ ++ def start_load_kv(self, forward_context: "ForwardContext", ++ **kwargs) -> None: ++ """ ++ get the number of new matched tokens ++ """ ++ if not isinstance(self._connector_metadata, DLLMDsD2DConnectorMetadata): ++ raise ValueError("connector_metadata is not an instance of DLLMDsD2DConnectorMetadata") ++ self.connector_worker.start_load_kv(self._connector_metadata) ++ ++ def wait_for_layer_load(self, layer_name: str) -> None: ++ """DLLMDsD2DConnector does not do layerwise saving.""" ++ pass ++ ++ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, ++ attn_metadata: "AttentionMetadata", **kwargs) -> None: ++ """DLLMDsD2DConnector does not save explicitly.""" ++ pass ++ ++ def wait_for_save(self): ++ """DLLMDsD2DConnector does not save explicitly.""" ++ pass ++ ++ ++class DLLMDsD2DConnectorScheduler: ++ """Implementation of Scheduler side methods""" ++ ++ def __init__(self, vllm_config: VllmConfig, engine_id: str): ++ self.vllm_config = vllm_config ++ self.block_size = vllm_config.cache_config.block_size ++ self.engine_id = engine_id ++ ++ self.side_channel_host = get_local_ip_by_remote() ++ self.side_channel_port = ( ++ BASE_PORT + ++ vllm_config.parallel_config.data_parallel_rank_local * ++ vllm_config.parallel_config.tensor_parallel_size) ++ ++ logger.info("Initializing DLLMDsD2D Scheduler %s", engine_id) ++ ++ # Requests that need to start recv. ++ # New requests are added by update_state_after_alloc in ++ # the scheduler. Used to make metadata passed to Worker. ++ self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} ++ ++ @staticmethod ++ def get_num_new_matched_tokens( ++ request: "Request", ++ num_computed_tokens: int) -> tuple[int, bool]: ++ """ ++ For remote prefill, pull all prompt blocks from remote ++ asynchronously relative to engine execution. ++ Args: ++ request (Request): the request object. ++ num_computed_tokens (int): the number of locally ++ computed tokens for this request ++ Returns: ++ * the number of tokens that can be loaded from the ++ external KV cache beyond what is already computed. ++ * true if the external KV cache tokens will be loaded ++ asynchronously (between scheduler steps). ++ """ ++ ++ params = request.kv_transfer_params ++ logger.info( ++ "DLLMDsD2DConnector get_num_new_matched_tokens: request_id=%s " ++ "num_computed_tokens=%s, kv_transfer_params=%s", request.request_id, ++ num_computed_tokens, params) ++ ++ if params is not None and params.get("do_remote_prefill"): ++ # Remote prefill: get all prompt blocks from remote. ++ ++ # 因为prefill阶段结束会把生成first_token拼到prompts上,所以prompt_token_ids会多一个,这里要减掉 ++ count = max(len(request.prompt_token_ids) - 1 - num_computed_tokens, 0) ++ if count > 0: ++ return count, True ++ ++ # No remote prefill for this request. ++ return 0, False ++ ++ def update_state_after_alloc(self, request: "Request", ++ blocks: "KVCacheBlocks", ++ num_external_tokens: int): ++ """ ++ update state after alloc ++ """ ++ params = request.kv_transfer_params ++ logger.info( ++ "DLLMDsD2DConnector update_state_after_alloc: " ++ "num_external_tokens=%s, kv_transfer_params=%s blocks=%s request_id=%s", ++ num_external_tokens, params, blocks, request.request_id) ++ ++ if params is not None and params.get("do_remote_prefill"): ++ if params.get("remote_block_ids"): ++ if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): ++ # Get unhashed blocks to pull from remote. ++ self._reqs_need_recv[request.request_id] = ( ++ request, blocks.get_unhashed_block_ids()) ++ else: ++ logger.warning( ++ "Got invalid KVTransferParams: %s. This " ++ "request will not utilize KVTransfer", params) ++ else: ++ if num_external_tokens != 0: ++ raise ValueError("num_external_tokens is not 0") ++ # Only trigger 1 KV transfer per request. ++ params["do_remote_prefill"] = False ++ ++ def build_connector_meta( ++ self, ++ scheduler_output: SchedulerOutput, ++ ) -> KVConnectorMetadata: ++ """ ++ build connector meta ++ """ ++ meta = DLLMDsD2DConnectorMetadata() ++ ++ # Loop through scheduled reqs and convert to ReqMeta. ++ for req_id, (req, block_ids) in self._reqs_need_recv.items(): ++ if req.kv_transfer_params is None: ++ raise ValueError(f"request_id={req_id} kv_transfer_params is none.") ++ meta.add_new_req( ++ request_id=req_id, ++ local_block_ids=block_ids, ++ kv_transfer_params=req.kv_transfer_params, ++ ) ++ ++ # Clear the list once workers start the transfers ++ self._reqs_need_recv.clear() ++ ++ return meta ++ ++ def request_finished( ++ self, ++ request: "Request", ++ block_ids: list[int], ++ ) -> tuple[bool, Optional[dict[str, Any]]]: ++ """ ++ Once a request is finished, determine whether request blocks ++ should be freed now or will be sent asynchronously and freed later. ++ """ ++ params = request.kv_transfer_params ++ logger.info(f"DLLMDsD2DConnector request_finished, request_id={request.request_id},block_ids={block_ids}," ++ f"params={params}, request.status={request.status}") ++ ++ if (params is None or not params.get("do_remote_decode") ++ or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): ++ return False, None ++ ++ # Get computed blocks. ++ # 因为我们的decode不能提供prefill的功能,所以我们不能舍弃block_size粒度以外的kv cache ++ computed_block_ids = block_ids ++ ++ # If prompt < block_size, no xfer so free blocks immediately. ++ delay_free_blocks = len(computed_block_ids) > 0 ++ ++ return delay_free_blocks, dict( ++ do_remote_prefill=True, ++ do_remote_decode=False, ++ remote_block_ids=computed_block_ids, ++ remote_engine_id=self.engine_id, ++ remote_host=self.side_channel_host, ++ remote_port=self.side_channel_port, ++ ) ++ ++ ++class DLLMDsD2DConnectorWorker: ++ """Implementation of Worker side methods""" ++ ++ def __init__(self, vllm_config: VllmConfig, engine_id: str): ++ # Metadata. ++ self.engine_id = engine_id ++ self.tp_rank = get_tensor_model_parallel_rank() ++ self.tp_size = vllm_config.parallel_config.tensor_parallel_size ++ self.tp_group = get_tp_group() ++ self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local ++ self.dp_size = vllm_config.parallel_config.data_parallel_size ++ self.kv_caches: dict[str, torch.Tensor] = {} ++ # map of kv_caches_base_addr -> tensor ++ self.kv_caches_addr_dict: dict[str, torch.Tensor] = {} ++ self.side_channel_host = get_local_ip_by_remote() ++ self.side_channel_port = ( ++ BASE_PORT + ++ vllm_config.parallel_config.data_parallel_rank_local * ++ vllm_config.parallel_config.tensor_parallel_size) ++ logger.info(f"tp_rank={self.tp_rank}, self.side_channel_port={self.side_channel_port}") ++ ++ # Map of engine_id -> agent_name. ++ self._remote_engine: List = [] ++ ++ # Map of engine_id -> kv_caches_base_addr ++ self.kv_caches_base_addr: dict[str, list[int]] = {} ++ ++ self.num_layers = 0 ++ ++ # # Complete transfer tracker. Used by the rank 0 to track finished ++ # # transactions on ranks 1 to N-1. ++ # # [req_id -> count] ++ self._done_recving_count: defaultdict[str, int] = defaultdict(lambda: 0) ++ self._done_sending_count: defaultdict[str, int] = defaultdict(lambda: 0) ++ self.is_producer = vllm_config.kv_transfer_config.is_kv_producer ++ ++ # Background thread for establishing new connections. ++ self._message_listener_loop: Optional[threading.Thread] = None ++ ++ self.vllm_config = vllm_config ++ self.block_size = vllm_config.cache_config.block_size ++ self.block_len = 0 ++ ++ # create async thread pool ++ self.futures: dict[str, list[Future]] = {} ++ self.req_record: dict[str, tuple[str, int]] = {} ++ ++ # gey device id ++ self._is_ms = os.getenv("vLLM_MODEL_BACKEND", None) == "MindFormers" ++ if self._is_ms: ++ self.device = get_world_group().local_rank ++ else: ++ self.device = torch.npu.current_device() ++ # init kvcStore ++ ds_worker_addr = os.getenv("DS_WORKER_ADDR", "172.17.0.4:9000") ++ ip_port = ds_worker_addr.split(":") ++ ip = ip_port[0] ++ port = int(ip_port[1]) ++ logger.info(f"init datasystem ip = {ip}, port = {port}, device_id = {self.device},_is_ms = {self._is_ms}") ++ ++ self._ds_client = HeteroClient(ip, port, 6000) ++ self._ds_client.init() ++ ++ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): ++ """Register the KV Cache data.""" ++ _, first_kv_cache_tuple = next(iter(kv_caches.items())) ++ first_kv_cache = first_kv_cache_tuple[0] ++ kv_elem_size = first_kv_cache.element_size() ++ use_mla = len(first_kv_cache.shape) == 3 ++ ++ num_blocks = first_kv_cache.shape[0] ++ block_shape = first_kv_cache.shape[-3:] ++ # hybrid attn, etc ++ self.block_len = kv_elem_size * math.prod(block_shape) ++ logger.info("Registering KV_Caches. use_mla: %s, shape %s, num_blocks: %s, block_shape: %s, kv_elem_size:%s", ++ use_mla, first_kv_cache.shape, num_blocks, block_shape, kv_elem_size) ++ ++ self.kv_caches = kv_caches ++ kv_caches_base_addr = [] ++ for cache_or_caches in kv_caches.values(): ++ # Normalize to always be a list of caches ++ cache_list = [cache_or_caches] if use_mla else cache_or_caches ++ for cache in cache_list: ++ base_addr = cache.data_ptr() ++ kv_caches_base_addr.append(base_addr) ++ logger.info(f"base_addr={base_addr}") ++ for block_index in range(num_blocks): ++ block_base_addr = base_addr + block_index * self.block_len ++ self.kv_caches_addr_dict[self.engine_id + "_" + str(self.tp_rank) + "_" + str(block_base_addr)] \ ++ = DeviceBlobList(self.device, [Blob(block_base_addr, self.block_len)]) ++ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr ++ if not self.is_producer: ++ return ++ # Only instance P needs to register KV cache address ++ self._register_kv_addr(num_blocks) ++ ++ def add_remote_agent(self, agent_meta: DLLMDsD2DAgentMetadata): ++ """add remote agent""" ++ engine_id = agent_meta.engine_id ++ if engine_id == self.engine_id: ++ logger.error(f"Conflict engine id found! engine_id={engine_id}") ++ raise ValueError("Conflict engine id found!") ++ if engine_id in self._remote_engine: ++ return ++ self._remote_engine.append(engine_id) ++ self.kv_caches_base_addr[ ++ engine_id] = agent_meta.kv_caches_base_addr ++ ++ def get_finished(self) -> tuple[set[str], set[str]]: ++ """get finished sending and receiving requests""" ++ # get the async transfer completed tasks (done recving) ++ req_id_list = list(self.futures.keys()) ++ time1 = time.monotonic() ++ self._handle_done_transfers(req_id_list) ++ done_sending = set(self._done_sending_count.keys()) ++ done_recving = set(self._done_recving_count.keys()) ++ if len(done_sending) == 0 and len(done_recving) == 0: ++ return done_sending, done_recving ++ logger.info( ++ "Rank %s, get_finished: %s requests done sending " ++ "and %s requests done recving", self.tp_rank, done_sending, ++ done_recving) ++ logger.info(f"future result time {time.monotonic() - time1}") ++ if self.tp_size == 1: ++ self._done_sending_count.clear() ++ self._done_recving_count.clear() ++ return done_sending, done_recving ++ # Rank 0: get finished from all other ranks. ++ if self.tp_rank == 0: ++ # Keep track of how many other ranks have finished. ++ other_ranks_finished_ids: list[str] = [] ++ for i in range(1, self.tp_size): ++ other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i)) ++ for req_id in other_ranks_finished_ids: ++ if req_id in self._done_recving_count: ++ self._done_recving_count[req_id] += 1 ++ else: ++ self._done_sending_count[req_id] += 1 ++ # Return ids that finished on all ranks to the scheduler. ++ all_done_recving: set[str] = set() ++ for req_id in list(self._done_recving_count.keys()): ++ if self._done_recving_count[req_id] == self.tp_size: ++ del self._done_recving_count[req_id] ++ all_done_recving.add(req_id) ++ all_done_sending: set[str] = set() ++ for req_id in list(self._done_sending_count.keys()): ++ if self._done_sending_count[req_id] == self.tp_size: ++ del self._done_sending_count[req_id] ++ all_done_sending.add(req_id) ++ logger.info(f"tp_rank 0,return result time {time.monotonic() - time1}") ++ return all_done_sending, all_done_recving ++ # Ranks 1 to N-1: send finished ids to Rank 0. ++ finished_req_ids = list(done_recving.union(done_sending)) ++ self.tp_group.send_object(finished_req_ids, dst=0) ++ # Unused as only Rank 0 results are sent to scheduler. ++ self._done_sending_count.clear() ++ self._done_recving_count.clear() ++ logger.info(f"tp_rank={self.tp_rank},return result time {time.monotonic() - time1}") ++ return done_sending, done_recving ++ ++ def start_load_kv(self, metadata: DLLMDsD2DConnectorMetadata): ++ """ ++ Start loading KV blocks from remote engine. ++ Args: ++ metadata: dict of request_id -> DLLMDsD2DConnectorMetadata ++ """ ++ if self.is_producer: ++ return ++ for req_id, meta in metadata.requests.items(): ++ logger.debug( ++ "start_load_kv for request %s from remote engine %s. " ++ "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, ++ meta.remote_engine_id, len(meta.local_block_ids), ++ len(meta.remote_block_ids)) ++ self._read_blocks( ++ request_id=req_id, ++ remote_engine_id=meta.remote_engine_id, ++ local_block_ids=meta.local_block_ids, ++ remote_block_ids=meta.remote_block_ids, ++ remote_host=meta.remote_host, ++ remote_port=meta.remote_port, ++ ) ++ ++ def _register_kv_addr(self, num_blocks): ++ items = list(self.kv_caches_addr_dict.items()) ++ kv_caches_addr_dict_list = [dict(items[i:i + 9999]) for i in range(0, len(items), 9999)] ++ for kv_caches_addr_dict in kv_caches_addr_dict_list: ++ try: ++ failed_keys = self._ds_client.dev_mset(list(kv_caches_addr_dict.keys()), ++ list(kv_caches_addr_dict.values())) ++ if failed_keys: ++ logger.warning(f"dev_mset failed,failed_keys:{failed_keys}") ++ except Exception as e: ++ self._ds_client.dev_delete(list(kv_caches_addr_dict.keys())) ++ logger.error(f"dev_mset error.e:{e}") ++ self.num_layers = len(self.kv_caches.keys()) ++ # After KV Caches registered, listen for new connections. ++ metadata = DLLMDsD2DAgentMetadata( ++ engine_id=self.engine_id, ++ kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], ++ num_blocks=num_blocks, ++ ) ++ ready_event = threading.Event() ++ self._message_listener_loop = threading.Thread( ++ target=self._message_listener, ++ args=(metadata, ready_event, self.tp_rank), ++ daemon=True, ++ name="message_listener") ++ self._message_listener_loop.start() ++ ready_event.wait() ++ ++ def _message_listener(self, metadata: DLLMDsD2DAgentMetadata, ++ ready_event: threading.Event, tp_rank: int): ++ """Background thread for getting new DLLMDsD2D handshakes.""" ++ ++ encoder = msgspec.msgpack.Encoder() ++ encoded_data = encoder.encode(metadata) ++ size_in_bytes = len(encoded_data) ++ logger.debug("Size of encoded DLLMDsD2DAgentMetadata: %s bytes", ++ str(size_in_bytes)) ++ ++ # Listen for new requests for metadata. ++ # NOTE(rob): we need each rank to have a unique port. This ++ # hack to keeps us moving. We will switch when moving to etcd ++ # or where we have a single ZMQ socket in the scheduler. ++ handshake_port = self.side_channel_port + tp_rank ++ path = make_zmq_path("tcp", self.side_channel_host, handshake_port) ++ logger.info("Starting listening on path: %s, side_channel_port:%s, tp_rank:%s", path, self.side_channel_port, ++ tp_rank) ++ with zmq_ctx(zmq.ROUTER, path) as sock: ++ ready_event.set() ++ while True: ++ try: ++ identity, _, msg = sock.recv_multipart() ++ msg = pickle.loads(msg) ++ logger.info(f"msg={msg}") ++ if msg[0] == GET_META_MSG: ++ logger.info( ++ "Got notify from remote engine that get meta msg") ++ sock.send_multipart((identity, b"", encoded_data)) ++ elif msg[0] == DONE_RECVING_MSG: ++ logger.info( ++ "Got notify from remote engine that load kv is done") ++ self._done_sending_count[msg[1]] += 1 ++ sock.send_multipart((identity, b"", encoded_data)) ++ ++ else: ++ logger.warning( ++ "Connection listener got unexpected message %s", msg) ++ except Exception as e: ++ logger.error(f"error:{e}") ++ ++ def _message_req(self, host: str, port: int, msg: tuple[bytes, str]): ++ """send a normal message with a remote instance.""" ++ ++ start_time = time.perf_counter() ++ # NOTE(rob): we need each rank to have a unique port. This is ++ # a hack to keep us moving. We will switch when moving to etcd ++ # or where we have a single ZMQ socket in the scheduler. ++ path = make_zmq_path("tcp", host, port) ++ logger.info("Querying metadata on path: %s, msg:%s", path, msg) ++ with zmq_ctx(zmq.REQ, path) as sock: ++ # Send msg to remote. It will recv a msg in shakehand case and other would not ++ if msg[0] == GET_META_MSG: ++ logger.info("Sending query for metadata") ++ data_bytes = pickle.dumps(msg) ++ sock.send(data_bytes) ++ metadata_bytes = sock.recv() ++ decoder = msgspec.msgpack.Decoder(DLLMDsD2DAgentMetadata) ++ metadata = decoder.decode(metadata_bytes) ++ # Register Remote agent. ++ self.add_remote_agent(metadata) ++ elif msg[0] == DONE_RECVING_MSG: ++ logger.info("Sending notify to prefill that load is done") ++ # 将元组序列化为字节 ++ data_bytes = pickle.dumps(msg) ++ sock.send(data_bytes) ++ sock.recv() ++ else: ++ logger.warning( ++ "Connection listener got unexpected message %s", msg) ++ raise RuntimeError(f"Unexpected message: {msg}") ++ ++ end_time = time.perf_counter() ++ logger.info("send %s message took: %s", ++ msg[0], end_time - start_time) ++ ++ def _read_blocks( ++ self, ++ local_block_ids: list[int], ++ remote_block_ids: list[int], ++ remote_host: str, ++ remote_port: int, ++ remote_engine_id: str, ++ request_id: str, ++ ): ++ # get target tp rank ++ logger.info(f"local_block_ids={local_block_ids}, remote_block_ids={remote_block_ids}," ++ f"remote_host={remote_host}, remote_port={remote_port}, tp_rank={self.tp_rank}," ++ f" remote_engine_id={remote_engine_id}," ++ f"request_id={request_id}") ++ remote_handshake_port = remote_port + self.tp_rank ++ self.req_record[request_id] = (remote_host, remote_handshake_port) ++ # NOTE(rob): this takes ~2s. We need to get this off the hotpath. ++ if remote_engine_id not in self._remote_engine: ++ msg = (GET_META_MSG, "") ++ self._message_req(remote_host, remote_handshake_port, msg) ++ # NOTE(rob): having the staging blocks be on the READER side is ++ # not going to work well (since we will have to call rearrange tensors). ++ # after we detect the txn is complete (which means we cannot make the ++ # read trxn async easily). If we want to make "READ" happen cleanly, ++ # then we will need to have the staging blocks on the remote side. ++ # NOTE(rob): according to nvidia the staging blocks are used to ++ # saturate IB with heterogeneous TP sizes. We should remove the staging ++ # blocks until we are ready. ++ # Full prefix cache hit: do not need to read remote blocks, ++ # just notify P worker that we have the blocks we need. ++ num_local_blocks = len(local_block_ids) ++ if num_local_blocks == 0: ++ self._done_recving_count[request_id] += 1 ++ msg = (DONE_RECVING_MSG, request_id) ++ self._message_req(remote_host, remote_handshake_port, msg) ++ return ++ # Partial prefix cache hit: just read uncomputed blocks. ++ num_remote_blocks = len(remote_block_ids) ++ if num_local_blocks > num_remote_blocks: ++ logger.error(f"num_local_blocks is bigger than num_remote_blocks request_id={request_id}," ++ f"local_block_ids={local_block_ids}, remote_block_ids={remote_block_ids}.") ++ raise ValueError(f"request_id={request_id},num_local_blocks={num_local_blocks} " ++ f"is bigger than num_remote_blocks={num_remote_blocks}.") ++ if num_local_blocks < num_remote_blocks: ++ remote_block_ids = remote_block_ids[-num_local_blocks:] ++ logger.info(f"request_id={request_id},engine_ids={self.kv_caches_base_addr.keys()}, " ++ f"device={torch.npu.current_device()}, tp_rank={self.tp_rank}") ++ self._async_load_kv(local_block_ids, remote_block_ids, remote_engine_id, request_id) ++ ++ def _async_load_kv(self, local_block_ids, remote_block_ids, remote_engine_id, request_id): ++ with ThreadPoolExecutor() as executor: ++ recv_dev_blobs_dict = {} ++ if self.engine_id not in self.kv_caches_base_addr or remote_engine_id not in self.kv_caches_base_addr: ++ logger.error(f"{self.engine_id} or {remote_engine_id} is not in " ++ f"kv_caches_base_addr={self.kv_caches_base_addr}") ++ return ++ for dst_base_addr, src_base_addr in zip(self.kv_caches_base_addr[self.engine_id], ++ self.kv_caches_base_addr[remote_engine_id]): ++ for local_block_id, remote_block_id in zip(local_block_ids, remote_block_ids): ++ dst = dst_base_addr + local_block_id * self.block_len ++ src = src_base_addr + remote_block_id * self.block_len ++ key = remote_engine_id + "_" + str(self.tp_rank) + "_" + str(src) ++ recv_dev_blobs_dict[key] = DeviceBlobList(self.device, [Blob(dst, self.block_len)]) ++ logger.debug(f"request_id={request_id}, key={key}, dst_base_addr={dst_base_addr}, " ++ f"local_block_id={local_block_id},remote_block_id={remote_block_id}," ++ f"self.block_len={self.block_len}") ++ ++ future = executor.submit( ++ self._sync_transfer, recv_dev_blobs_dict) ++ if request_id in self.futures: ++ self.futures[request_id].append(future) ++ else: ++ self.futures[request_id] = [future] ++ ++ def _sync_transfer( ++ self, recv_dev_blobs_dict: dict): ++ # """Synchronously transfer data to the specified address.""" ++ if self._is_ms: ++ import acl ++ import mindspore as ms ++ acl.init() ++ acl.rt.set_device(self.device) ++ ms.set_device(device_target="Ascend", device_id=self.device) ++ else: ++ torch.npu.set_device(self.device) ++ items = list(recv_dev_blobs_dict.items()) ++ # The length of the keys should be less than 10000 ++ recv_dev_blobs_dict_list = [dict(items[i:i + 9999]) for i in range(0, len(items), 9999)] ++ for dev_blobs_dict in recv_dev_blobs_dict_list: ++ failed_keys = self._ds_client.dev_mget(list(dev_blobs_dict.keys()), ++ list(dev_blobs_dict.values()), 0) ++ del_failed_keys = self._ds_client.dev_local_delete(list(dev_blobs_dict.keys())) ++ logger.info(f"local delete failed keys={del_failed_keys}") ++ if len(failed_keys) != 0: ++ logger.error(f"dev_mget failed.failed_keys={failed_keys}") ++ return -1 ++ return 0 ++ ++ def _handle_done_transfers(self, req_id_list): ++ for req_id in req_id_list: ++ if req_id not in self.futures: ++ logger.warning(f"req_id={req_id} is not transferring") ++ continue ++ success_count = 0 ++ try: ++ # 获取每个任务的返回值 ++ future = self.futures[req_id] ++ for trans_task in future: ++ ret = trans_task.result() ++ # 假设status为True时表示成功 ++ if ret >= 0: ++ success_count += 1 ++ except Exception as e: ++ # 处理任务中的异常情况 ++ logger.error(f"DLLMDsD2D Transfer Engine Return Error req_id:{req_id}.e:{e}") ++ if success_count == len(self.futures[req_id]): ++ self._done_recving_count[req_id] += 1 ++ msg = (DONE_RECVING_MSG, req_id) ++ if req_id not in self.req_record: ++ logger.warning(f"req_id={req_id} is not in req_record") ++ continue ++ remote_host = self.req_record[req_id][0] ++ remote_handshake_port = self.req_record[req_id][1] ++ self._message_req(remote_host, remote_handshake_port, msg) ++ del self.futures[req_id] ++ del self.req_record[req_id] ++ ++ ++@contextlib.contextmanager ++def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: ++ """Context manager for a ZMQ socket""" ++ if socket_type not in (zmq.ROUTER, zmq.REQ): ++ raise ValueError(f"Unexpected socket type: {socket_type}") ++ ctx: Optional[zmq.Context] = None ++ try: ++ ctx = zmq.Context() # type: ignore[attr-defined] ++ yield make_zmq_socket(ctx=ctx, ++ path=addr, ++ socket_type=socket_type, ++ bind=socket_type == zmq.ROUTER) ++ finally: ++ if ctx is not None: ++ ctx.destroy(linger=0) ++ ++ ++def get_local_ip_by_remote() -> str: ++ """get local ip by remote,try ipv4/ipv6""" ++ # try ipv4 ++ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) ++ try: ++ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable ++ return s.getsockname()[0] ++ except Exception as e: ++ logger.error(f"socket connect error,e:{e}") ++ finally: ++ if s is not None: ++ s.close() ++ try: ++ hostname = socket.gethostname() ++ ip = socket.gethostbyname(hostname) ++ if ip and ip != "127.0.0.1" and ip != "0.0.0.0": ++ return ip ++ except Exception as e: ++ logger.error(f"socket get host by name error,e:{e}") ++ # try ipv6 ++ try: ++ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) ++ # Google's public DNS server, see ++ # https://developers.google.com/speed/public-dns/docs/using#addresses ++ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable ++ return s.getsockname()[0] ++ except Exception as e: ++ raise ValueError("Can not get local ip") from e ++ finally: ++ if s is not None: ++ s.close() +diff --git a/dllm_tools/dllm/entities.py b/dllm_tools/dllm/entities.py +new file mode 100644 +index 000000000..78d326059 +--- /dev/null ++++ b/dllm_tools/dllm/entities.py +@@ -0,0 +1,114 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from enum import Enum, auto ++from dataclasses import dataclass ++from typing import ClassVar, List, Optional ++import ray ++from ray.util.placement_group import PlacementGroup ++ ++from dllm import constants ++ ++ ++class VllmInstanceStatus(Enum): ++ UNREADY = auto() ++ RUNNING = auto() ++ SUBPROCESS_EXITED = auto() ++ HEALTHCHECK_FAILED = auto() ++ ++ ++class Role(Enum): ++ """定义实例角色类型""" ++ ++ PREFILL = 0 ++ DECODE = 1 ++ MIXED = 2 ++ ++ ++class SchedulerPolicy(Enum): ++ """负载均衡策略类型""" ++ ++ ROUND_ROBIN = 0 # 轮询 ++ ++ ++@dataclass ++class VllmInstanceInfo: ++ """VLLM实例的健康状态信息""" ++ ++ id: str ++ #: the api server's http address, e.g. "http://10.2.3.4:8000" ++ uri: str ++ #: the instance's role ++ role: Role ++ #: the instance's status ++ status: VllmInstanceStatus = VllmInstanceStatus.UNREADY ++ #: dp master ip ++ dp_master_ip: str = "" ++ #: dp master port ++ dp_master_port: int = 0 ++ ++ ++@dataclass ++class DispatchResult: ++ prefill_vllm_instance_uri: str ++ decode_vllm_instance_uri: str ++ ++ ++@dataclass ++class MetricsInfo: ++ running_num: int = 0 ++ waiting_num: int = 0 ++ npu_usage_perc: float = 0.0 ++ ++ # key:param_name, value:metric_name ++ METRIC_NAME_MAPPING: ClassVar[dict] = { ++ "running_num": constants.NUM_REQUESTS_RUNNING, ++ "waiting_num": constants.NUM_REQUESTS_WAITING, ++ "npu_usage_perc": constants.GPU_CACHE_USAGE_PERC, ++ } ++ ++ ++@dataclass ++class InstanceInfo: ++ actor: ray.actor.ActorHandle ++ in_p: bool ++ in_d: bool ++ ++ ++@dataclass ++class ActorInstanceInfo: ++ actor: ray.actor.ActorHandle ++ config: 'VllmInstanceConfig' ++ pg: PlacementGroup ++ instance_id: str ++ ++ ++@dataclass ++class DeployConfig: ++ head_ip: str ++ prefill_instances_num: int ++ prefill_startup_params: List[str] ++ prefill_startup_env: Optional[str] ++ prefill_data_parallel_size: int ++ prefill_tensor_parallel_size: int ++ prefill_expert_parallel_size: int ++ decode_instances_num: int ++ decode_startup_params: List[str] ++ decode_startup_env: Optional[str] ++ decode_data_parallel_size: int ++ decode_tensor_parallel_size: int ++ decode_expert_parallel_size: int ++ scheduler_policy: SchedulerPolicy ++ proxy_host: str ++ proxy_port: int +diff --git a/dllm_tools/dllm/kvc/__init__.py b/dllm_tools/dllm/kvc/__init__.py +new file mode 100644 +index 000000000..a51e21d78 +--- /dev/null ++++ b/dllm_tools/dllm/kvc/__init__.py +@@ -0,0 +1,15 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++ ++__all__ = ['TorchAdaptor'] ++ ++from .torch_adaptor import TorchAdaptor +diff --git a/dllm_tools/dllm/kvc/torch_adaptor.py b/dllm_tools/dllm/kvc/torch_adaptor.py +new file mode 100644 +index 000000000..661c529bc +--- /dev/null ++++ b/dllm_tools/dllm/kvc/torch_adaptor.py +@@ -0,0 +1,139 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from typing import Optional ++from torch import Tensor ++from dllm.cpp_ext.kvc import Blob, DeviceBlobList, KvcFuture, KvcTensor, PageAttnUtils, KvcStore ++ ++ ++class TorchAdaptor: ++ def __init__(self, store: KvcStore, device_id: Optional[int] = None): ++ self._store = store ++ self.device_id = device_id ++ ++ @staticmethod ++ def get_start_data_ptr(tensor: Tensor) -> int: ++ """get start data pointer""" ++ if TorchAdaptor._is_ms_tensor(tensor): ++ element_size = tensor.element_size() ++ return tensor.data_ptr() + (tensor.storage_offset() * element_size) ++ return tensor.data_ptr() ++ ++ @staticmethod ++ def _is_ms_tensor(tensor: Tensor) -> str: ++ """check if the tensor is mindspore type""" ++ ++ is_ms = tensor.device.type == "Ascend" ++ return is_ms ++ ++ @staticmethod ++ def _check_tensor_device_type(tensor: Tensor) -> None: ++ """check tensor device type""" ++ if tensor.device.type not in ["Ascend", "npu"]: ++ raise ValueError("Not a npu/Ascend tensor") ++ ++ @classmethod ++ def tensor_2_dbl(cls, tensor_list: list[Tensor], device_id: Optional[int] = None) -> DeviceBlobList: ++ """Convert the tensor to DeviceBlobList object""" ++ if len(tensor_list) == 0: ++ return None ++ ++ cls._check_tensor_device_type(tensor_list[0]) ++ blob_list = [Blob(cls.get_start_data_ptr(tensor), tensor.nbytes) for tensor in tensor_list] ++ ++ if TorchAdaptor._is_ms_tensor(tensor_list[0]): ++ return DeviceBlobList(blob_list, device_id) ++ ++ return DeviceBlobList(blob_list, tensor_list[0].device.index) ++ ++ @classmethod ++ def page_attn_layerwise_dbls(cls, layer_tensors: list[Tensor], block_ids: list[int]) -> list[DeviceBlobList]: ++ """Convert the page attention layer wise to DeviceBlobList object""" ++ kvc_tensors = cls.construct_layerwise_tensors(block_ids, layer_tensors) ++ return PageAttnUtils.layerwise_dev_blob_lists(layer_tensors[0].device.index, kvc_tensors, block_ids) ++ ++ @classmethod ++ def construct_layerwise_tensors(cls, block_ids, layer_tensors): ++ """construct layerwise tensors""" ++ if not layer_tensors: ++ raise ValueError("No layer tensor") ++ if not block_ids: ++ raise ValueError("No block id") ++ for tensor in layer_tensors: ++ cls._check_tensor_device_type(tensor) ++ if tensor.device.index != layer_tensors[0].device.index: ++ raise ValueError("Tensors not from a same device") ++ kvc_tensors = [KvcTensor(cls.get_start_data_ptr(t), t.element_size(), list(t.shape)) ++ for t in layer_tensors] ++ return kvc_tensors ++ ++ @classmethod ++ def page_attn_blockwise_dbls(cls, layer_tensors: list[Tensor], block_ids: list[int], device_id) \ ++ -> list[DeviceBlobList]: ++ """Convert the page attention block wise to DeviceBlobList object""" ++ kvc_tensors = cls.construct_layerwise_tensors(block_ids, layer_tensors) ++ if TorchAdaptor._is_ms_tensor(layer_tensors[0]): ++ return PageAttnUtils.blockwise_dev_blob_lists(device_id, kvc_tensors, block_ids) ++ return PageAttnUtils.blockwise_dev_blob_lists(layer_tensors[0].device.index, kvc_tensors, block_ids) ++ ++ def put_tensors_d2d(self, keys: list[str], tensors: list[Tensor]) -> list[KvcFuture]: ++ """put tensors via D2D communication""" ++ dbls = [self.tensor_2_dbl(t) for t in tensors] ++ return self._store.put_d2d(keys, dbls) ++ ++ def get_tensors_d2d(self, keys: list[str], tensors: list[Tensor]) -> list[KvcFuture]: ++ """get tensors via D2D communication""" ++ dbls = [self.tensor_2_dbl(t) for t in tensors] ++ return self._store.get_d2d(keys, dbls) ++ ++ def put_page_attn_layerwise_d2d(self, keys: list[str], layer_tensors: list[Tensor], block_ids: list[int])\ ++ -> list[KvcFuture]: ++ """put page attention layer wise via D2D communication""" ++ dbls = self.page_attn_layerwise_dbls(layer_tensors, block_ids) ++ return self._store.put_d2d(keys, dbls) ++ ++ def get_page_attn_layerwise_d2d(self, keys: list[str], layer_tensors: list[Tensor], block_ids: list[int])\ ++ -> list[KvcFuture]: ++ """get page attention layer wise via D2D communication""" ++ dbls = self.page_attn_layerwise_dbls(layer_tensors, block_ids) ++ return self._store.get_d2d(keys, dbls) ++ ++ def mset_tensors_d2h(self, keys: list[str], tensors: list[Tensor]) -> KvcFuture: ++ """set tensors via D2H communication""" ++ dbls = [self.tensor_2_dbl(t, self.device_id) for t in tensors] ++ return self._store.mset_d2h(keys, dbls) ++ ++ def mget_tensors_h2d(self, keys: list[str], tensors: list[list[Tensor]]) -> KvcFuture: ++ """get tensors via D2H communication""" ++ dbls = [self.tensor_2_dbl(t, self.device_id) for t in tensors] ++ return self._store.mget_h2d(keys, dbls) ++ ++ def mset_page_attn_blockwise_d2h(self, keys: list[str], layer_tensors: list[Tensor], block_ids: list[int])\ ++ -> KvcFuture: ++ """set page attention block wise via D2H communication""" ++ dbls = self.page_attn_blockwise_dbls(layer_tensors, block_ids, self.device_id) ++ return self._store.mset_d2h(keys, dbls) ++ ++ def mget_page_attn_blockwise_h2d(self, keys: list[str], layer_tensors: list[Tensor], block_ids: list[int])\ ++ -> KvcFuture: ++ """get page attention block wise via D2H communication""" ++ dbls = self.page_attn_blockwise_dbls(layer_tensors, block_ids, self.device_id) ++ return self._store.mget_h2d(keys, dbls) ++ ++ def delete(self, keys: list[str]) -> KvcFuture: ++ """delete store keys""" ++ return self._store.delete(keys) ++ ++ def exist(self, keys: list[str]) -> list[bool]: ++ """check existence of kv tensors """ ++ return self._store.exist(keys) +diff --git a/dllm_tools/dllm/logging.py b/dllm_tools/dllm/logging.py +new file mode 100644 +index 000000000..a09568099 +--- /dev/null ++++ b/dllm_tools/dllm/logging.py +@@ -0,0 +1,27 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import logging ++ ++ ++def setup_logging(level=logging.INFO): ++ """setup logging""" ++ logger = logging.getLogger("dllm") ++ logger.propagate = False ++ if not logger.handlers: ++ handler = logging.StreamHandler() ++ formatter = logging.Formatter("[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)s] %(message)s") ++ handler.setFormatter(formatter) ++ logger.addHandler(handler) ++ logger.setLevel(level) +diff --git a/dllm_tools/dllm/monkey_patch/README.md b/dllm_tools/dllm/monkey_patch/README.md +new file mode 100644 +index 000000000..b7d3597ab +--- /dev/null ++++ b/dllm_tools/dllm/monkey_patch/README.md +@@ -0,0 +1,22 @@ ++# About ++We implement VizTracer profiling for DLLM using a monkey-patching approach and vllm-plugin, which is automatically loaded when a vllm-process starts. ++ ++# Usage +++ enable porfile: `export ENABLE_VIZTRACER_PROFILE=1` +++ install viztracer: `pip install viztracer` +++ set env to enable sitecustomize: `export PYTHONPATH=PYTHONPATH:/xxx/dllm/monkey_patch/` +++ set the trace output dir, default is `/tmp/viztracer_profile/`: `export VIZTRACER_OUTPUT_PATH=xxx` +++ set the profiling param, default is `{"tracer_entries"=10000000, "min_duration"=20}`: `VIZTRACER_PARAM = "json manner"`, the param same as https://viztracer.readthedocs.io/en/latest/viztracer.html ++ +++ deploy dllm ++ +++ start profiling: curl -X POST --max-time 600 "http://xxxx:xxx/start_viz_profile" -H "Authorization: Bearer YOUR_API_KEY" ++ +++ do your benchmark ++ +++ stop profiling: curl -X POST --max-time 600 "http://xxxx:xxx/stop_viz_profile" -H "Authorization: Bearer YOUR_API_KEY" ++ +++ The profiling traces can be found in $VIZTRACER_OUTPUT_PATH, with each process generating a separate trace file. To analyze all traces on a unified timeline, we can merge them using : `viztracer --combine xxx xxx --output_file xxxx` ++ ++# Limit ++Since sys.setprofile, which VizTracer relies on, is a thread-level mechanism, profiling only takes effect on the main thread when started, and does not apply to other threads that were already running. Fortunately, most of the workload occurs in the main thread, so this limitation does not significantly impact our analysis. +\ No newline at end of file +diff --git a/dllm_tools/dllm/monkey_patch/__init__.py b/dllm_tools/dllm/monkey_patch/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/monkey_patch/viz_profile/__init__.py b/dllm_tools/dllm/monkey_patch/viz_profile/__init__.py +new file mode 100644 +index 000000000..e69de29bb +diff --git a/dllm_tools/dllm/monkey_patch/viz_profile/common.py b/dllm_tools/dllm/monkey_patch/viz_profile/common.py +new file mode 100644 +index 000000000..faaa2a19e +--- /dev/null ++++ b/dllm_tools/dllm/monkey_patch/viz_profile/common.py +@@ -0,0 +1,82 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++ ++# common ++import logging ++viz_logger = logging.getLogger("VIZ_PROFILE_LOG") ++viz_logger.setLevel(logging.INFO) ++ ++## default ++VIZ_DEFAULT_PARAM = '{"tracer_entries": 10000000, "min_duration": 20}' ++VIZ_DEFAULT_OUTPUT_DIR = "/tmp/viztracer_profile/" ++ ++ ++## basic func ++def check_viztracer_modules() -> bool: ++ '''Check viztracer avaliable.''' ++ import importlib ++ return importlib.util.find_spec("viztracer") is not None ++ ++ ++def get_global_viztracer() -> "VizTracer": ++ '''Ensure unique instance of viztracer.''' ++ from viztracer import VizTracer, get_tracer ++ import os ++ _viztracer = get_tracer() ++ if _viztracer is None: ++ import json ++ json_arg = os.environ.get("VIZTRACER_PARAM", VIZ_DEFAULT_PARAM) ++ dict_arg = json.loads(json_arg) ++ _viztracer = VizTracer(**dict_arg) ++ _viztracer.register_global() ++ return _viztracer ++ ++ ++def generate_output_path(process_name : str) -> str: ++ '''Generate output path.''' ++ import time ++ import os ++ output_path = os.path.join(os.environ.get("VIZTRACER_OUTPUT_PATH", VIZ_DEFAULT_OUTPUT_DIR), ++ f"trace_{process_name}_{os.getpid()}_{time.time()}.json") ++ viz_logger.info(f"[viztracer monkey patch] output path: {output_path}") ++ return output_path ++ ++ ++def viz_profile_basic(process_name: str, is_start: bool = True) -> tuple[bool, str]: ++ '''Basic function for start/stop profile''' ++ viz_logger.info(f"[viztracer monkey patch][{process_name}]: \ ++ {'Starting' if is_start else 'Stopping'} viztracer profiler...") ++ if not check_viztracer_modules(): ++ viz_logger.error(f"[viztracer monkey patch][{process_name}]: Failed to import VizTracer.") ++ return False, "Need to install Viztracer first." ++ _viztracer = get_global_viztracer() ++ ++ if is_start: ++ if _viztracer.enable: ++ viz_logger.warning(f"[viztracer monkey patch][{process_name}]: Profile is already started") ++ return False, "Profile is already started." ++ _viztracer.start() ++ viz_logger.info(f"[viztracer monkey patch][{process_name}]: \ ++ Profiler viztracer Started") ++ return True, "success" ++ ++ if not _viztracer.enable: ++ viz_logger.warning(f"[viztracer monkey patch][{process_name}]: Profile not started") ++ return False, "Need to start profile first." ++ _viztracer.stop() ++ _viztracer.save(output_file=generate_output_path(process_name)) ++ viz_logger.info(f"[viztracer monkey patch][{process_name}]: \ ++ Profiler viztracer stopped") ++ return True, "success" +diff --git a/dllm_tools/dllm/monkey_patch/viz_profile/viz_profile_plugin.py b/dllm_tools/dllm/monkey_patch/viz_profile/viz_profile_plugin.py +new file mode 100644 +index 000000000..9c4150f64 +--- /dev/null ++++ b/dllm_tools/dllm/monkey_patch/viz_profile/viz_profile_plugin.py +@@ -0,0 +1,20 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++def viz_profile_plugin(): ++ '''Plugin loaded by vllm''' ++ import os ++ if os.environ.get("ENABLE_VIZTRACER_PROFILE", '0') == '1': ++ # Importing this module for viztracer profiling. ++ import dllm.monkey_patch.viz_profile.vllm_engine_core_patch ++ import dllm.monkey_patch.viz_profile.vllm_api_server_patch +diff --git a/dllm_tools/dllm/monkey_patch/viz_profile/vllm_api_server_patch.py b/dllm_tools/dllm/monkey_patch/viz_profile/vllm_api_server_patch.py +new file mode 100644 +index 000000000..653b4e5be +--- /dev/null ++++ b/dllm_tools/dllm/monkey_patch/viz_profile/vllm_api_server_patch.py +@@ -0,0 +1,66 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++# monkey patch for vllm api server ++import sys ++api_server = sys.modules.get("vllm.entrypoints.openai.api_server") ++core_client = sys.modules.get("vllm.v1.engine.core_client") ++async_llm = sys.modules.get("vllm.v1.engine.async_llm") ++router = engine_client = EngineCoreClient = AsyncLLM = None ++ ++if all(v is not None for v in [api_server, core_client, async_llm]): ++ router = getattr(api_server, "router", None) ++ engine_client = getattr(api_server, "engine_client", None) ++ EngineCoreClient = getattr(core_client, "EngineCoreClient", None) ++ AsyncLLM = getattr(async_llm, "AsyncLLM", None) ++ ++if all(v is not None for v in [router, engine_client, EngineCoreClient, AsyncLLM]): ++ from dllm.monkey_patch.viz_profile.common import viz_profile_basic ++ from fastapi import Request ++ from fastapi.responses import Response ++ ++ @router.post("/start_viz_profile") ++ async def start_viz_profile(raw_request: Request): ++ '''Start profile interface.''' ++ _success, _message = viz_profile_basic("api_server", True) ++ if not _success: ++ return Response(status_code=500, content=_message) ++ ++ await engine_client(raw_request).async_llm_start_viz_profile() ++ return Response(status_code=200, content="success") ++ ++ @router.post("/stop_viz_profile") ++ async def stop_viz_profile(raw_request: Request): ++ '''Stop profile interface.''' ++ _success, _message = viz_profile_basic("api_server", False) ++ if not _success: ++ return Response(status_code=500, content=_message) ++ ++ await engine_client(raw_request).async_llm_stop_viz_profile() ++ return Response(status_code=200, content="success") ++ ++ async def async_llm_start_viz_profile(self) -> None: ++ '''New interface for [asyncLLM] -> [engine core client]''' ++ await self.engine_core.ec_client_viz_profile_async(True) ++ ++ async def async_llm_stop_viz_profile(self) -> None: ++ '''New interface for [asyncLLM] -> [engine core client]''' ++ await self.engine_core.ec_client_viz_profile_async(False) ++ ++ AsyncLLM.async_llm_start_viz_profile = async_llm_start_viz_profile ++ AsyncLLM.async_llm_stop_viz_profile = async_llm_stop_viz_profile ++ ++ async def ec_client_viz_profile_async(self, is_start: bool = True) -> None: ++ '''New interface for [engine core client] -> [engine core]''' ++ await self.call_utility_async("engine_core_viz_profile", is_start) ++ EngineCoreClient.ec_client_viz_profile_async = ec_client_viz_profile_async +diff --git a/dllm_tools/dllm/monkey_patch/viz_profile/vllm_engine_core_patch.py b/dllm_tools/dllm/monkey_patch/viz_profile/vllm_engine_core_patch.py +new file mode 100644 +index 000000000..aef500d4d +--- /dev/null ++++ b/dllm_tools/dllm/monkey_patch/viz_profile/vllm_engine_core_patch.py +@@ -0,0 +1,49 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import sys ++core = sys.modules.get("vllm.v1.engine.core") ++abstract = sys.modules.get("vllm.v1.executor.abstract") ++EngineCore = Executor = None ++ ++if all(v is not None for v in [core, abstract]): ++ EngineCore = getattr(core, "EngineCore", None) ++ Executor = getattr(abstract, "Executor", None) ++ ++if all(v is not None for v in [EngineCore, Executor]): ++ from dllm.monkey_patch.viz_profile.common import viz_profile_basic ++ ++ def engine_core_viz_profile(self, is_start: bool = True): ++ '''New interface for [engine core] -> [executor]''' ++ viz_profile_basic("engine_core", is_start) ++ self.model_executor.executor_viz_profile(is_start) ++ EngineCore.engine_core_viz_profile = engine_core_viz_profile ++ ++ ++ def viz_profile(any_value: object, is_start: bool = True): ++ ''' ++ New interface for [executor] -> [worker] ++ since it will be send to remote process, so we need to add dependent internal ++ ''' ++ from dllm.monkey_patch.viz_profile.common import viz_profile_basic ++ viz_profile_basic("worker", is_start) ++ ++ def executor_viz_profile(self, is_start: bool = True): ++ ''' ++ We send the function bytecode directly instead of the function name, ++ in order to avoid importing low-level modules like vllm_ascend, ++ which could lead to circular imports. ++ ''' ++ self.collective_rpc(viz_profile, args=(is_start, )) ++ Executor.executor_viz_profile = executor_viz_profile +diff --git a/dllm_tools/dllm/scripts.py b/dllm_tools/dllm/scripts.py +new file mode 100644 +index 000000000..ccf40a8f7 +--- /dev/null ++++ b/dllm_tools/dllm/scripts.py +@@ -0,0 +1,238 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++from typing import List, Optional ++import logging ++import shlex ++import ray ++import click ++ ++from dllm.controller.controller import Controller ++from dllm.controller.endpoint import deploy_endpoint_to_cluster ++from dllm.logging import setup_logging ++from dllm.constants import ENDPOINT_APPLICATION_NAME, DLLM_NAMESPACE, CONTROLLER_ACTOR_NAME ++from dllm.entities import SchedulerPolicy, DeployConfig ++from dllm.config import ControllerConfig, InferenceInstanceConfig ++ ++setup_logging() ++logger = logging.getLogger(__name__) ++ ++ ++@click.group() ++def cli(): ++ """DLLM Cluster Management""" ++ pass ++ ++ ++@cli.command(name="deploy", context_settings={"show_default": True}) ++@click.option("--head-ip", type=str, help='IP of Ray head node (e.g. "10.2.3.4")', default="auto") ++@click.option("--prefill-instances-num", type=int, help="the num of Prefill instances", default=1) ++@click.option( ++ "--prefill-startup-params", ++ type=str, ++ help="the Prefill instance start up command", ++ default="vllm serve /workspace/models/qwen2.5_7B", ++ callback=lambda ctx, param, value: shlex.split(value), ++) ++@click.option( ++ "--prefill-startup-env", ++ type=str, ++ help="the Prefill instance start up env", ++ default=None, ++) ++@click.option("--prefill-data-parallel-size", "-pdp", type=int, help="the dp of Prefill instances", default=1) ++@click.option("--prefill-tensor-parallel-size", "-ptp", type=int, help="the tp of Prefill instances", default=1) ++@click.option( ++ "--prefill-expert-parallel-size", ++ "-pep", ++ type=int, ++ help="the ep of Prefill instances, should be equal to dp*tp, 0 means disable expert parallelism", ++ default=0, ++) ++@click.option("--decode-instances-num", type=int, help="the num of Decode instances", default=1) ++@click.option( ++ "--decode-startup-params", ++ type=str, ++ help="the Decode instance start up command", ++ default="vllm serve /workspace/models/qwen2.5_7B", ++ callback=lambda ctx, param, value: shlex.split(value), ++) ++@click.option( ++ "--decode-startup-env", ++ type=str, ++ help="the decode instance start up env", ++ default=None, ++) ++@click.option("--decode-data-parallel-size", "-ddp", type=int, help="the dp of Decode instances", default=1) ++@click.option("--decode-tensor-parallel-size", "-dtp", type=int, help="the tp of Decode instances", default=1) ++@click.option( ++ "--decode-expert-parallel-size", ++ "-dep", ++ type=int, ++ help="the ep of Decode instances, should be equal to dp*tp, 0 means disable expert parallelism", ++ default=0, ++) ++@click.option( ++ "--scheduler-policy", ++ type=click.Choice([e.name for e in SchedulerPolicy], case_sensitive=False), ++ help="the scheduling policy, default to RoundRobin", ++ default=SchedulerPolicy.ROUND_ROBIN.name, ++ callback=lambda ctx, param, value: SchedulerPolicy[value.upper()], ++) ++@click.option("--proxy-host", type=str, help="the dllm service listening host", default="0.0.0.0") ++@click.option("--proxy-port", type=int, help="the dllm service listening port", default=8000) ++def deploy( ++ head_ip: str, ++ prefill_instances_num: int, ++ prefill_startup_params: List[str], ++ prefill_startup_env: Optional[str], ++ prefill_data_parallel_size: int, ++ prefill_tensor_parallel_size: int, ++ prefill_expert_parallel_size: int, ++ decode_instances_num: int, ++ decode_startup_params: List[str], ++ decode_startup_env: Optional[str], ++ decode_data_parallel_size: int, ++ decode_tensor_parallel_size: int, ++ decode_expert_parallel_size: int, ++ scheduler_policy: SchedulerPolicy, ++ proxy_host: str, ++ proxy_port: int, ++): ++ """deploy dllm""" ++ _inner_deploy(DeployConfig( ++ head_ip, ++ prefill_instances_num, ++ prefill_startup_params, ++ prefill_startup_env, ++ prefill_data_parallel_size, ++ prefill_tensor_parallel_size, ++ prefill_expert_parallel_size, ++ decode_instances_num, ++ decode_startup_params, ++ decode_startup_env, ++ decode_data_parallel_size, ++ decode_tensor_parallel_size, ++ decode_expert_parallel_size, ++ scheduler_policy, ++ proxy_host, ++ proxy_port) ++ ) ++ ++ ++def _inner_deploy( ++ deploy_config: DeployConfig ++): ++ config = ControllerConfig( ++ scheduler_policy=deploy_config.scheduler_policy, ++ prefill_instances_num=deploy_config.prefill_instances_num, ++ p_inference_instance_config=InferenceInstanceConfig( ++ startup_params=deploy_config.prefill_startup_params, ++ startup_env=deploy_config.prefill_startup_env, ++ dp=deploy_config.prefill_data_parallel_size, ++ tp=deploy_config.prefill_tensor_parallel_size, ++ ep=deploy_config.prefill_expert_parallel_size, ++ ), ++ decode_instances_num=deploy_config.decode_instances_num, ++ d_inference_instance_config=InferenceInstanceConfig( ++ startup_params=deploy_config.decode_startup_params, ++ startup_env=deploy_config.decode_startup_env, ++ dp=deploy_config.decode_data_parallel_size, ++ tp=deploy_config.decode_tensor_parallel_size, ++ ep=deploy_config.decode_expert_parallel_size, ++ ), ++ ) ++ logger.info(f"deploy_config:{deploy_config}") ++ ++ """Deploy to Ray cluster""" ++ try: ++ logger.info("Connecting to existing Ray cluster at: %s", deploy_config.head_ip) ++ ray.init(address=deploy_config.head_ip, namespace=DLLM_NAMESPACE, ++ runtime_env={"worker_process_setup_hook": setup_logging}) ++ except Exception as e: ++ logger.exception("Failed to connect ray cluster: %s", str(e)) ++ return ++ ++ logger.info("Ray cluster resources: %s", ray.cluster_resources()) ++ ++ should_start_controller = False ++ try: ++ controller = ray.get_actor(CONTROLLER_ACTOR_NAME) ++ logger.exception( ++ "There is already an dllm controller running in the cluster, please clean dllm before " "deploy again" ++ ) ++ except ValueError: ++ should_start_controller = True ++ ++ if not should_start_controller: ++ return ++ ++ logger.info("No existing Controller found, creating new instance") ++ controller = ray.remote(Controller).options( ++ name=CONTROLLER_ACTOR_NAME, ++ lifetime="detached", ++ ).remote(config) ++ ray.get(controller.initialize.remote()) ++ logger.info("Controller actor created.") ++ ++ try: ++ ray.serve.shutdown() ++ deploy_endpoint_to_cluster(deploy_config.proxy_host, deploy_config.proxy_port) ++ logger.info("Deployment completed successfully") ++ except Exception as e: ++ logger.exception("Deployment failed: %s", str(e)) ++ ++ ++@cli.command("clean", context_settings={"show_default": True}) ++@click.option("--head-ip", type=str, help='IP of Ray head node (e.g. "10.2.3.4")', default="auto") ++@click.option("--shutdown-ray-serve/--no-shutdown-ray-serve", type=bool, is_flag=True, ++ help="whether or not to shutdown Ray serve proxy", default=True) ++def clean(head_ip, shutdown_ray_serve): ++ """Clean up deployment from Ray cluster""" ++ _inner_clean(head_ip, shutdown_ray_serve) ++ ++ ++def _inner_clean(head_ip, shutdown_ray_serve): ++ try: ++ logger.info("Connecting to existing Ray cluster at: %s", head_ip) ++ ray.init(address=head_ip, namespace=DLLM_NAMESPACE, log_to_driver=False, ++ runtime_env={"worker_process_setup_hook": setup_logging}) ++ except Exception as e: ++ logger.exception("Failed to connect ray cluster: %s", str(e)) ++ return ++ ++ if shutdown_ray_serve: ++ ray.serve.shutdown() ++ else: ++ try: ++ ray.serve.delete(ENDPOINT_APPLICATION_NAME) ++ except Exception as e: ++ logger.warning("Cleanup endpoint failed: %s", str(e)) ++ ++ controller = None ++ try: ++ controller = ray.get_actor(CONTROLLER_ACTOR_NAME) ++ logger.info("Found existing Controller actor, attempting to kill it") ++ ray.get(controller.terminate.remote()) ++ except ValueError: ++ logger.info("No existing Controller actor found, nothing to clean") ++ except Exception as e: ++ logger.info(f"Failed to clean up controller {e}") ++ finally: ++ if controller: ++ ray.kill(controller) ++ ++ ++if __name__ == "__main__": ++ cli() +diff --git a/dllm_tools/dllm/utils.py b/dllm_tools/dllm/utils.py +new file mode 100644 +index 000000000..6ac91b1f0 +--- /dev/null ++++ b/dllm_tools/dllm/utils.py +@@ -0,0 +1,140 @@ ++#!/usr/bin/env python3 ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import errno ++import glob ++import logging ++import socket ++import psutil ++import ray ++ ++logger = logging.getLogger(__name__) ++ ++ ++def ray_run_on_every_nodes(func, *args, **kwargs): ++ """run a func on every node of ray """ ++ unique_ips = set([node["NodeManagerAddress"] for node in ray.nodes() if node["Alive"]]) ++ futures = [ray.remote(func).options(resources={f"node:{ip}": 0.01}).remote(*args, **kwargs) for ip in unique_ips] ++ return ray.get(futures) ++ ++ ++def get_num_npus() -> int: ++ """get npu number from `/dev/davinci?`""" ++ try: ++ return len(glob.glob("/dev/davinci[0-9]*")) ++ except Exception as e: ++ logger.error("Failed to get npu number! exception: %s.", e) ++ pass ++ return 0 ++ ++ ++def find_node_ip(address: str = "8.8.8.8:53") -> str: ++ """ ++ NOTE: this implementation is adapted from ray-project/ray, see: ++ https://github.com/ray-project/ray/blob/aa2dede7f795d21407deebf4cefc61fd00e68e84/python/ray/_private/services.py#L637 ++ ++ IP address by which the local node can be reached *from* the `address`. ++ ++ Args: ++ address: The IP address and port of any known live service on the ++ network you care about. ++ ++ Returns: ++ The IP address by which the local node can be reached from the address. ++ """ ++ ip_address, port = address.split(":") ++ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) ++ try: ++ # This command will raise an exception if there is no internet ++ # connection. ++ s.connect((ip_address, int(port))) ++ node_ip_address = s.getsockname()[0] ++ except OSError as e: ++ node_ip_address = "127.0.0.1" ++ # [Errno 101] Network is unreachable ++ if e.errno == errno.ENETUNREACH: ++ try: ++ # try get node ip address from host name ++ host_name = socket.getfqdn(socket.gethostname()) ++ node_ip_address = socket.gethostbyname(host_name) ++ except Exception: ++ logger.error("find node ip error, host_name: %s.", host_name) ++ pass ++ finally: ++ s.close() ++ ++ return node_ip_address ++ ++ ++def find_free_port(address: str = "") -> str: ++ """ ++ find one free port ++ ++ Returns: ++ port ++ """ ++ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: ++ s.bind((address, 0)) ++ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) ++ return str(s.getsockname()[1]) ++ ++ ++def find_interface_by_ip(ip_address): ++ """ ++ Find the network interface name associated with the given IP address. ++ ++ Args: ++ ip_address (str): The IP address to look up (e.g., "192.168.1.100"). ++ ++ Returns: ++ str: The name of the matching network interface (e.g., "eth0" or "wlan0"), or None if not found. ++ """ ++ interfaces = psutil.net_if_addrs() ++ ++ for interface_name, addresses in interfaces.items(): ++ for address in addresses: ++ if address.family == socket.AF_INET and address.address == ip_address: ++ return interface_name ++ ++ # Return None if no match is found ++ return None ++ ++ ++def find_ip_by_interface(interface_name: str): ++ """ ++ Find the IP address associated with the given network interface name. ++ ++ Args: ++ interface_name (str): The name of the network interface (e.g., "eth0", "wlan0"). ++ ++ Returns: ++ str: The IP address associated with the interface, or None if not found. ++ """ ++ # Get all network interfaces and their addresses ++ interfaces = psutil.net_if_addrs() ++ ++ # Check if the interface exists ++ if interface_name not in interfaces: ++ return None ++ ++ # Determine the address family (IPv4 or IPv6) ++ family = socket.AF_INET # IPv6: 10 (AF_INET6), IPv4: 2 (AF_INET) ++ ++ # Iterate through the addresses of the specified interface ++ for address in interfaces[interface_name]: ++ if address.family == family: ++ return address.address ++ ++ # Return None if no matching IP address is found ++ return None +diff --git a/dllm_tools/launch_test.py b/dllm_tools/launch_test.py +new file mode 100644 +index 000000000..5be3183b5 +--- /dev/null ++++ b/dllm_tools/launch_test.py +@@ -0,0 +1,103 @@ ++import argparse ++import fnmatch ++import glob ++import os ++import random ++import socket ++import subprocess ++import sys ++import threading ++import time ++from os import path ++ ++BASE_DIR = path.dirname(path.realpath(__file__)) ++PIP_VERSION = "pip3.11" ++ ++ ++def remove_file(file): ++ if path.exists(file): ++ os.remove(file) ++ ++def timeout_run_case_wrapper(*args): ++ run_cmd = args[0] ++ start_time = time.time() ++ try: ++ subprocess.run(run_cmd, timeout=300, stdout=sys.stdout, stderr=subprocess.STDOUT, check=True) ++ end_time = time.time() ++ execute_time = end_time - start_time ++ print(f"The cmd {' '.join(run_cmd)} executed cost {int(execute_time)} seconds") ++ return 0 ++ except subprocess.TimeoutExpired: ++ print(f"The cmd {' '.join(run_cmd)} timed out after 300 seconds") ++ return 1 ++ except subprocess.CalledProcessError as e: ++ print(f"The cmd {' '.join(run_cmd)} exited with non-zero code {e.returncode}") ++ return e.returncode ++ ++ ++def install_whl(whl_path, whl_pattern): ++ whl_files = glob.glob(os.path.join(whl_path, whl_pattern)) ++ ++ if whl_files: ++ for whl_file in whl_files: ++ whl_file_name = os.path.basename(whl_file) ++ print(f"[INFO][Main] Installing {whl_file_name}") ++ try: ++ subprocess.run([PIP_VERSION, "install", whl_file]) ++ print(f"[INFO][Main] Successfully installed {whl_file_name}") ++ break ++ except subprocess.CalledProcessError as e: ++ print(f"[ERROR][Main] Failed to install {whl_file_name}: {e}") ++ else: ++ print(f"[WARNING][Main] No {whl_pattern} found in", whl_path) ++ ++ ++def launch(timeout): ++ log_file = path.join(BASE_DIR, "log.txt") ++ remove_file(log_file) ++ report_file = path.join(BASE_DIR, "log.xml") ++ remove_file(report_file) ++ install_whl("/workspace/Ascend/ascend-toolkit", "vllm*.whl") ++ subprocess.run(["mkdir", "-p", "/tmp/yr_sessions"], check=True) ++ subprocess.run( ++ [ ++ PIP_VERSION, ++ "install", ++ "pytest", ++ "pytest-cov", ++ "pytest-mock", ++ "pytest-xdist", ++ "pytest-asyncio", ++ "torch-npu==2.5.1rc1", ++ ], ++ check=True, ++ ) ++ subprocess.run([PIP_VERSION, "install", "--force-reinstall", "numpy==1.26.4"]) ++ ++ pytest_cmd = [ ++ "python3.11", ++ "-m", ++ "pytest", ++ "-vs", ++ "--junitxml=" + report_file, ++ ] ++ ++ try: ++ result = timeout_run_case_wrapper(pytest_cmd, timeout) ++ if not result: ++ print("All tests passed successfully!") ++ return 0 ++ else: ++ print(f"Tests failed with return code: {result}") ++ return 1 ++ except Exception as e: ++ print(f"An error occurred during test execution: {e}") ++ return 1 ++ ++ ++if __name__ == "__main__": ++ parser = argparse.ArgumentParser() ++ parser.add_argument("--timeout", type=int, default=300, help="Timeout for test execution in seconds") ++ args = parser.parse_args() ++ exit_code = launch(args.timeout) ++ sys.exit(exit_code) +diff --git a/dllm_tools/pyproject.toml b/dllm_tools/pyproject.toml +new file mode 100644 +index 000000000..9b3c1d462 +--- /dev/null ++++ b/dllm_tools/pyproject.toml +@@ -0,0 +1,51 @@ ++[build-system] ++requires = ["setuptools>=61.0", "wheel==0.43.0", "pybind11==2.13.6", "cmake"] ++build-backend = "setuptools.build_meta" ++ ++[project] ++name = "dllm" ++version = "0.0.1" ++description = "distributed vllm" ++ ++requires-python = ">=3.9" ++dependencies = [ ++ "requests>=2.25.1", ++ "numpy==1.26.4", ++ "fastapi==0.115.11", ++ "aiohttp==3.11.10", ++ "uvicorn", ++ "ray[serve]==2.46.0", ++ "click", ++ "pybind11==2.13.6", ++ "psutil", ++ "prometheus_client", ++] ++classifiers = [ ++ "Programming Language :: Python :: 3.9", ++ "Operating System :: OS Independent", ++] ++ ++[project.scripts] ++dllm = "dllm.scripts:cli" ++ ++[project.entry-points."vllm.general_plugins"] ++viz_profile = "dllm.monkey_patch.viz_profile.viz_profile_plugin:viz_profile_plugin" ++ ++[project.optional-dependencies] ++dev = [ ++ "pytest>=7.0.0", ++ "black>=23.0.0", ++ "sphinx", ++ "sphinx-design", ++ "myst-parser", ++ "sphinx-click", ++] ++build = [ ++ "wheel==0.43.0", ++ "twine>=4.0.0", ++ "build>=0.10.0", ++] ++ ++[tool.black] ++line-length = 120 ++target-version = ['py39', 'py310', 'py311'] +diff --git a/dllm_tools/pytest.ini b/dllm_tools/pytest.ini +new file mode 100644 +index 000000000..f1169d20a +--- /dev/null ++++ b/dllm_tools/pytest.ini +@@ -0,0 +1,9 @@ ++[pytest] ++pythonpath = ./ ++testpaths = ++ ./tests/ut ++markers = ++ smoke: quick and must-have ++ asyncio: for async def ++ st: system test ++addopts = --cov=dllm --cov-report html --cov-report term --cov-config=.coveragerc +\ No newline at end of file +diff --git a/dllm_tools/requirements.txt b/dllm_tools/requirements.txt +new file mode 100644 +index 000000000..f740a1f15 +--- /dev/null ++++ b/dllm_tools/requirements.txt +@@ -0,0 +1,12 @@ ++pybind11==2.13.6 ++wheel==0.43.0 ++requests>=2.25.1 ++numpy==1.26.4 ++fastapi==0.115.11 ++aiohttp==3.11.10 ++uvicorn==0.35.0 ++ray[serve]==2.46.0 ++click==8.1.8 ++prometheus_client==0.21.1 ++pandas==2.2.3 ++# simpy +\ No newline at end of file +diff --git a/dllm_tools/setup.py b/dllm_tools/setup.py +new file mode 100644 +index 000000000..1e64b78c6 +--- /dev/null ++++ b/dllm_tools/setup.py +@@ -0,0 +1,163 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ++# This software is licensed under Mulan PSL v2. ++# You can use this software according to the terms and conditions of the Mulan PSL v2. ++# You may obtain a copy of Mulan PSL v2 at: ++# ++# http://license.coscl.org.cn/MulanPSL2 ++# ++# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, ++# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, ++# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. ++# See the Mulan PSL v2 for more details. ++ ++import logging ++import os ++import shutil ++import subprocess ++import sys ++ ++import setuptools ++from setuptools import Extension ++from setuptools.command.build_ext import build_ext ++ ++logger = logging.getLogger(__name__) ++ ++# can not specify [tool.setuptools.packages.find] in pyproject.toml ++dllm_build_for = os.environ.get("DLLM_BUILD_FOR", "default") ++ ++ ++class CMakeExtension(Extension): ++ def __init__(self, name, sourcedir=""): ++ super().__init__(name, sources=[]) ++ self.sourcedir = os.path.abspath(sourcedir) ++ ++ ++class CMakeBuild(build_ext): ++ ++ def __init__(self, *args, **kwargs): ++ super().__init__(*args, **kwargs) ++ self.pybind_path = None ++ self.enable_perf = None ++ self.ds_include_dir = None ++ self.ds_lib_dir = None ++ ++ def run(self): ++ self._find_pybind11() ++ self._find_ds() ++ self.enable_perf = os.environ.get("ENABLE_PERF", "False") ++ # Find the absolute path of the cmake executable ++ cmake_executable = shutil.which("cmake") ++ ++ if cmake_executable is None: ++ raise RuntimeError( ++ "CMake must be installed to build the following extensions: " ++ + ", ".join(e.name for e in self.extensions) ++ ) ++ ++ if cmake_executable is None: ++ raise RuntimeError( ++ "CMake must be installed to build the following extensions: " ++ + ", ".join(e.name for e in self.extensions) ++ ) ++ for ext in self.extensions: ++ self.build_extension(ext) ++ ++ def build_extension(self, ext): ++ extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) ++ cmake_args = [ ++ "-G Ninja", ++ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", ++ f"-DPYTHON_EXECUTABLE={sys.executable}", ++ f"-Dpybind11_DIR={self.pybind_path}", ++ f"-DENABLE_PERF={self.enable_perf}", ++ ] ++ if self.ds_lib_dir and self.ds_include_dir: ++ cmake_args.append(f"-DDatasystem_INCLUDE_DIR={self.ds_include_dir}") ++ cmake_args.append(f"-DDatasystem_LIBRARY_DIR={self.ds_lib_dir}") ++ else: ++ raise RuntimeError("No datasystem dependency when building.") ++ ++ build_args = ["--", "-j8"] ++ ++ build_tmp = os.path.join(os.getcwd(), "build") ++ if self.build_temp is not None: ++ build_tmp = self.build_temp ++ ++ if not os.path.exists(build_tmp): ++ os.makedirs(build_tmp) ++ ++ # Run CMake configuration ++ subprocess.check_call([shutil.which("cmake"), ext.sourcedir] + cmake_args, cwd=build_tmp) ++ # Build the extension ++ subprocess.check_call([shutil.which("cmake"), "--build", "."] + build_args, cwd=build_tmp) ++ ++ def _find_ds(self): ++ global dllm_build_for ++ if dllm_build_for == "xiaoyi": ++ try: ++ import yr ++ yr_home = os.path.dirname(yr.__file__) ++ except ModuleNotFoundError as e: ++ logger.error("Missing required dependency: yr") ++ raise e ++ self.ds_include_dir = os.path.join(yr_home, "inner", "data_system", "sdk", "cpp", "include") ++ self.ds_lib_dir = yr_home ++ else: ++ ds_home = os.environ.get("DS_DIR", None) ++ if ds_home is None: ++ try: ++ import datasystem ++ ds_home = os.path.dirname(datasystem.__file__) ++ except ModuleNotFoundError as e: ++ logger.error("Missing required dependency: datasystem.") ++ raise e ++ self.ds_include_dir = os.path.join(ds_home, "include") ++ self.ds_lib_dir = os.path.join(ds_home, "lib") ++ logger.info(f"DS_INCLUDE_DIR: {self.ds_include_dir}" ++ f"DS_LIB_DIR: {self.ds_lib_dir}") ++ ++ def _find_pybind11(self): ++ import pybind11 ++ ++ # pybind11 include dir ++ include_path = pybind11.get_include() ++ # See comments from Caleb and MrCrHaM (works also for conda) ++ # https://stackoverflow.com/questions/63254584/how-to-make-cmake-find-pybind11 ++ base_path, _ = include_path.rsplit("/include", 1) ++ self.pybind_path = os.path.join(base_path, "share", "cmake", "pybind11") ++ ++ ++# For including C/C++ headers in the .whl ++# Then the client C++ code would include like: #include ++tmp_include_dir = os.path.join("dllm", "include", "dllm") ++shutil.copytree(os.path.join("csrc", "include"), tmp_include_dir, dirs_exist_ok=True) ++ ++ ++if dllm_build_for == "xiaoyi": ++ packages = ["dllm.kvc", "dllm.include"] ++else: ++ packages = setuptools.find_packages(include=["dllm", "dllm.*"], ++ exclude=["csrc", "cmake", "secbrella", "vllm_patchs"]) ++ ++setuptools.setup( ++ name="dllm", ++ version="0.1", ++ description="Distributed LLM", ++ packages=packages, ++ ext_modules=[CMakeExtension(name="dllm.cpp_ext")], ++ python_requires=">=3.9", ++ cmdclass=dict(build_ext=CMakeBuild), ++ zip_safe=False, ++ entry_points={ ++ "console_scripts": [ ++ "dllm=dllm.scripts:cli", ++ ], ++ "vllm.general_plugins": [ ++ "viz_profile = dllm.monkey_patch.viz_profile.viz_profile_plugin:viz_profile_plugin", ++ ] ++ }, ++ include_package_data=True, ++) ++ ++shutil.rmtree(tmp_include_dir) +-- +2.35.1.windows.2 + diff --git a/vllm.spec b/vllm.spec index 3ebe760..fbf1f3f 100644 --- a/vllm.spec +++ b/vllm.spec @@ -3,12 +3,17 @@ Name: vllm Version: 0.9.1 -Release: 1 +Release: 2 Summary: Powerful engine for LLMs License: (Apache-2.0 AND BSD-3-Clause) OR BSD-3-CLause URL: https://github.com/vllm-project/vllm Source0: https://gitee.com/src-openeuler/vllm/raw/master/vllm-%{version}.tar.gz +Patch1: 0001-bugfix-support-lower-version-setuptools-on-openeuler.patch +Patch2: 0002-bugfix-prefix-cache.patch +Patch3: 0003-bugfix-for-dllm-register.patch +Patch4: 0004-feature-dllm-tools.patch + BuildArch: noarch %description @@ -27,7 +32,7 @@ Buildrequires: python3-pytorch %package_help %prep -%autosetup -n %{name}-%{version} -N +%autosetup -n %{name}-%{version} -p1 %build export SETUPTOOLS_SCM_PRETEND_VERSION=%{version} @@ -69,6 +74,9 @@ mv %{buildroot}/filelist.lst . %files -n python3-%{_name} -f filelist.lst %changelog +* Thu Jul 24 2025 gongzequn - 0.9.1-2 +- Add dllm deploy and clean command support + * Fri Jul 4 2025 gongzequn - 0.9.1-1 - Change the baseline version to 0.9.1 -- Gitee