diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding_bag.py b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding_bag.py index 76d47d710797616e3544580f0a4d6a171fb8307c..f397fe72545ce0cdbd206535ebf3a35f72b58886 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding_bag.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding_bag.py @@ -172,40 +172,6 @@ class EmbCacheEmbeddingBagCollection(EmbeddingBagCollection): tables (List[EmbeddingBagConfig]): list of embedding tables. is_weighted (bool): whether input `KeyedJaggedTensor` is weighted. device (Optional[torch.device]): default compute device. - - Example:: - - table_0 = EmbeddingBagConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - table_1 = EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] - ) - - ebc = EmbeddingBagCollection(tables=[table_0, table_1]) - - # 0 1 2 <-- batch - # "f1" [0,1] None [2] - # "f2" [3] [4] [5,6,7] - # ^ - # feature - - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - - pooled_embeddings = ebc(features) - print(pooled_embeddings.values()) - tensor([[-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], - [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], - [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], - grad_fn=) - print(pooled_embeddings.keys()) - ['f1', 'f2'] - print(pooled_embeddings.offset_per_key()) - tensor([0, 3, 7]) """ def __init__( diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/train_pipeline.py b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/train_pipeline.py index 8651ab84c0495bc6dde45da692b07dbee1149895..2d670aecf2c34cecda8142245445abf7e450ba44 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/train_pipeline.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/train_pipeline.py @@ -120,7 +120,6 @@ class EmbCacheTrainPipelineContext(TrainPipelineContext): class EmbCachePipelinedForward(PipelinedForward): - # pyre-ignore [2, 24] def __call__(self, *input_feature, **kwargs) -> Awaitable: self._context.sparse_features_after_restore_future.pop(self._name).get() data = self._context.sparse_features_after_post_dist.pop(self._name) @@ -216,7 +215,6 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: ( names, FusedKJTListSplitsAwaitable( - # pyre-ignore[6] requests=[ context.input_dist_splits_requests[name] for name in names ], @@ -603,7 +601,6 @@ class EmbCacheTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): return self._init_pipelined_modules( - # pyre-ignore [6] self.batches[0], self.contexts[0], EmbCachePipelinedForward,