diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py index cd0db171..795cb234 100644 --- a/corelib/dynamicemb/dynamicemb/key_value_table.py +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -546,6 +546,9 @@ def update( batch = keys.size(0) + if batch == 0: + return 0, None, None + device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) pointers = torch.empty(batch, dtype=torch.long, device=device) @@ -1184,6 +1187,16 @@ def find_impl( scores = self.create_scores(batch, device, input_scores) + if batch == 0: + return ( + 0, + torch.empty_like(unique_keys), + torch.empty(batch, dtype=torch.long, device=device), + torch.empty(batch, dtype=torch.uint64, device=device) + if scores is not None + else None, + ) + score_args_lookup = [ ScoreArg( name=self.score_policy.name, @@ -1354,6 +1367,9 @@ def update( batch = keys.size(0) + if batch == 0: + return 0, None, None + device = keys.device founds = torch.empty(batch, dtype=torch.bool, device=device) indices = torch.empty(batch, dtype=self.key_index_map.index_type, device=device) diff --git a/corelib/dynamicemb/src/dynamic_emb_op.cu b/corelib/dynamicemb/src/dynamic_emb_op.cu index 2d58bf7c..c484e3c0 100644 --- a/corelib/dynamicemb/src/dynamic_emb_op.cu +++ b/corelib/dynamicemb/src/dynamic_emb_op.cu @@ -898,6 +898,10 @@ void load_from_combined_table(std::optional dev_table, std::optional uvm_table, at::Tensor indices, at::Tensor output) { + int64_t num_total = indices.size(0); + if (num_total == 0) { + return; + } int64_t stride = -1; int64_t dim = output.size(1); if ((not dev_table.has_value()) and (not uvm_table.has_value())) { @@ -934,8 +938,6 @@ void load_from_combined_table(std::optional dev_table, auto val_type = get_data_type(output); auto index_type = get_data_type(indices); - int64_t num_total = indices.size(0); - constexpr int kWarpSize = 32; constexpr int MULTIPLIER = 4; constexpr int BLOCK_SIZE_VEC = 64; diff --git a/corelib/dynamicemb/src/index_calculation.cu b/corelib/dynamicemb/src/index_calculation.cu index 10141068..be5053e9 100644 --- a/corelib/dynamicemb/src/index_calculation.cu +++ b/corelib/dynamicemb/src/index_calculation.cu @@ -523,6 +523,15 @@ void select(at::Tensor flags, at::Tensor inputs, at::Tensor outputs, auto num_select_iter_type = scalartype_to_datatype(num_selected.dtype().toScalarType()); + if (num_total == 0) { + DISPATCH_INTEGER_DATATYPE_FUNCTION( + num_select_iter_type, NumSelectedIteratorT, [&] { + DEMB_CUDA_CHECK(cudaMemsetAsync( + reinterpret_cast(num_selected.data_ptr()), 0, + sizeof(NumSelectedIteratorT), stream)); + }); + return; + } DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { DISPATCH_INTEGER_DATATYPE_FUNCTION( num_select_iter_type, NumSelectedIteratorT, [&] { @@ -545,6 +554,15 @@ void select_index(at::Tensor flags, at::Tensor output_indices, auto num_select_iter_type = scalartype_to_datatype(num_selected.dtype().toScalarType()); + if (num_total == 0) { + DISPATCH_INTEGER_DATATYPE_FUNCTION( + num_select_iter_type, NumSelectedIteratorT, [&] { + DEMB_CUDA_CHECK(cudaMemsetAsync( + reinterpret_cast(num_selected.data_ptr()), 0, + sizeof(NumSelectedIteratorT), stream)); + }); + return; + } DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { DISPATCH_INTEGER_DATATYPE_FUNCTION( num_select_iter_type, NumSelectedIteratorT, [&] { diff --git a/corelib/dynamicemb/src/optimizer.cu b/corelib/dynamicemb/src/optimizer.cu index 85984730..84f0ab5d 100644 --- a/corelib/dynamicemb/src/optimizer.cu +++ b/corelib/dynamicemb/src/optimizer.cu @@ -20,6 +20,7 @@ All rights reserved. # SPDX-License-Identifier: Apache-2.0 #include "optimizer_kernel.cuh" #include "torch_utils.h" #include "utils.h" +#include void find_pointers(std::shared_ptr table, const size_t n, const at::Tensor keys, at::Tensor values, @@ -545,7 +546,8 @@ void launch_update_kernel_for_combined_table( GradType *grads, WeightType *dev_table, WeightType *uvm_table, IndexType *indices, OptimizerType opt, int64_t const ev_nums, uint32_t const dim, int64_t const stride, int64_t const split_index, - int device_id) { + int device_id, + std::function smem_size_f = [](int block_size) { return 0; }) { auto stream = at::cuda::getCurrentCUDAStream().stream(); auto &device_prop = DeviceProp::getDeviceProp(device_id); if (dim % 4 == 0) { @@ -574,7 +576,7 @@ void launch_update_kernel_for_combined_table( auto kernel = update_with_index_kernel; - kernel<<>>( + kernel<<>>( ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices, nullptr, opt); } @@ -797,7 +799,8 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices, launch_update_kernel_for_combined_table( grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride, - split_index, device_id); + split_index, device_id, + [](int block_size) { return block_size * sizeof(float); }); }); }); }); diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py index 7d02c2f4..d28306cb 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -404,7 +404,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS): """ For torchrec's adam optimizer, it will increment the optimizer_step in every forward, - which will affect the weights update, pay attention to it or try to use `set_optimizer_step()` + which will affect the weights update, pay attention to it or try to use `set_optimizer_step()` to control(not verified) it. """ @@ -444,6 +444,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS): [ (True, DynamicEmbPoolingMode.NONE, [8, 8, 8]), (False, DynamicEmbPoolingMode.NONE, [16, 16, 16]), + (False, DynamicEmbPoolingMode.NONE, [17, 17, 17]), (False, DynamicEmbPoolingMode.SUM, [128, 32, 16]), (False, DynamicEmbPoolingMode.MEAN, [4, 8, 16]), ], @@ -467,7 +468,10 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist max_capacity = 2048 dyn_emb_table_options_list = [] + cmp_with_torchrec = True for dim in dims: + if dim % 4 != 0: + cmp_with_torchrec = False dyn_emb_table_options = DynamicEmbTableOptions( dim=dim, init_capacity=max_capacity, @@ -492,49 +496,68 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist **opt_params, ) num_embs = [max_capacity // 2 for d in dims] - stbe = create_split_table_batched_embedding( - table_names, - feature_table_map, - OPTIM_TYPE[opt_type], - opt_params, - dims, - num_embs, - POOLING_MODE[pooling_mode], - device, - ) - init_embedding_tables(stbe, bdeb) - """ - feature number = 4, batch size = 2 - f0 [0,1], [12], - f1 [64,8], [12], - f2 [15, 2, 7], [105], - f3 [], [0] - """ - for i in range(10): - indices = torch.tensor( - [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device - ).to(key_type) - offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( - key_type + if cmp_with_torchrec: + stbe = create_split_table_batched_embedding( + table_names, + feature_table_map, + OPTIM_TYPE[opt_type], + opt_params, + dims, + num_embs, + POOLING_MODE[pooling_mode], + device, ) + init_embedding_tables(stbe, bdeb) + """ + feature number = 4, batch size = 2 + + f0 [0,1], [12], + f1 [64,8], [12], + f2 [15, 2, 7], [105], + f3 [], [0] + """ + for i in range(10): + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device + ).to(key_type) + offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( + key_type + ) - embs_bdeb = bdeb(indices, offsets) - embs_stbe = stbe(indices, offsets) - - torch.cuda.synchronize() - with torch.no_grad(): - torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06) + embs_bdeb = bdeb(indices, offsets) + embs_stbe = stbe(indices, offsets) + + torch.cuda.synchronize() + with torch.no_grad(): + torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06) + + loss = embs_bdeb.mean() + loss.backward() + loss_stbe = embs_stbe.mean() + loss_stbe.backward() + + torch.cuda.synchronize() + torch.testing.assert_close(loss, loss_stbe) + + print(f"Passed iteration {i}") + else: + # This scenario will not test correctness, but rather test whether it functions correctly. + for i in range(10): + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device + ).to(key_type) + offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( + key_type + ) - loss = embs_bdeb.mean() - loss.backward() - loss_stbe = embs_stbe.mean() - loss_stbe.backward() + embs_bdeb = bdeb(indices, offsets) + loss = embs_bdeb.mean() + loss.backward() - torch.cuda.synchronize() - torch.testing.assert_close(loss, loss_stbe) + torch.cuda.synchronize() - print(f"Passed iteration {i}") + print(f"Passed iteration {i}") if deterministic: del os.environ["DEMB_DETERMINISM_MODE"] @@ -853,3 +876,102 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc del os.environ["DEMB_DETERMINISM_MODE"] print("all check passed") + + +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.EXACT_ROWWISE_ADAGRAD, + { + "learning_rate": 0.3, + "eps": 3e-5, + }, + ), + ], +) +@pytest.mark.parametrize("dim", [7, 8]) +@pytest.mark.parametrize("caching", [True, False]) +@pytest.mark.parametrize("deterministic", [True, False]) +@pytest.mark.parametrize("PS", [None]) +def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS): + print( + f"step in test_forward_train_eval_empty_batch , opt_type = {opt_type} opt_params = {opt_params}" + ) + + if deterministic: + os.environ["DEMB_DETERMINISM_MODE"] = "ON" + + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + dims = [dim, dim, dim] + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + init_capacity = 1024 + max_capacity = 2048 + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=init_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + caching=caching, + local_hbm_for_values=1024**3, + external_storage=PS, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + bdebt = BatchedDynamicEmbeddingTablesV2( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=[0, 0, 1, 2], + pooling_mode=DynamicEmbPoolingMode.NONE, + optimizer=opt_type, + use_index_dedup=True, + **opt_params, + ) + bdebt.enable_prefetch = True + """ + feature number = 4, batch size = 1 + + f0 [], + f1 [], + f2 [], + f3 [], + """ + indices = torch.tensor([], dtype=key_type, device=device) + offsets = torch.tensor([0, 0, 0, 0, 0], dtype=key_type, device=device) + + pretch_stream = torch.cuda.Stream() + forward_stream = torch.cuda.Stream() + + if caching: + with torch.cuda.stream(pretch_stream): + bdebt.prefetch(indices, offsets, forward_stream) + torch.cuda.synchronize() + + with torch.cuda.stream(forward_stream): + res = bdebt(indices, offsets) + torch.cuda.synchronize() + + res.mean().backward() + + with torch.no_grad(): + bdebt.eval() + bdebt(indices, offsets) + torch.cuda.synchronize() + + if deterministic: + del os.environ["DEMB_DETERMINISM_MODE"] + + print("all check passed")