From f2cf00e5302ce4038cf4841bd72d5c6edff7a305 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 28 Oct 2025 06:59:50 +0000 Subject: [PATCH 1/3] Support LFU frequency masking on V2 --- .../dynamicemb/batched_dynamicemb_function.py | 69 ++++++++++++++++++- .../dynamicemb/batched_dynamicemb_tables.py | 8 ++- .../dynamicemb/dynamicemb/key_value_table.py | 24 +++++++ corelib/dynamicemb/example/example.py | 4 +- 4 files changed, 102 insertions(+), 3 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index 24b32a92b..1f9ba21c1 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -45,6 +45,60 @@ ) +def _mask_embeddings_by_frequency( + cache: Optional[Cache], + storage: Storage, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + frequency_threshold: int, + mask_dims: int, +) -> None: + """ + Mask low-frequency embeddings by setting specific dimensions to zero. + + This function queries scores from cache and storage, then masks embeddings + whose scores are below the frequency threshold. + + Args: + cache: Optional cache table (can be None if caching is disabled) + storage: Storage table (always present) + unique_keys: Keys to query scores for + unique_embs: Embeddings to mask (modified in-place) + frequency_threshold: Minimum score threshold + mask_dims: Number of dimensions to mask (from the end) + """ + batch = unique_keys.size(0) + if batch == 0: + return + + # Query scores from cache and storage + if cache is not None: + # 1. Query cache first + cache_scores = cache.query_scores(unique_keys) + cache_founds = cache_scores > 0 + + # 2. Query storage for cache misses + if (~cache_founds).any(): + missing_keys = unique_keys[~cache_founds] + storage_scores = storage.query_scores(missing_keys) + cache_scores[~cache_founds] = storage_scores + + scores = cache_scores + else: + # Without cache: query from storage only + scores = storage.query_scores(unique_keys) + + # Apply masking + low_freq_mask = scores < frequency_threshold + if low_freq_mask.any(): + unique_embs[low_freq_mask, -mask_dims:] = 0.0 + + for i in range(unique_embs.size(0)): + print( + f"Row {i}: score = {scores[i].item()}, last {mask_dims} dims = {unique_embs[i, -mask_dims:].tolist()}" + ) + + # TODO: BatchedDynamicEmbeddingFunction is more concrete. class DynamicEmbeddingBagFunction(torch.autograd.Function): @staticmethod @@ -348,6 +402,8 @@ def forward( input_dist_dedup: bool = False, training: bool = True, frequency_counters: Optional[torch.Tensor] = None, + frequency_threshold: int = 0, + mask_dims: int = 0, *args, ): table_num = len(storages) @@ -426,6 +482,17 @@ def forward( lfu_accumulated_frequency_per_table, ) + # Apply frequency-based masking if enabled + if is_lfu_enabled and mask_dims > 0 and frequency_threshold > 0: + _mask_embeddings_by_frequency( + caches[i] if caching else None, + storages[i], + unique_indices_per_table, + unique_embs_per_table, + frequency_threshold, + mask_dims, + ) + if training or caching: output_embs = torch.empty( indices.shape[0], emb_dim, dtype=output_dtype, device=indices.device @@ -501,4 +568,4 @@ def backward(ctx, grads): optimizer, ) - return (None,) * 14 + return (None,) * 16 diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index f4814d318..81c1aa83d 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -459,6 +459,9 @@ def __init__( cowclip_regularization: Optional[ CowClipDefinition ] = None, # used by Rowwise Adagrad + # Frequency masking parameters + frequency_threshold: int = 0, # Frequency threshold for masking + mask_dims: int = 0, # Number of dimensions to mask # TO align with FBGEMM TBE *args, **kwargs, @@ -483,7 +486,8 @@ def __init__( self._table_names = table_names self.bounds_check_mode_int: int = bounds_check_mode.value self._create_score() - + self.frequency_threshold = frequency_threshold + self.mask_dims = mask_dims if device is not None: self.device_id = int(str(device)[-1]) else: @@ -984,6 +988,8 @@ def forward( self.use_index_dedup, self.training, per_sample_weights, # Pass frequency counters as weights + self.frequency_threshold, + self.mask_dims, self._empty_tensor, ) for cache in self._caches: diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 6e4950438..5f34a1bea 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -40,6 +40,7 @@ erase, export_batch, export_batch_matched, + find, find_pointers, find_pointers_with_scores, insert_and_evict, @@ -347,6 +348,29 @@ def create_scores( else: return None + def query_scores(self, unique_keys: torch.Tensor) -> torch.Tensor: + """Query scores for given keys from the table. + + Returns: + scores: Tensor of scores, with 0 for keys not found in table + """ + + batch = unique_keys.size(0) + device = unique_keys.device + + scores = torch.empty(batch, device=device, dtype=torch.long) + values = torch.empty( + batch, self.value_dim(), device=device, dtype=self.embedding_dtype() + ) + founds = torch.empty(batch, device=device, dtype=torch.bool) + + find(self.table, batch, unique_keys, values, founds, score=scores) + + # for not found keys, set score to 0 + scores[~founds] = 0 + + return scores + def insert( self, unique_keys: torch.Tensor, diff --git a/corelib/dynamicemb/example/example.py b/corelib/dynamicemb/example/example.py index 2725fb48f..cc12e77c5 100644 --- a/corelib/dynamicemb/example/example.py +++ b/corelib/dynamicemb/example/example.py @@ -391,6 +391,8 @@ def get_sharder(args, optimizer_type): ] = ( SparseType.FP32 ) # data type of the output after lookup, and can differ from the stored. + fused_params["frequency_threshold"] = 10 + fused_params["mask_dims"] = 10 fused_params.update(optimizer_kwargs) fused_params[ "prefetch_pipeline" @@ -486,7 +488,7 @@ def get_planner(device, eb_configs, batch_size, optimizer_type, training, cachin initializer_args=DynamicEmbInitializerArgs( mode=DynamicEmbInitializerMode.NORMAL ), - score_strategy=DynamicEmbScoreStrategy.STEP, + score_strategy=DynamicEmbScoreStrategy.LFU, caching=caching, training=training, ), From c31bb9c6d983d70486f68fb460fe3a5d3e8f1701 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 28 Oct 2025 07:06:51 +0000 Subject: [PATCH 2/3] Remove debug print --- corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index 1f9ba21c1..17450de14 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -93,11 +93,6 @@ def _mask_embeddings_by_frequency( if low_freq_mask.any(): unique_embs[low_freq_mask, -mask_dims:] = 0.0 - for i in range(unique_embs.size(0)): - print( - f"Row {i}: score = {scores[i].item()}, last {mask_dims} dims = {unique_embs[i, -mask_dims:].tolist()}" - ) - # TODO: BatchedDynamicEmbeddingFunction is more concrete. class DynamicEmbeddingBagFunction(torch.autograd.Function): From d941891493c4982440701948d6c672b95d224b93 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 28 Oct 2025 08:43:28 +0000 Subject: [PATCH 3/3] Added enter conditions for LFU masking --- .../dynamicemb/dynamicemb/batched_dynamicemb_function.py | 8 +++++++- .../dynamicemb/dynamicemb/batched_dynamicemb_tables.py | 7 ++----- corelib/dynamicemb/dynamicemb/dynamicemb_config.py | 3 +++ corelib/dynamicemb/example/example.py | 4 ++-- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index 17450de14..947b7c366 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -70,7 +70,9 @@ def _mask_embeddings_by_frequency( batch = unique_keys.size(0) if batch == 0: return - + assert hasattr( + storage, "query_scores" + ), "If you want to use frequency masking, storage must implement the query_scores method" # Query scores from cache and storage if cache is not None: # 1. Query cache first @@ -92,6 +94,10 @@ def _mask_embeddings_by_frequency( low_freq_mask = scores < frequency_threshold if low_freq_mask.any(): unique_embs[low_freq_mask, -mask_dims:] = 0.0 + for i in range(unique_embs.size(0)): + print( + f"Row {i}: score = {scores[i].item()}, last {mask_dims} dims = {unique_embs[i, -mask_dims:].tolist()}" + ) # TODO: BatchedDynamicEmbeddingFunction is more concrete. diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index 81c1aa83d..c56e5f140 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -459,9 +459,6 @@ def __init__( cowclip_regularization: Optional[ CowClipDefinition ] = None, # used by Rowwise Adagrad - # Frequency masking parameters - frequency_threshold: int = 0, # Frequency threshold for masking - mask_dims: int = 0, # Number of dimensions to mask # TO align with FBGEMM TBE *args, **kwargs, @@ -486,8 +483,8 @@ def __init__( self._table_names = table_names self.bounds_check_mode_int: int = bounds_check_mode.value self._create_score() - self.frequency_threshold = frequency_threshold - self.mask_dims = mask_dims + self.frequency_threshold = table_option.frequency_threshold + self.mask_dims = table_option.mask_dims if device is not None: self.device_id = int(str(device)[-1]) else: diff --git a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py index eb849db7c..4be9203a4 100644 --- a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py +++ b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py @@ -386,6 +386,9 @@ class DynamicEmbTableOptions(_ContextOptions): global_hbm_for_values: int = 0 # in bytes external_storage: Storage = None index_type: Optional[torch.dtype] = None + # Frequency-based masking parameters + frequency_threshold: int = 0 # frequency threshold for masking (0 = disabled) + mask_dims: int = 0 # number of dimensions to mask (0 = disabled) def __post_init__(self): assert ( diff --git a/corelib/dynamicemb/example/example.py b/corelib/dynamicemb/example/example.py index cc12e77c5..283dcd16c 100644 --- a/corelib/dynamicemb/example/example.py +++ b/corelib/dynamicemb/example/example.py @@ -391,8 +391,6 @@ def get_sharder(args, optimizer_type): ] = ( SparseType.FP32 ) # data type of the output after lookup, and can differ from the stored. - fused_params["frequency_threshold"] = 10 - fused_params["mask_dims"] = 10 fused_params.update(optimizer_kwargs) fused_params[ "prefetch_pipeline" @@ -491,6 +489,8 @@ def get_planner(device, eb_configs, batch_size, optimizer_type, training, cachin score_strategy=DynamicEmbScoreStrategy.LFU, caching=caching, training=training, + frequency_threshold=10, + mask_dims=10, ), )