diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/sparse/jagged_tensor_with_count.py b/torchrec/hybrid_torchrec/hybrid_torchrec/sparse/jagged_tensor_with_count.py index 5e5158ff1f237997d9b6029572bff4c4e991f42a..c613cc8b65f01bfa6762d973cf95c29e0a56dfba 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/sparse/jagged_tensor_with_count.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/sparse/jagged_tensor_with_count.py @@ -8,10 +8,16 @@ from typing import Optional, Dict, List, Tuple import torch + +from torch.autograd.profiler import record_function from torchrec.sparse.jagged_tensor import ( + _pin_and_move, + _permute_tensor_by_segments, + _sum_by_splits, JaggedTensor, KeyedJaggedTensor, ) +from torchrec.pt2.checks import is_torchdynamo_compiling, is_non_strict_exporting class JaggedTensorWithCount(JaggedTensor): @@ -83,3 +89,604 @@ class KeyedJaggedTensorWithCount(KeyedJaggedTensor): ) self._counts: torch.Tensor = counts + + @property + def counts(self) -> torch.Tensor: + return self._counts + + @staticmethod + def from_jt_dict(jt_dict: Dict[str, JaggedTensorWithCount]) -> "KeyedJaggedTensorWithCount": + """ + Constructs a KeyedJaggedTensorWithCount from a dictionary of JaggedTensorWithCounts. + Automatically calls `kjt.sync()` on newly created KJT. + + Args: + jt_dict (Dict[str, JaggedTensorWithCount]): dictionary of JaggedTensorWithCounts. + + Returns: + KeyedJaggedTensorWithCount: constructed KeyedJaggedTensorWithCount. + """ + kjt_keys = list(jt_dict.keys()) + kjt_vals_list: List[torch.Tensor] = [] + kjt_counts_list: List[torch.Tensor] = [] + kjt_lens_list: List[torch.Tensor] = [] + kjt_weights_list: List[torch.Tensor] = [] + stride_per_key: List[int] = [] + for jt in jt_dict.values(): + stride_per_key.append(len(jt.lengths())) + kjt_vals_list.append(jt.values()) + kjt_counts_list.append(jt.counts) + kjt_lens_list.append(jt.lengths()) + weight = jt.weights_or_none() + if weight is not None: + kjt_weights_list.append(weight) + kjt_vals = torch.concat(kjt_vals_list) + kjt_lens = torch.concat(kjt_lens_list) + + # handle custom attribute: counts + kjt_counts = ( + torch.concat(kjt_counts_list) if len(kjt_counts_list) > 0 else None + ) + + kjt_weights = ( + torch.concat(kjt_weights_list) if len(kjt_weights_list) > 0 else None + ) + kjt_stride, kjt_stride_per_key_per_rank = ( + (stride_per_key[0], None) + if all(s == stride_per_key[0] for s in stride_per_key) + else (None, [[stride] for stride in stride_per_key]) + ) + kjt = KeyedJaggedTensorWithCount( + keys=kjt_keys, + values=kjt_vals, + counts=kjt_counts, + weights=kjt_weights, + lengths=kjt_lens, + stride=kjt_stride, + stride_per_key_per_rank=kjt_stride_per_key_per_rank, + ).sync() + return kjt + + def split(self, segments: List[int]) -> List["KeyedJaggedTensorWithCount"]: + split_list: List[KeyedJaggedTensorWithCount] = [] + start = 0 + start_offset = 0 + _length_per_key = self.length_per_key() + _offset_per_key = self.offset_per_key() + for segment in segments: + end = start + segment + end_offset = _offset_per_key[end] + keys: List[str] = self._keys[start:end] + + stride, stride_per_key_per_rank = ( + (None, self.stride_per_key_per_rank()[start:end]) + if self.variable_stride_per_key() + else (self._stride, None) + ) + if segment == len(self._keys): + # no torch slicing required + split_list.append( + KeyedJaggedTensorWithCount( + keys=self._keys, + values=self._values, + counts=self._counts, + weights=self.weights_or_none(), + lengths=self._lengths, + offsets=self._offsets, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + jt_dict=self._jt_dict, + ) + ) + elif segment == 0: + empty_int_list: List[int] = torch.jit.annotate(List[int], []) + split_list.append( + KeyedJaggedTensorWithCount( + keys=keys, + values=torch.tensor( + empty_int_list, + device=self.device(), + dtype=self._values.dtype, + ), + counts=torch.tensor( + empty_int_list, + device=self.device(), + dtype=self._counts.dtype, + ), + weights=( + None + if self.weights_or_none() is None + else torch.tensor( + empty_int_list, + device=self.device(), + dtype=self.weights().dtype, + ) + ), + lengths=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int + ), + offsets=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int + ), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + ) + ) + else: + split_length_per_key = _length_per_key[start:end] + split_list.append( + KeyedJaggedTensorWithCount( + keys=keys, + values=self._values[start_offset:end_offset], + counts=( + self._counts[start_offset:end_offset] + if self._counts is not None + else None + ), + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + self.lengths_offset_per_key()[ + start + ]: self.lengths_offset_per_key()[end] + ], + offsets=None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=split_length_per_key, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + ) + ) + start = end + start_offset = end_offset + return split_list + + def permute( + self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None + ) -> "KeyedJaggedTensorWithCount": + """ + Permutes the KeyedJaggedTensorWithCount. + + Args: + indices (List[int]): list of indices. + indices_tensor (Optional[torch.Tensor]): tensor of indices. + + Returns: + KeyedJaggedTensorWithCount: permuted KeyedJaggedTensorWithCount. + """ + if indices_tensor is None: + indices_tensor = torch.tensor( + indices, dtype=torch.int, device=self.device() + ) + + length_per_key = self.length_per_key() + permuted_keys: List[str] = [] + permuted_stride_per_key_per_rank: List[List[int]] = [] + permuted_length_per_key: List[int] = [] + permuted_length_per_key_sum = 0 + for index in indices: + key = self.keys()[index] + permuted_keys.append(key) + permuted_length_per_key.append(length_per_key[index]) + if self.variable_stride_per_key(): + permuted_stride_per_key_per_rank.append( + self.stride_per_key_per_rank()[index] + ) + + permuted_length_per_key_sum = sum(permuted_length_per_key) + if not torch.jit.is_scripting() and is_non_strict_exporting(): + torch._check_is_size(permuted_length_per_key_sum) + torch._check(permuted_length_per_key_sum != -1) + torch._check(permuted_length_per_key_sum != 0) + + if self.variable_stride_per_key(): + length_per_key_tensor = _pin_and_move( + torch.tensor(self.length_per_key()), self.device() + ) + stride_per_key_tensor = _pin_and_move( + torch.tensor(self.stride_per_key()), self.device() + ) + permuted_lengths, _ = _permute_tensor_by_segments( + self.lengths(), + stride_per_key_tensor, + indices_tensor, + None, + ) + permuted_values, permuted_weights = _permute_tensor_by_segments( + self.values(), + length_per_key_tensor, + indices_tensor, + self.weights_or_none(), + ) + permuted_counts, _ = _permute_tensor_by_segments( + self.counts, + length_per_key_tensor, + indices_tensor, + self.weights_or_none(), + ) + elif is_torchdynamo_compiling() and not torch.jit.is_scripting(): + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.values(), + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + _, permuted_counts, _ = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.counts, + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + else: + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.values(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + _, permuted_counts, _ = torch.ops.fbgemm.permute_2D_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.counts, + self.weights_or_none(), + permuted_length_per_key_sum, + ) + stride_per_key_per_rank = ( + permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None + ) + kjt = KeyedJaggedTensorWithCount( + keys=permuted_keys, + values=permuted_values, + counts=permuted_counts, + weights=permuted_weights, + lengths=permuted_lengths.view(-1), + offsets=None, + stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, + length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, + lengths_offset_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) + return kjt + + def pin_memory(self) -> "KeyedJaggedTensorWithCount": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + stride, stride_per_key_per_rank = ( + (None, self._stride_per_key_per_rank) + if self.variable_stride_per_key() + else (self._stride, None) + ) + + return KeyedJaggedTensorWithCount( + keys=self._keys, + values=self._values.pin_memory(), + counts=( + self._counts.pin_memory() + if self._counts is not None + else None + ), + weights=weights.pin_memory() if weights is not None else None, + lengths=lengths.pin_memory() if lengths is not None else None, + offsets=offsets.pin_memory() if offsets is not None else None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + jt_dict=None, + ) + + def to( + self, device: torch.device, non_blocking: bool = False + ) -> "KeyedJaggedTensorWithCount": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + stride, stride_per_key_per_rank = ( + (None, self._stride_per_key_per_rank) + if self.variable_stride_per_key() + else (self._stride, None) + ) + length_per_key = self._length_per_key + offset_per_key = self._offset_per_key + index_per_key = self._index_per_key + jt_dict = self._jt_dict + + return KeyedJaggedTensorWithCount( + keys=self._keys, + values=self._values.to(device, non_blocking=non_blocking), + counts=( + self._counts.to(device, non_blocking=non_blocking) + if self._counts is not None + else None + ), + weights=( + weights.to(device, non_blocking=non_blocking) + if weights is not None + else None + ), + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=length_per_key, + offset_per_key=offset_per_key, + index_per_key=index_per_key, + jt_dict=jt_dict, + ) + + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + super().record_stream(stream) + if self._counts is not None: + self._counts.record_stream(stream) + + def to_dict(self) -> Dict[str, JaggedTensor]: + # invoke base class's method, and will discard timestamp data. + return super().to_dict() + + def dist_labels(self) -> List[str]: + labels = ["lengths", "values"] + if self.variable_stride_per_key(): + labels.append("strides") + if self.weights_or_none() is not None: + labels.append("weights") + if self._counts is not None: + labels.append("counts") + return labels + + def dist_splits(self, key_splits: List[int]) -> List[List[int]]: + batch_size_per_split = _sum_by_splits(self.stride_per_key(), key_splits) + length_per_split = _sum_by_splits(self.length_per_key(), key_splits) + splits = [batch_size_per_split, length_per_split] + if self.variable_stride_per_key(): + splits.append(key_splits) + if self.weights_or_none() is not None: + splits.append(length_per_split) + if self._counts is not None: + splits.append(length_per_split) + return splits + + def dist_tensors(self) -> List[torch.Tensor]: + tensors = [self.lengths(), self.values()] + if self.variable_stride_per_key(): + strides = _pin_and_move(torch.tensor(self.stride_per_key()), self.device()) + tensors.append(strides) + if self.weights_or_none() is not None: + tensors.append(self.weights()) + if self._counts is not None: + tensors.append(self._counts) + return tensors + + @staticmethod + def dist_init( + keys: List[str], + tensors: List[torch.Tensor], + variable_stride_per_key: bool, + num_workers: int, + recat: Optional[torch.Tensor], + stride_per_rank: Optional[List[int]], + stagger: int = 1, + ) -> "KeyedJaggedTensorWithCount": + # The original largest length is 4, there is an extra counts params, the biggest length is 5. + if len(tensors) not in [2, 3, 4, 5]: + raise RuntimeError(f"tensors length must in [2, 3, 4, 5] but got:{len(tensors)}") + lengths = tensors[0] + values = tensors[1] + stride_per_rank_per_key = tensors[2] if variable_stride_per_key else None + + # 仅当local unique且有表开启准入时,会使用KeyedJaggedTensorWithCount做all2all + # 此时会固定在tensors列表末尾传递counts数据 + weights = ( + tensors[-2] + if (variable_stride_per_key and len(tensors) == 5) + or (not variable_stride_per_key and len(tensors) == 4) + else None + ) + counts = tensors[-1] + + if variable_stride_per_key: + stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view( + num_workers, len(keys) + ).T.cpu() + + strides_cumsum: torch.Tensor = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(stride_per_rank_per_key) + ).cpu() + + cumsum_lengths = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + + n = strides_cumsum.size(0) + strides_cumsum_from_1 = torch.narrow( + strides_cumsum, dim=0, start=1, length=n - 1 + ) + strides_cumsum_to_minus_1 = torch.narrow( + strides_cumsum, dim=0, start=0, length=n - 1 + ) + length_per_key_tensor = ( + cumsum_lengths[strides_cumsum_from_1] + - cumsum_lengths[strides_cumsum_to_minus_1] + ) + + with record_function("## all2all_data:recat_values ##"): + if recat is not None: + new_lengths, _ = _permute_tensor_by_segments( + lengths, + stride_per_rank_per_key, + torch.jit._unwrap_optional(recat), + None, + ) + new_values, new_weights = _permute_tensor_by_segments( + values, + length_per_key_tensor, + torch.jit._unwrap_optional(recat), + weights, + ) + if counts is not None: + new_counts, _ = _permute_tensor_by_segments( + counts, + length_per_key_tensor, + torch.jit._unwrap_optional(recat), + None, + ) + + stride_per_key_per_rank = torch.jit.annotate( + List[List[int]], stride_per_key_per_rank_tensor.tolist() + ) + + if not stride_per_key_per_rank: + stride_per_key_per_rank = [[0]] * len(keys) + if stagger > 1: + stride_per_key_per_rank_stagger: List[List[int]] = [] + local_world_size = num_workers // stagger + for i in range(len(keys)): + stride_per_rank_stagger: List[int] = [] + for j in range(local_world_size): + stride_per_rank_stagger.extend( + stride_per_key_per_rank[i][j::local_world_size] + ) + stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) + stride_per_key_per_rank = stride_per_key_per_rank_stagger + + kjt = KeyedJaggedTensorWithCount( + keys=keys, + values=new_values, + counts=new_counts, + weights=new_weights, + lengths=lengths, + stride_per_key_per_rank=stride_per_key_per_rank, + ) + return kjt.sync() + else: + with record_function("## all2all_data:recat_values ##"): + if recat is not None: + stride = stride_per_rank[0] + + single_batch_per_rank = True + new_counts = None + if not is_torchdynamo_compiling(): + single_batch_per_rank = all( + s == stride for s in stride_per_rank + ) + if ( + single_batch_per_rank + and is_torchdynamo_compiling() + and not torch.jit.is_scripting() + ): + ( + new_lengths, + new_values, + new_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + torch.jit._unwrap_optional(recat), + lengths, + values, + stride, + weights, + values.numel(), + ) + if counts is not None: + _, new_counts, _ = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + torch.jit._unwrap_optional(recat), + lengths, + counts, + stride, + None, + counts.numel(), + ) + elif single_batch_per_rank: + ( + new_lengths, + new_values, + new_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1, stride), + values, + weights, + values.numel(), + ) + if counts is not None: + _, new_counts, _ = torch.ops.fbgemm.permute_2D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1, stride), + counts, + None, + counts.numel(), + ) + new_lengths = new_lengths.view(-1) + else: # variable batch size per rank + ( + new_lengths, + new_values, + new_weights, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1), + values, + weights, + values.numel(), + ) + if counts is not None: + _, new_counts, _ = torch.ops.fbgemm.permute_1D_sparse_data( + torch.jit._unwrap_optional(recat), + lengths.view(-1), + counts, + None, + counts.numel(), + ) + else: + new_lengths = lengths + new_values = values + new_weights = weights + new_counts = counts + kjt = KeyedJaggedTensorWithCount( + keys=keys, + values=new_values, + counts=new_counts, + weights=new_weights, + lengths=new_lengths, + stride=sum(stride_per_rank), + ) + return kjt.sync() + + diff --git a/torchrec/torchrec_embcache/build.sh b/torchrec/torchrec_embcache/build.sh index 55581f470c35f06cd69d92296b362615be512b00..d6723f35ae48d157f1d663791faf6c2fe998d055 100644 --- a/torchrec/torchrec_embcache/build.sh +++ b/torchrec/torchrec_embcache/build.sh @@ -26,9 +26,16 @@ fi function prepare_deps() { python3 -m pip install pybind11 - cd "${SCRIPT_PATH}/src/3rdparty" - git clone -b master https://gitee.com/Janisa/huawei_secure_c.git securec - cd - + + local securec_dir="${SCRIPT_PATH}/src/3rdparty/securec" + if [ ! -d "$securec_dir" ]; then + echo "Cloning huawei_secure_c..." + cd "${SCRIPT_PATH}/src/3rdparty" + git clone -b master https://gitee.com/Janisa/huawei_secure_c.git securec + cd - + else + echo "securec directory already exists, skipping clone." + fi } function check_ret_fn() diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp index 52c8d2f492dbd77b41406e183d4cbb4475e7bede..60d44769912489c058d590f8df4b6d68d1db3605 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp @@ -18,6 +18,7 @@ #include "utils/logger.h" #include "utils/time_cost.h" +#include "utils/string_tools.h" using namespace Embcache; @@ -49,7 +50,9 @@ EmbcacheManager::EmbcacheManager(const std::vector& embConfigs, bool } if (embConfigs[i].admitAndEvictConfig.IsFeatureFilterEnabled()) { - // 待补充 feature filter 初始化 + auto& aaeConfig = embConfigs[i].admitAndEvictConfig; + featureFilters_.emplace_back(FeatureFilter(embConfigs[i].tableName, aaeConfig.admitThreshold, + aaeConfig.evictThreshold, aaeConfig.evictStepInterval)); } } TORCH_CHECK(embConfigs.size() > 0, "ERROR, Size of embConfigs must > 0") @@ -99,7 +102,7 @@ SwapInfo EmbcacheManager::ComputeSwapInfo(const at::Tensor& batchKeys, const std for (int64_t i = 0; i < curTableIndices.size(); i++) { int64_t idx = curTableIndices[i]; if (embConfigs_[idx].admitAndEvictConfig.IsAdmitEnabled()) { - // 待补充 feature filter 统计 + featureFilters_[idx].CountFilter(keyPtr, offsetPerKey[i], offsetPerKey[i + 1]); } // 取出每个表的 key @@ -272,10 +275,44 @@ void EmbcacheManager::EmbeddingUpdate(const std::vector>& s void EmbcacheManager::RecordTimestamp(const at::Tensor& batchKeys, const std::vector& offsetPerKey, const at::Tensor& timestamps, const std::vector& tableIndices) { + LOG_INFO("Start invoke mgmt RecordTimestamp"); + TimeCost recordTimestampTC; + const auto* keyPtr = batchKeys.data_ptr(); + const auto* timestampsPtr = timestamps.data_ptr(); + const std::vector& curTableIndices = tableIndices.empty() ? embTableIndies_ : tableIndices; + TORCH_CHECK(curTableIndices.size() + 1 == offsetPerKey.size(), + "tableIndices size+1 must be equal to offsetPerKey size"); + + for (int64_t i = 0; i < embNum_; ++i) { + int32_t idx = curTableIndices[i]; + if (embConfigs_[idx].admitAndEvictConfig.IsEvictEnabled()) { + featureFilters_[idx].RecordTimestamp(keyPtr, offsetPerKey[i], offsetPerKey[i + 1], timestampsPtr); + } + } + LOG_INFO("RecordTimestamp execution time: {} ms", recordTimestampTC.ElapsedMS()); } void EmbcacheManager::EvictFeatures() { + LOG_INFO("Start invoke EvictFeatures method, ComputeSwapInfo execute times: {}", swapCount_); + TimeCost evictFeaturesTC; + size_t evictKeyCount = 0; + for (int32_t i = 0; i < embNum_; ++i) { + if (!embConfigs_[i].admitAndEvictConfig.IsEvictEnabled()) { + LOG_INFO("The table: {} doesn't enable evict, skip feature evict.", embConfigs_[i].tableName); + continue; + } + + // 获取当前表要淘汰的keys + const std::vector& evictFeatures = featureFilters_[i].evictFeatureRecord_.GetEvictKeys(); + // 调用swapManager删除映射信息 + // 删除embeddingTables中的embedding待对应step的swap out emb update执行完成后触发 + swapManagers_[i].RemoveKeys(evictFeatures); + featureFilters_[i].evictFeatureRecord_.SetSwapCount(swapCount_); + evictKeyCount += evictFeatures.size(); + } + LOG_INFO("EvictFeatures execution time : {} ms, all table evictKeyCount : {}", evictFeaturesTC.ElapsedMS(), + evictKeyCount); } void EmbcacheManager::RecordEmbeddingUpdateTimes() @@ -297,14 +334,60 @@ AsyncTask EmbcacheManager::EmbeddingUpdateAsync(const SwapInfo& swapInfo, } bool EmbcacheManager::NeedEvictEmbeddingTable() { + for (int32_t i = 0; i < embNum_; ++i) { + // 开启淘汰 + if (!embConfigs_[i].admitAndEvictConfig.IsEvictEnabled()) { + continue; + } + // 待删除embTable的keys非空且达到和GetSwapInfo相同的步数 + if (!featureFilters_[i].evictFeatureRecord_.GetEvictKeys().empty() && + featureFilters_[i].evictFeatureRecord_.CanRemoveFromEmbTable(embUpdateCount_)) { + return true; + } + } return false; } void EmbcacheManager::RemoveEmbeddingTableInfo() { + LOG_INFO("Start invoke RemoveEmbeddingTableInfo, embUpdateCount_: {}", embUpdateCount_); + TimeCost removeEmbeddingTableTC; + for (int32_t i = 0; i < embNum_; ++i) { + auto& keys = featureFilters_[i].evictFeatureRecord_.GetEvictKeys(); + if (keys.empty()) { + LOG_INFO("Feature keys list is empty, skip to remove embedding from table: {}", embConfigs_[i].tableName); + continue; + } + + embeddingTables_[i]->RemoveEmbedding(keys); + LOG_INFO("Remove table embedding info, table : {}, remove key size : {}, detail keys : {}", + embConfigs_[i].tableName, keys.size(), StringTools::ToString(keys)); + featureFilters_[i].evictFeatureRecord_.ClearEvictInfo(); + } + LOG_INFO("RemoveEmbeddingTableInfo execution time: {} ms", removeEmbeddingTableTC.ElapsedMS()); } void EmbcacheManager::StatisticsKeyCount(const at::Tensor& batchKeys, const torch::Tensor& offset, const at::Tensor& batchKeyCounts, int64_t tableIndex) { + LOG_INFO("StatisticsKeyCount, tableIndex : {}, isAdmit : {}", + tableIndex, embConfigs_[tableIndex].admitAndEvictConfig.IsAdmitEnabled()); + if (!embConfigs_[tableIndex].admitAndEvictConfig.IsAdmitEnabled()) { + return; + } + TORCH_CHECK(offset.numel() > tableIndex, "param error, tableIndex need be smaller than offset length," + " but got equal or greater than offset length.") + // 未开启local unique时,counts为空tensor,处理时默认key对应count为1 + bool isCountDataEmpty = batchKeyCounts.numel() == 0; + if (!isCountDataEmpty) { + TORCH_CHECK(batchKeys.numel() == batchKeyCounts.numel(), + "batchKeys length should equal with batchKeyCounts length when batchKeyCounts is not empty.") + } + auto* featureDataPtr = batchKeys.data_ptr(); + auto* countDataPtr = batchKeyCounts.data_ptr(); + auto* offsetDataPtr = offset.data_ptr(); + int64_t start = offsetDataPtr[tableIndex]; + int64_t end = offsetDataPtr[tableIndex + 1]; + TORCH_CHECK(end <= batchKeys.numel()) + featureFilters_[tableIndex].StatisticsKeyCount(featureDataPtr, countDataPtr, start, end, isCountDataEmpty); } diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h index 1d74e60815fedeb0e0a1fac6a609ecba6c1ea3c9..b85a62ad450c5770c6c4025ddf430752a656fe22 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h @@ -16,6 +16,7 @@ #include "common/common.h" #include "emb_table/emb_table.h" +#include "feature_filter/feature_filter.h" #include "swap_manager.h" #include "utils/async_task.h" #include "utils/thread_pool.h" @@ -129,6 +130,7 @@ private: std::vector embConfigs_; std::vector swapManagers_; std::vector> embeddingTables_; + std::vector featureFilters_; uint64_t swapCount_ = 0; // ComputeSwapInfo 执行次数 uint64_t embUpdateCount_ = 0; // EmbeddingUpdate 执行次数