diff --git a/0005-feature-kv-cache-protect-for-dllm-tools.patch b/0005-feature-kv-cache-protect-for-dllm-tools.patch new file mode 100644 index 0000000000000000000000000000000000000000..ed9e5c54317118ceabd7a5d50e0f6f9c703758c6 --- /dev/null +++ b/0005-feature-kv-cache-protect-for-dllm-tools.patch @@ -0,0 +1,583 @@ +From 2d905ddb84c1d07cc9d69dfd7c6924eeedd0eeed Mon Sep 17 00:00:00 2001 +From: hanzhibin +Date: Wed, 30 Jul 2025 15:11:21 +0800 +Subject: [PATCH] feature: kv cache protect for dllm tools + +--- + dllm_tools/README.md | 8 + + dllm_tools/dllm/controller/vllm_instance.py | 15 +- + dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py | 194 ++++++++++++++++++ + dllm_tools/dllm/dkvc/v1/sec_mask/__init__.py | 0 + dllm_tools/dllm/dkvc/v1/sec_mask/crypto.py | 19 ++ + .../dllm/dkvc/v1/sec_mask/sec_mask_manager.py | 176 ++++++++++++++++ + 6 files changed, 411 insertions(+), 1 deletion(-) + create mode 100644 vllm-0.9.1/dllm_tools/dllm/dkvc/v1/sec_mask/__init__.py + create mode 100644 vllm-0.9.1/dllm_tools/dllm/dkvc/v1/sec_mask/crypto.py + create mode 100644 vllm-0.9.1/dllm_tools/dllm/dkvc/v1/sec_mask/sec_mask_manager.py + +diff --git a/dllm_tools/README.md b/dllm_tools/README.md +index b3f7eb0..4770681 100644 +--- a/dllm_tools/README.md ++++ b/dllm_tools/README.md +@@ -82,3 +82,11 @@ curl -X POST "http://127.0.0.1:8000/v1/completions" -H "Content-Type: applicatio + "temperature": 0 + }' + ``` ++ ++### enable kv cache protect ++ ++To prevent private data leakage, dllm support kv cache protect by encrypt kv cache data when transmitting between prefill and decode instance in PD disaggregated deployment ++ ++Kv cache data is encrypt by sec-mask in parallel with inference to enhance encryption performance ++ ++To enable kv cache protect, you need to set environment **before start Ray**: `ENABLE_KVC_PROTECT=True` +\ No newline at end of file +diff --git a/dllm_tools/dllm/controller/vllm_instance.py b/dllm_tools/dllm/controller/vllm_instance.py +index ee48b84..f39d712 100644 +--- a/dllm_tools/dllm/controller/vllm_instance.py ++++ b/dllm_tools/dllm/controller/vllm_instance.py +@@ -15,6 +15,8 @@ + import asyncio + import json + import uuid ++import base64 ++import secrets + + import subprocess + import sys +@@ -36,6 +38,8 @@ from dllm.utils import find_node_ip, find_free_port, find_interface_by_ip, find_ + logger = logging.getLogger(__name__) + + ++AES_256_KEY_SIZE = 32 ++ + def select_distributed_torch_interface(): + """ + Determines the preferred network interface for distributed PyTorch communication. +@@ -91,6 +95,8 @@ class VllmInstance: + + self.__has_process_started = False + ++ self._secret_key = "" ++ + async def init_dp_master_ip_port(self): + """ + if dp config is None, init dp master +@@ -229,6 +235,11 @@ class VllmInstance: + ] + ) + elif self._vllm_instance_config.pd_config.is_disaggregated_p_d(): ++ enable_kvc_protect = True if os.environ.get("ENABLE_KVC_PROTECT", "FALSE").lower() == "true" else False ++ if enable_kvc_protect: ++ secret_key = secrets.token_bytes(AES_256_KEY_SIZE) ++ self._secret_key = base64.b64encode(secret_key).decode('ascii') ++ + self._vllm_instance_config.exec_cmd.extend( + [ + "--kv-transfer-config", +@@ -243,7 +254,9 @@ class VllmInstance: + "kv_parallel_size": 2, + "kv_rank": 0 if self._vllm_instance_config.pd_config.role is Role.PREFILL else 1, + "kv_connector_extra_config": { +- "device_ids": [int(i) for i in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")] ++ "device_ids": [int(i) for i in os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",")], ++ "enable_kvc_protect": enable_kvc_protect, ++ "secret_key": self._secret_key + }, + } + ), +diff --git a/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py b/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py +index 8580c5c..3be5dfd 100644 +--- a/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py ++++ b/dllm_tools/dllm/dkvc/v1/dllm_ds_connector.py +@@ -19,6 +19,9 @@ from typing import TYPE_CHECKING, List, Optional, Any + import threading + from collections import defaultdict + import asyncio ++import base64 ++import atexit ++import ml_dtypes + + import numpy + import torch +@@ -32,11 +35,20 @@ from vllm.distributed.parallel_state import (get_world_group, get_tp_group) + + from dllm.cpp_ext.kvc import KvcStore, KvcFuture + from dllm.kvc import TorchAdaptor ++from dllm.dkvc.v1.sec_mask.sec_mask_manager import (SecMaskManager, SecMaskMode, SecMaskConfig, DEFAULE_IV_SIZE) ++from dllm.dkvc.dllm_cache_engine import DLLMCacheEngine + + ENABLE_PREFIX_CACHING = int(os.environ.get("USING_PREFIX_CONNECTOR", 0)) + FUTURE_TIMEOUT = int(os.getenv("FUTURE_TIMEOUT", 0)) + SLEEP_TIMEOUT = 0.005 + ++DTYPE_SIZE_TO_NUMPY_INT_DTYPE = { ++ 1: numpy.uint8, ++ 2: numpy.uint16, ++ 4: numpy.uint32, ++ 8: numpy.uint64 ++} ++ + if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext +@@ -311,6 +323,25 @@ class DLLMDsConnector(KVConnectorBase_V1): + else: + self._async_handler = None + ++ # create sec mask manager ++ self._enable_kvc_protect = vllm_config.kv_transfer_config.kv_connector_extra_config["enable_kvc_protect"] ++ logger.debug(f"enable kv cache protect: {self._enable_kvc_protect}") ++ if self._enable_kvc_protect: ++ # _block_data_size: bytes size of a block ++ self._block_data_size = get_block_data_size(vllm_config) ++ self._layer_num = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) ++ # _mask_block_size: bytes size of one layer of a block ++ self._mask_block_size = self._block_data_size // self._layer_num if self._layer_num >= 0 else 0 ++ if role == KVConnectorRole.WORKER: ++ sec_mask_mode = SecMaskMode.MASTER if self.is_producer else SecMaskMode.WORKER ++ secret_key_str = vllm_config.kv_transfer_config.kv_connector_extra_config["secret_key"] ++ secret_key = base64.b64decode(secret_key_str.encode('ascii')) ++ self._sec_mask_manager = SecMaskManager(config=SecMaskConfig(mode=sec_mask_mode, ++ process_parallel_num=4, ++ symmetric_key=secret_key)) ++ self._load_req_id_to_block_id : dict[str, List[int]] = {} ++ atexit.register(self._sec_mask_manager.close) ++ + def start_event_loop(self): + """start event loop""" + self.loop.run_until_complete(asyncio.gather(*self.task_list)) +@@ -371,6 +402,11 @@ class DLLMDsConnector(KVConnectorBase_V1): + """ + # effective only when prefix cache is disabled and the role is producer. + if self.is_producer and not ENABLE_PREFIX_CACHING: ++ if self._enable_kvc_protect: ++ # start to generate mask ++ metadata: KVConnectorMetadata = self._get_connector_metadata() ++ for request in metadata.requests: ++ self._gen_mask_for_request(request.request_id, len(request.block_ids) * self._block_data_size) + return + + metadata: KVConnectorMetadata = self._get_connector_metadata() +@@ -381,6 +417,9 @@ class DLLMDsConnector(KVConnectorBase_V1): + self._init_kv_caches_from_forward_context(forward_context) + + for request in metadata.requests: ++ if self._enable_kvc_protect: ++ # start to generate mask ++ self._gen_mask_for_request(request.request_id, len(request.block_ids) * self._block_data_size) + if self._async_handler is not None: + self._load_request_queue.put_nowait(request) + else: +@@ -398,6 +437,10 @@ class DLLMDsConnector(KVConnectorBase_V1): + finished_saved_req = self._async_handler.get_save_finished(finished_req_ids) + + if not self.is_producer or ENABLE_PREFIX_CACHING: ++ if self._enable_kvc_protect and finished_loaded_req is not None: ++ for req_id in finished_loaded_req: ++ block_ids = self._load_req_id_to_block_id.pop(req_id, None) ++ self._get_and_mask_kvc_layerwise(req_id, block_ids) + finished_loaded_req = self._async_handler.get_load_finished() + + if self.tp_size == 1: +@@ -464,9 +507,14 @@ class DLLMDsConnector(KVConnectorBase_V1): + future = self.kvc_store.mget_page_attn_blockwise_h2d(key_cache_key_list, self.key_caches, block_id_list) + future_1 = self.kvc_store.mget_page_attn_blockwise_h2d(value_cache_key_list, self.value_caches, + block_id_list) ++ if self._enable_kvc_protect: ++ self._load_req_id_to_block_id[request.request_id] = block_id_list ++ + if not self.do_async_save: + get_future(future) + get_future(future_1) ++ if self._enable_kvc_protect: ++ self._get_and_mask_kvc_layerwise(request.request_id, block_id_list) + else: + self._async_handler.add_load_request(request, 2) + self._async_handler.add_load_future(request, future) +@@ -476,6 +524,10 @@ class DLLMDsConnector(KVConnectorBase_V1): + return + + future = self.kvc_store.mget_page_attn_blockwise_h2d(key_list, self.kv_caches, block_id_list) ++ ++ if self._enable_kvc_protect: ++ self._load_req_id_to_block_id[request.request_id] = block_id_list ++ + if not self.do_async_save: + get_future(future) + else: +@@ -529,6 +581,10 @@ class DLLMDsConnector(KVConnectorBase_V1): + if not request.block_ids: + return + ++ if self._enable_kvc_protect: ++ # encrypt kv cache with mask ++ self._get_and_mask_kvc_layerwise(request.request_id, request.block_ids) ++ + token_key_list = self.generate_kv_cache_token_key(request) + if not self.is_mla: + key_cache_key_list = token_key_list +@@ -797,6 +853,129 @@ class DLLMDsConnector(KVConnectorBase_V1): + else: + self.kv_caches.append(kv_layer) + ++ def _gen_mask_for_request(self, request_id, mask_size): ++ if mask_size == 0: ++ return ++ tp_rank_str = str(self.tp_rank) + "-" ++ if self.is_mla: ++ kv_mask_key = "mask-kv-" + tp_rank_str + request_id ++ self._gen_mask_for_blocks(kv_mask_key, mask_size) ++ else: ++ k_mask_key = "mask-k-" + tp_rank_str + request_id ++ v_mask_key = "mask-v-" + tp_rank_str + request_id ++ self._gen_mask_for_blocks(k_mask_key, mask_size) ++ self._gen_mask_for_blocks(v_mask_key, mask_size) ++ ++ def _gen_mask_for_blocks(self, mask_key: str, mask_size: int): ++ """ ++ In prefill, generate iv and mask, put iv into datasystem; ++ In decode, get iv from datasystem, generate mask ++ ++ Args: ++ mask_key: mask key ++ mask_size: mask size ++ """ ++ if self.is_producer: ++ extra_key = "producer-" ++ else: ++ extra_key = "consumer-" ++ iv_tensor = torch.zeros((DEFAULE_IV_SIZE // 4,), dtype=torch.float32).npu() ++ iv_future = self.kvc_store.mget_tensors_h2d([mask_key], [[iv_tensor]]) ++ get_future(iv_future) ++ iv_bytes = iv_tensor.cpu().numpy().tobytes() ++ self._sec_mask_manager.set_iv(mask_key, iv_bytes) ++ ++ self._sec_mask_manager.generate_mask(mask_key, extra_key, mask_size) ++ ++ if self.is_producer: ++ iv_bytes = self._sec_mask_manager.get_iv(mask_key) ++ if iv_bytes is not None: ++ iv_tensor = torch.frombuffer(buffer=iv_bytes, dtype=torch.float32).npu() ++ iv_future = self.kvc_store.mset_tensors_d2h([mask_key], [[iv_tensor]]) ++ get_future(iv_future) ++ else: ++ logger.error("get iv failed when set iv d2h after generate mask") ++ raise ValueError("get iv failed when set iv d2h after generate mask") ++ ++ def _get_mask_layerwise(self, mask_key: str): ++ """ ++ get mask for each layer tensor ++ """ ++ mask_bytes = self._sec_mask_manager.get_mask(mask_key) ++ mask_size = len(mask_bytes) ++ ++ assert self._layer_num > 0 ++ assert mask_size % self._layer_num == 0 ++ mask_size_per_layer = mask_size // self._layer_num ++ mask_ndarray_list = [numpy.ndarray(shape=(mask_size_per_layer,), dtype=numpy.uint8, ++ buffer=mask_bytes[i:i + mask_size_per_layer]) for i in range(0, mask_size, mask_size_per_layer)] ++ return mask_ndarray_list ++ ++ def _mask_kvc_layerwise(self, kvc_list: List[torch.Tensor], mask_list: List[numpy.ndarray]): ++ """ ++ mask/unmask kvc tensor with mask ++ """ ++ assert len(kvc_list) == len(mask_list) ++ if len(kvc_list) == 0: ++ return ++ # need to skip the masks of hit blocks if prefix caching hit ++ kvc_block_num = kvc_list[0].shape[0] ++ assert self._mask_block_size != 0 ++ mask_block_num = mask_list[0].nbytes // self._mask_block_size ++ logger.debug(f"kvc block num: {kvc_block_num}, mask_block_num: {mask_block_num}") ++ if mask_block_num > kvc_block_num: ++ skip_size = (mask_block_num - kvc_block_num) * self._mask_block_size ++ mask_list = [mask[skip_size:] for mask in mask_list] ++ ++ # reinterpret to the same shape and dtype to do bitwise mask/unmask ++ uint_dtype = DTYPE_SIZE_TO_NUMPY_INT_DTYPE[kvc_list[0].element_size()] ++ kvc_array_list = [kvc_tensor.numpy().view(uint_dtype) for kvc_tensor in kvc_list] ++ mask_list = [mask.view(kvc_array_list[0].dtype).reshape(kvc_array_list[0].shape) for mask in mask_list] ++ ++ # xor kvc with mask ++ masked_kvc_list = [numpy.bitwise_xor(kvc_array_list[i], mask_list[i]) for i in range(len(kvc_array_list))] ++ masked_kvc_tensor = [torch.from_numpy(masked_kvc.view(ml_dtypes.bfloat16)).npu() for masked_kvc in masked_kvc_list] ++ return masked_kvc_tensor ++ ++ def _get_and_mask_kvc_layerwise(self, request_id, block_ids) -> List[torch.Tensor]: ++ """ ++ mask/unmask kv cache with sec mask ++ """ ++ if block_ids is None: ++ return ++ tp_rank_str = str(self.tp_rank) + "-" ++ if self.is_mla: ++ target_kv_caches = [self.kv_caches[i][block_ids] for i in range(len(self.kv_caches))] ++ kv_mask_key = "mask-kv-" + tp_rank_str + request_id ++ kv_masks = self._get_mask_layerwise(kv_mask_key) ++ masked_kv = self._mask_kvc_layerwise(target_kv_caches, kv_masks) ++ ++ index = torch.tensor(block_ids).npu() ++ index_shape = [len(block_ids)] + [1] * (len(masked_kv[0].shape) - 1) ++ index = index.view(*index_shape).expand(*masked_kv[0].shape) ++ for i in range(len(self.kv_caches)): ++ torch.scatter(self.kv_caches[i], 0, index, masked_kv[i]) ++ ++ else: ++ target_key_caches = [self.key_caches[i][block_ids] for i in range(len(self.key_caches))] ++ target_value_caches = [self.value_caches[i][block_ids] for i in range(len(self.value_caches))] ++ ++ k_mask_key = "mask-k-" + tp_rank_str + request_id ++ v_mask_key = "mask-v-" + tp_rank_str + request_id ++ ++ k_masks = self._get_mask_layerwise(k_mask_key) ++ v_masks = self._get_mask_layerwise(v_mask_key) ++ ++ masked_k = self._mask_kvc_layerwise(target_key_caches, k_masks) ++ masked_v = self._mask_kvc_layerwise(target_value_caches, v_masks) ++ ++ index = torch.tensor(block_ids).npu() ++ index_shape = [len(block_ids)] + [1] * (len(masked_k[0].shape) - 1) ++ index = index.view(*index_shape).expand(*masked_k[0].shape) ++ for i in range(len(self.key_caches)): ++ torch.scatter(self.key_caches[i], 0, index, masked_k[i]) ++ torch.scatter(self.value_caches[i], 0, index, masked_v[i]) ++ + + def extract_number(s): + """extract number""" +@@ -847,3 +1026,18 @@ def get_future(fut: KvcFuture) -> List[str]: + return RequestStatus.TIMEOUT + + return RequestStatus.FINISHED ++ ++ ++def get_block_data_size(vllm_config) -> int: ++ """ ++ calculate total block data size (bytes), used for generating sec mask to encrypt/decrypt kvc for saving/loading ++ per_block_size = block_size * layer_num * kv_head_num * head_size * dtype_size ++ """ ++ block_size = vllm_config.cache_config.block_size ++ layer_num = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) ++ dtype = vllm_config.model_config.dtype ++ kv_head_num = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) ++ head_size = vllm_config.model_config.get_head_size() ++ dtype_size = DLLMCacheEngine._get_dtype_size(dtype) ++ ++ return block_size * layer_num * kv_head_num * head_size * dtype_size +\ No newline at end of file +diff --git a/dllm_tools/dllm/dkvc/v1/sec_mask/__init__.py b/dllm_tools/dllm/dkvc/v1/sec_mask/__init__.py +new file mode 100644 +index 0000000..e69de29 +diff --git a/dllm_tools/dllm/dkvc/v1/sec_mask/crypto.py b/dllm_tools/dllm/dkvc/v1/sec_mask/crypto.py +new file mode 100644 +index 0000000..7200ee4 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/v1/sec_mask/crypto.py +@@ -0,0 +1,19 @@ ++from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes ++ ++ ++class CryptoBase(): ++ def encrypt(self, iv: bytes, data: bytes) -> bytes: ++ raise NotImplementedError ++ ++ def decrypt(self, iv: bytes, data: bytes) -> bytes: ++ raise NotImplementedError ++ ++ ++class CryptoCryptography(CryptoBase): ++ def __init__(self, key: bytes) -> None: ++ self._key = key ++ ++ def encrypt(self, iv: bytes, data: bytes) -> bytes: ++ cipher = Cipher(algorithms.AES(self._key), modes.CTR(iv)) ++ encryptor = cipher.encryptor() ++ return encryptor.update(data) + encryptor.finalize() +\ No newline at end of file +diff --git a/dllm_tools/dllm/dkvc/v1/sec_mask/sec_mask_manager.py b/dllm_tools/dllm/dkvc/v1/sec_mask/sec_mask_manager.py +new file mode 100644 +index 0000000..fe67684 +--- /dev/null ++++ b/dllm_tools/dllm/dkvc/v1/sec_mask/sec_mask_manager.py +@@ -0,0 +1,176 @@ ++from __future__ import annotations ++from multiprocessing import shared_memory, Pool, Manager ++from dataclasses import dataclass ++from typing import Dict, Optional ++import secrets ++import numpy as np ++from enum import Enum ++from threading import Lock ++import torch ++ ++from vllm.logger import init_logger ++ ++logger = init_logger(f"vllm.sec_mask.{__name__}") ++ ++DEFAULE_GEN_MASK_PARALLEL_NUM = 4 ++DEFAULE_IV_SIZE = 16 ++ ++class SharedMemoryBlock(): ++ def __init__(self, name: str): ++ self.name = name ++ self.shm: shared_memory.SharedMemory = None ++ ++ def create(self, size: int): ++ try: ++ self.shm = shared_memory.SharedMemory(name=self.name, create=True, size=size) ++ logger.debug(f"[MASK] create shm: {self.name}") ++ except FileExistsError: ++ logger.debug(f"[MASK] Shared memory ({self.name}) already exists, try to release") ++ existing_shm = shared_memory.SharedMemory(name=self.name, create=False) ++ existing_shm.close() ++ existing_shm.unlink() ++ return self.create(size) ++ except PermissionError: ++ logger.error(f"Permission denied for creating Shared memory ({self.name}).") ++ return False ++ except Exception as e: ++ logger.error(f"Err: {str(e)}") ++ return False ++ return True ++ ++ def destroy(self): ++ if self.shm is not None: ++ self.shm.close() ++ self.shm.unlink() ++ logger.debug(f"[MASK] destroy shm: {self.name}") ++ ++ ++class SecMaskMode(Enum): ++ MASTER = 0 # for prefill, generate iv and mask ++ WORKER = 1 # for decode, get iv that already exists and generate mask ++ ++ ++@dataclass ++class SecMaskConfig(): ++ '''SecMaskManager config''' ++ mode: SecMaskMode = SecMaskMode.MASTER ++ process_parallel_num: int = DEFAULE_GEN_MASK_PARALLEL_NUM ++ symmetric_key: bytes = None ++ mask_dtype: torch.dtype = None ++ ++ ++def gen_mask(symmetric_key: bytes, iv: bytes, size: int, shm_name: str, event): ++ ''' ++ run in subprocess, generate sec mask and write into shared memory ++ ++ Args: ++ symmetric_key: symmetric key for generating mask ++ iv: initial vector for generating mask ++ size: mask size ++ shm_name: share memory for mask storage ++ event: notification whether mask generation is done ++ ''' ++ shm = shared_memory.SharedMemory(name=shm_name, create=False) ++ # mask is generated from zeros ++ zeros = b'\x00' * size ++ from dllm.dkvc.v1.sec_mask.crypto import CryptoCryptography ++ encryptor = CryptoCryptography(key=symmetric_key) ++ mask_bytes = encryptor.encrypt(iv, zeros) ++ shm.buf[:size] = mask_bytes ++ shm.close() ++ event.set() ++ logger.debug(f"[MASK] generate mask, iv: {iv}, size: {size}, shm_name: {shm_name}") ++ ++ ++@dataclass ++class MaskDataMeta(): ++ shm_block: SharedMemoryBlock ++ size: int ++ event: Event ++ used: bool = False ++ ++ ++class SecMaskManager(): ++ def __init__(self, config: SecMaskConfig = None): ++ self._config = config ++ ++ self._iv_dict: Dict[str, bytes] = {} # mask_key: iv_bytes ++ self._mask_tasks: Dict[str, MaskDataMeta] = {} # mask_key: mask_meta ++ ++ self._shm_block = None ++ self._process_pool = None ++ self._manager = Manager() ++ self._process_pool = Pool(processes=self._config.process_parallel_num) ++ ++ def close(self): ++ if len(self._mask_tasks) > 0: ++ for mask_task in self._mask_tasks.values(): ++ mask_task.shm_block.destroy() ++ self._process_pool.close() ++ self._process_pool.join() ++ self._manager.shutdown() ++ ++ def _gen_iv(self): ++ return secrets.token_bytes(DEFAULE_IV_SIZE) ++ ++ def _gen_mask_err_handler(self, err: Exception, mask_key: str): ++ logger.error(f"mask_key: {mask_key}, type: {type(err)}, err: {err}") ++ self._iv_dict.pop(mask_key) ++ mask_task = self._mask_tasks.pop(mask_key, None) ++ if mask_task: ++ mask_task.shm_block.destroy() ++ ++ def generate_mask(self, mask_key: str, extra_key: str = "", size: int = 0): ++ ''' ++ 1. when in prefill or for encrypt, generate iv and mask ++ 2. when in decode or for decrypt, generate mask with iv that already exists ++ ++ Args: ++ mask_key: key of mask ++ size: size of mask ++ ''' ++ shm_name = "shm-" + extra_key + mask_key ++ shm_block = SharedMemoryBlock(shm_name) ++ if not shm_block.create(size): ++ raise ValueError(f"create shm ({shm_name}) failed") ++ event = self._manager.Event() ++ if self._config.mode is SecMaskMode.MASTER: ++ iv = self._gen_iv() ++ self._iv_dict[mask_key] = iv ++ elif self._config.mode is SecMaskMode.WORKER: ++ iv = self._iv_dict.get(mask_key) ++ if iv is None: ++ logger.error(f"iv of mask_key: {mask_key} is not exist") ++ raise ValueError(f"iv of mask_key: {mask_key} is not exist") ++ ++ ++ self._mask_tasks[mask_key] = MaskDataMeta(shm_block=shm_block, size=size, event=event) ++ self._process_pool.apply_async(gen_mask, args=(self._config.symmetric_key, iv, size, shm_name, event), ++ error_callback=lambda e, mask_key=mask_key: self._gen_mask_err_handler(e, mask_key)) ++ logger.debug(f"start to gen iv and mask, mask_key: {mask_key}, size: {size}, shm_name: {shm_name}, iv: {iv}") ++ ++ def get_mask(self, mask_key: str) -> Optional[bytes]: ++ mask_task = self._mask_tasks.get(mask_key, None) ++ if not mask_task: ++ logger.error(f"[MASK]mask of key ({mask_key}) not exist") ++ raise ValueError(f"mask of key ({mask_key}) not exist") ++ if mask_task.used: ++ logger.error(f"[MASK]mask of key ({mask_key}) has beed used") ++ raise ValueError(f"mask of key ({mask_key}) has beed used") ++ ++ mask_task.event.wait() ++ logger.debug(f"[MASK]get mask task, mask_key: {mask_key}") ++ ++ mask_bytes = bytes(mask_task.shm_block.shm.buf[:mask_task.size]) ++ mask_task.used = True ++ mask_task.shm_block.destroy() ++ ++ del self._mask_tasks[mask_key] ++ logger.debug(f"[MASK] delete mask task when get, mask_key: {mask_key}") ++ return mask_bytes ++ ++ def get_iv(self, mask_key: str) -> Optional[bytes]: ++ return self._iv_dict.get(mask_key, None) ++ ++ def set_iv(self, mask_key: str, iv_bytes: bytes): ++ self._iv_dict[mask_key] = iv_bytes +-- +2.25.1 + diff --git a/vllm.spec b/vllm.spec index fbf1f3ffc00128614ef906f25737fe644dde13bd..6de972a83fc0127dc4e5b7e2f9487007e1857268 100644 --- a/vllm.spec +++ b/vllm.spec @@ -3,7 +3,7 @@ Name: vllm Version: 0.9.1 -Release: 2 +Release: 3 Summary: Powerful engine for LLMs License: (Apache-2.0 AND BSD-3-Clause) OR BSD-3-CLause URL: https://github.com/vllm-project/vllm @@ -13,6 +13,7 @@ Patch1: 0001-bugfix-support-lower-version-setuptools-on-openeuler.patch Patch2: 0002-bugfix-prefix-cache.patch Patch3: 0003-bugfix-for-dllm-register.patch Patch4: 0004-feature-dllm-tools.patch +Patch5: 0005-feature-kv-cache-protect-for-dllm-tools.patch BuildArch: noarch @@ -74,6 +75,9 @@ mv %{buildroot}/filelist.lst . %files -n python3-%{_name} -f filelist.lst %changelog +* Wed Jul 30 2025 hanzhibin - 0.9.1-3 +- Add dllm kv cache protect support + * Thu Jul 24 2025 gongzequn - 0.9.1-2 - Add dllm deploy and clean command support