From 237fca9ba3b94a1748c7b27a85e329bf3d976c2f Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Mon, 19 Jan 2026 02:54:35 +0000 Subject: [PATCH 1/8] Fix issue related to empty batch --- corelib/dynamicemb/src/dynamic_emb_op.cu | 6 +- corelib/dynamicemb/src/index_calculation.cu | 6 ++ ...est_batched_dynamic_embedding_tables_v2.py | 89 +++++++++++++++++++ 3 files changed, 99 insertions(+), 2 deletions(-) diff --git a/corelib/dynamicemb/src/dynamic_emb_op.cu b/corelib/dynamicemb/src/dynamic_emb_op.cu index 2d58bf7ce..c484e3c0f 100644 --- a/corelib/dynamicemb/src/dynamic_emb_op.cu +++ b/corelib/dynamicemb/src/dynamic_emb_op.cu @@ -898,6 +898,10 @@ void load_from_combined_table(std::optional dev_table, std::optional uvm_table, at::Tensor indices, at::Tensor output) { + int64_t num_total = indices.size(0); + if (num_total == 0) { + return; + } int64_t stride = -1; int64_t dim = output.size(1); if ((not dev_table.has_value()) and (not uvm_table.has_value())) { @@ -934,8 +938,6 @@ void load_from_combined_table(std::optional dev_table, auto val_type = get_data_type(output); auto index_type = get_data_type(indices); - int64_t num_total = indices.size(0); - constexpr int kWarpSize = 32; constexpr int MULTIPLIER = 4; constexpr int BLOCK_SIZE_VEC = 64; diff --git a/corelib/dynamicemb/src/index_calculation.cu b/corelib/dynamicemb/src/index_calculation.cu index 101410685..5ae796eba 100644 --- a/corelib/dynamicemb/src/index_calculation.cu +++ b/corelib/dynamicemb/src/index_calculation.cu @@ -523,6 +523,9 @@ void select(at::Tensor flags, at::Tensor inputs, at::Tensor outputs, auto num_select_iter_type = scalartype_to_datatype(num_selected.dtype().toScalarType()); + if (num_total == 0) { + return; + } DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { DISPATCH_INTEGER_DATATYPE_FUNCTION( num_select_iter_type, NumSelectedIteratorT, [&] { @@ -545,6 +548,9 @@ void select_index(at::Tensor flags, at::Tensor output_indices, auto num_select_iter_type = scalartype_to_datatype(num_selected.dtype().toScalarType()); + if (num_total == 0) { + return; + } DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { DISPATCH_INTEGER_DATATYPE_FUNCTION( num_select_iter_type, NumSelectedIteratorT, [&] { diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 7d02c2f40..224a87a9e 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -853,3 +853,92 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc del os.environ["DEMB_DETERMINISM_MODE"] print("all check passed") + + +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.EXACT_ROWWISE_ADAGRAD, + { + "learning_rate": 0.3, + "eps": 3e-5, + }, + ), + ], +) +@pytest.mark.parametrize("dim", [7, 8]) +@pytest.mark.parametrize("caching", [True, False]) +@pytest.mark.parametrize("deterministic", [True, False]) +@pytest.mark.parametrize("PS", [None]) +def test_forward_train_eval_empty_batch( + opt_type, opt_params, dim, caching, deterministic, PS +): + print( + f"step in test_forward_train_eval_empty_batch , opt_type = {opt_type} opt_params = {opt_params}" + ) + + if deterministic: + os.environ["DEMB_DETERMINISM_MODE"] = "ON" + + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + dims = [dim, dim, dim] + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + init_capacity = 1024 + max_capacity = 2048 + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=init_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + caching=caching, + local_hbm_for_values=1024**3, + external_storage=PS, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + bdebt = BatchedDynamicEmbeddingTablesV2( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=[0, 0, 1, 2], + pooling_mode=DynamicEmbPoolingMode.NONE, + optimizer=opt_type, + use_index_dedup=True, + **opt_params, + ) + """ + feature number = 4, batch size = 1 + + f0 [], + f1 [], + f2 [], + f3 [], + """ + indices = torch.tensor([], dtype=key_type, device=device) + offsets = torch.tensor([0, 0, 0, 0, 0], dtype=key_type, device=device) + + bdebt(indices, offsets) + torch.cuda.synchronize() + + with torch.no_grad(): + bdebt.eval() + bdebt(indices, offsets) + torch.cuda.synchronize() + + if deterministic: + del os.environ["DEMB_DETERMINISM_MODE"] + + print("all check passed") From ff913964f15175b660a1b1e8b1b3eb59e0a86dab Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Mon, 19 Jan 2026 09:09:20 +0000 Subject: [PATCH 2/8] Empty batch in prefetch --- .../dynamicemb/dynamicemb/key_value_table.py | 18 +++++++++++---- ...est_batched_dynamic_embedding_tables_v2.py | 23 +++++++++++++------ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index cd0db1716..ac1cf761f 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -420,8 +420,8 @@ def find_impl( load_from_pointers(pointers, unique_embs) missing = torch.logical_not(founds) - num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) - num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + num_missing_0: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device) + num_missing_1: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device) missing_keys: torch.Tensor = torch.empty_like(unique_keys) missing_indices: torch.Tensor = torch.empty( batch, dtype=torch.long, device=device @@ -1184,6 +1184,16 @@ def find_impl( scores = self.create_scores(batch, device, input_scores) + if batch == 0: + return ( + 0, + torch.empty_like(unique_keys), + torch.empty(batch, dtype=torch.long, device=device), + torch.empty(batch, dtype=torch.uint64, device=device) + if scores is not None + else None, + ) + score_args_lookup = [ ScoreArg( name=self.score_policy.name, @@ -1203,8 +1213,8 @@ def find_impl( ) missing = torch.logical_not(founds) - num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) - num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + num_missing_0: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device) + num_missing_1: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device) missing_keys: torch.Tensor = torch.empty_like(unique_keys) missing_indices: torch.Tensor = torch.empty( batch, dtype=torch.long, device=device diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 224a87a9e..1f54f9d6b 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -14,7 +14,6 @@ # limitations under the License. import os -import random from typing import Dict, Optional, Tuple, cast import pytest @@ -404,7 +403,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS): """ For torchrec's adam optimizer, it will increment the optimizer_step in every forward, - which will affect the weights update, pay attention to it or try to use `set_optimizer_step()` + which will affect the weights update, pay attention to it or try to use `set_optimizer_step()` to control(not verified) it. """ @@ -919,6 +918,7 @@ def test_forward_train_eval_empty_batch( use_index_dedup=True, **opt_params, ) + bdebt.enable_prefetch = True """ feature number = 4, batch size = 1 @@ -930,13 +930,22 @@ def test_forward_train_eval_empty_batch( indices = torch.tensor([], dtype=key_type, device=device) offsets = torch.tensor([0, 0, 0, 0, 0], dtype=key_type, device=device) - bdebt(indices, offsets) - torch.cuda.synchronize() + pretch_stream = torch.cuda.Stream() + forward_stream = torch.cuda.Stream() - with torch.no_grad(): - bdebt.eval() + if caching: + with torch.cuda.stream(pretch_stream): + bdebt.prefetch(indices, offsets, forward_stream) + torch.cuda.synchronize() + + with torch.cuda.stream(forward_stream): bdebt(indices, offsets) - torch.cuda.synchronize() + torch.cuda.synchronize() + + with torch.no_grad(): + bdebt.eval() + bdebt(indices, offsets) + torch.cuda.synchronize() if deterministic: del os.environ["DEMB_DETERMINISM_MODE"] From 2be63fd8ac3cb331c3a4bb6bfc357157d724d4a2 Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Mon, 19 Jan 2026 09:16:56 +0000 Subject: [PATCH 3/8] Empty batch in backward --- .../test/test_batched_dynamic_embedding_tables_v2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 1f54f9d6b..0f9bff56f 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import random from typing import Dict, Optional, Tuple, cast import pytest @@ -871,9 +872,7 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc @pytest.mark.parametrize("caching", [True, False]) @pytest.mark.parametrize("deterministic", [True, False]) @pytest.mark.parametrize("PS", [None]) -def test_forward_train_eval_empty_batch( - opt_type, opt_params, dim, caching, deterministic, PS -): +def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS): print( f"step in test_forward_train_eval_empty_batch , opt_type = {opt_type} opt_params = {opt_params}" ) @@ -939,9 +938,11 @@ def test_forward_train_eval_empty_batch( torch.cuda.synchronize() with torch.cuda.stream(forward_stream): - bdebt(indices, offsets) + res = bdebt(indices, offsets) torch.cuda.synchronize() + res.mean().backward() + with torch.no_grad(): bdebt.eval() bdebt(indices, offsets) From d430904d87bc760d74574b3a650b1790ef280a85 Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Mon, 19 Jan 2026 09:51:25 +0000 Subject: [PATCH 4/8] debug --- corelib/dynamicemb/dynamicemb/key_value_table.py | 6 ++++++ corelib/dynamicemb/dynamicemb/optimizer.py | 1 + 2 files changed, 7 insertions(+) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index ac1cf761f..110f05911 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -546,6 +546,9 @@ def update( batch = keys.size(0) + # if batch == 0: + # return None, None, None + device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) pointers = torch.empty(batch, dtype=torch.long, device=device) @@ -1363,6 +1366,9 @@ def update( assert self._score_update == False, "update is called only in backward." batch = keys.size(0) + + # if batch == 0: + # return None, None, None device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) diff --git a/corelib/dynamicemb/dynamicemb/optimizer.py b/corelib/dynamicemb/dynamicemb/optimizer.py index f51c0d3a8..284f46cd4 100644 --- a/corelib/dynamicemb/dynamicemb/optimizer.py +++ b/corelib/dynamicemb/dynamicemb/optimizer.py @@ -945,6 +945,7 @@ def fused_update_with_index( emb_dim = grads.size(1) state_dim = self.get_state_dim(emb_dim) + print(f"Optimizer: {grads.shape}, {indices.shape}") rowwise_adagrad_for_combined_table( grads, indices, From f0f387dd4a31a434f0bfe0915e2c83a458583cec Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Tue, 20 Jan 2026 01:59:18 +0000 Subject: [PATCH 5/8] Fix IMA caused by not configuring shared memory size --- .../dynamicemb/dynamicemb/key_value_table.py | 2 +- corelib/dynamicemb/dynamicemb/optimizer.py | 1 - corelib/dynamicemb/src/optimizer.cu | 9 +- ...est_batched_dynamic_embedding_tables_v2.py | 95 ++++++++++++------- 4 files changed, 66 insertions(+), 41 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 110f05911..008e27738 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -1366,7 +1366,7 @@ def update( assert self._score_update == False, "update is called only in backward." batch = keys.size(0) - + # if batch == 0: # return None, None, None diff --git a/corelib/dynamicemb/dynamicemb/optimizer.py b/corelib/dynamicemb/dynamicemb/optimizer.py index 284f46cd4..f51c0d3a8 100644 --- a/corelib/dynamicemb/dynamicemb/optimizer.py +++ b/corelib/dynamicemb/dynamicemb/optimizer.py @@ -945,7 +945,6 @@ def fused_update_with_index( emb_dim = grads.size(1) state_dim = self.get_state_dim(emb_dim) - print(f"Optimizer: {grads.shape}, {indices.shape}") rowwise_adagrad_for_combined_table( grads, indices, diff --git a/corelib/dynamicemb/src/optimizer.cu b/corelib/dynamicemb/src/optimizer.cu index 859847305..84f0ab5db 100644 --- a/corelib/dynamicemb/src/optimizer.cu +++ b/corelib/dynamicemb/src/optimizer.cu @@ -20,6 +20,7 @@ All rights reserved. # SPDX-License-Identifier: Apache-2.0 #include "optimizer_kernel.cuh" #include "torch_utils.h" #include "utils.h" +#include void find_pointers(std::shared_ptr table, const size_t n, const at::Tensor keys, at::Tensor values, @@ -545,7 +546,8 @@ void launch_update_kernel_for_combined_table( GradType *grads, WeightType *dev_table, WeightType *uvm_table, IndexType *indices, OptimizerType opt, int64_t const ev_nums, uint32_t const dim, int64_t const stride, int64_t const split_index, - int device_id) { + int device_id, + std::function smem_size_f = [](int block_size) { return 0; }) { auto stream = at::cuda::getCurrentCUDAStream().stream(); auto &device_prop = DeviceProp::getDeviceProp(device_id); if (dim % 4 == 0) { @@ -574,7 +576,7 @@ void launch_update_kernel_for_combined_table( auto kernel = update_with_index_kernel; - kernel<<>>( + kernel<<>>( ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices, nullptr, opt); } @@ -797,7 +799,8 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, launch_update_kernel_for_combined_table( grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id); + split_index, device_id, + [](int block_size) { return block_size * sizeof(float); }); }); }); }); diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 0f9bff56f..d28306cb6 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -444,6 +444,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS): [ (True, DynamicEmbPoolingMode.NONE, [8, 8, 8]), (False, DynamicEmbPoolingMode.NONE, [16, 16, 16]), + (False, DynamicEmbPoolingMode.NONE, [17, 17, 17]), (False, DynamicEmbPoolingMode.SUM, [128, 32, 16]), (False, DynamicEmbPoolingMode.MEAN, [4, 8, 16]), ], @@ -467,7 +468,10 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist max_capacity = 2048 dyn_emb_table_options_list = [] + cmp_with_torchrec = True for dim in dims: + if dim % 4 != 0: + cmp_with_torchrec = False dyn_emb_table_options = DynamicEmbTableOptions( dim=dim, init_capacity=max_capacity, @@ -492,49 +496,68 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist **opt_params, ) num_embs = [max_capacity // 2 for d in dims] - stbe = create_split_table_batched_embedding( - table_names, - feature_table_map, - OPTIM_TYPE[opt_type], - opt_params, - dims, - num_embs, - POOLING_MODE[pooling_mode], - device, - ) - init_embedding_tables(stbe, bdeb) - """ - feature number = 4, batch size = 2 - f0 [0,1], [12], - f1 [64,8], [12], - f2 [15, 2, 7], [105], - f3 [], [0] - """ - for i in range(10): - indices = torch.tensor( - [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device - ).to(key_type) - offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( - key_type + if cmp_with_torchrec: + stbe = create_split_table_batched_embedding( + table_names, + feature_table_map, + OPTIM_TYPE[opt_type], + opt_params, + dims, + num_embs, + POOLING_MODE[pooling_mode], + device, ) + init_embedding_tables(stbe, bdeb) + """ + feature number = 4, batch size = 2 + + f0 [0,1], [12], + f1 [64,8], [12], + f2 [15, 2, 7], [105], + f3 [], [0] + """ + for i in range(10): + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device + ).to(key_type) + offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( + key_type + ) - embs_bdeb = bdeb(indices, offsets) - embs_stbe = stbe(indices, offsets) + embs_bdeb = bdeb(indices, offsets) + embs_stbe = stbe(indices, offsets) - torch.cuda.synchronize() - with torch.no_grad(): - torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06) + torch.cuda.synchronize() + with torch.no_grad(): + torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06) - loss = embs_bdeb.mean() - loss.backward() - loss_stbe = embs_stbe.mean() - loss_stbe.backward() + loss = embs_bdeb.mean() + loss.backward() + loss_stbe = embs_stbe.mean() + loss_stbe.backward() - torch.cuda.synchronize() - torch.testing.assert_close(loss, loss_stbe) + torch.cuda.synchronize() + torch.testing.assert_close(loss, loss_stbe) + + print(f"Passed iteration {i}") + else: + # This scenario will not test correctness, but rather test whether it functions correctly. + for i in range(10): + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device + ).to(key_type) + offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( + key_type + ) + + embs_bdeb = bdeb(indices, offsets) + loss = embs_bdeb.mean() + loss.backward() + + torch.cuda.synchronize() - print(f"Passed iteration {i}") + print(f"Passed iteration {i}") if deterministic: del os.environ["DEMB_DETERMINISM_MODE"] From a239176a17331d18bef0fba6108a1b64e1af2f75 Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Tue, 20 Jan 2026 02:42:58 +0000 Subject: [PATCH 6/8] Remove comments in update --- corelib/dynamicemb/dynamicemb/key_value_table.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 008e27738..96af6327d 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -546,8 +546,8 @@ def update( batch = keys.size(0) - # if batch == 0: - # return None, None, None + if batch == 0: + return None, None, None device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) @@ -1367,8 +1367,8 @@ def update( batch = keys.size(0) - # if batch == 0: - # return None, None, None + if batch == 0: + return None, None, None device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) From 67b39edf8d16610189bd2bfca76690110f5c3b60 Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Tue, 27 Jan 2026 01:57:12 +0000 Subject: [PATCH 7/8] fix return value of DynamicEmbeddingTable.update when empty batch --- corelib/dynamicemb/dynamicemb/key_value_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index 96af6327d..f5f6e574f 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -1368,7 +1368,7 @@ def update( batch = keys.size(0) if batch == 0: - return None, None, None + return 0, None, None device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) From 4dcb338cf82d655b5c76ab91ff4a423c0cffabfd Mon Sep 17 00:00:00 2001 From: Jiashu Yao Date: Tue, 27 Jan 2026 02:03:54 +0000 Subject: [PATCH 8/8] fix as AI suggests, AI is useful and helpful --- corelib/dynamicemb/dynamicemb/key_value_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index f5f6e574f..2b3927da4 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -547,7 +547,7 @@ def update( batch = keys.size(0) if batch == 0: - return None, None, None + return 0, None, None device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device)