Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions corelib/dynamicemb/dynamicemb/key_value_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def find_impl(
load_from_pointers(pointers, unique_embs)

missing = torch.logical_not(founds)
num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device)
num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device)
num_missing_0: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need torch.zeros here?

num_missing_1: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device)
missing_keys: torch.Tensor = torch.empty_like(unique_keys)
missing_indices: torch.Tensor = torch.empty(
batch, dtype=torch.long, device=device
Expand Down Expand Up @@ -546,6 +546,9 @@ def update(

batch = keys.size(0)

if batch == 0:
return 0, None, None

device = keys.device
founds = torch.empty(batch, dtype=torch.bool, device=device)
pointers = torch.empty(batch, dtype=torch.long, device=device)
Expand Down Expand Up @@ -1184,6 +1187,16 @@ def find_impl(

scores = self.create_scores(batch, device, input_scores)

if batch == 0:
return (
0,
torch.empty_like(unique_keys),
torch.empty(batch, dtype=torch.long, device=device),
torch.empty(batch, dtype=torch.uint64, device=device)
if scores is not None
else None,
)

score_args_lookup = [
ScoreArg(
name=self.score_policy.name,
Expand All @@ -1203,8 +1216,8 @@ def find_impl(
)

missing = torch.logical_not(founds)
num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device)
num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device)
num_missing_0: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device)
num_missing_1: torch.Tensor = torch.zeros(1, dtype=torch.long, device=device)
missing_keys: torch.Tensor = torch.empty_like(unique_keys)
missing_indices: torch.Tensor = torch.empty(
batch, dtype=torch.long, device=device
Expand Down Expand Up @@ -1354,6 +1367,9 @@ def update(

batch = keys.size(0)

if batch == 0:
return 0, None, None

device = keys.device
founds = torch.empty(batch, dtype=torch.bool, device=device)
indices = torch.empty(batch, dtype=self.key_index_map.index_type, device=device)
Expand Down
6 changes: 4 additions & 2 deletions corelib/dynamicemb/src/dynamic_emb_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,10 @@ void load_from_combined_table(std::optional<at::Tensor> dev_table,
std::optional<at::Tensor> uvm_table,
at::Tensor indices, at::Tensor output) {

int64_t num_total = indices.size(0);
if (num_total == 0) {
return;
}
int64_t stride = -1;
int64_t dim = output.size(1);
if ((not dev_table.has_value()) and (not uvm_table.has_value())) {
Expand Down Expand Up @@ -934,8 +938,6 @@ void load_from_combined_table(std::optional<at::Tensor> dev_table,
auto val_type = get_data_type(output);
auto index_type = get_data_type(indices);

int64_t num_total = indices.size(0);

constexpr int kWarpSize = 32;
constexpr int MULTIPLIER = 4;
constexpr int BLOCK_SIZE_VEC = 64;
Expand Down
6 changes: 6 additions & 0 deletions corelib/dynamicemb/src/index_calculation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ void select(at::Tensor flags, at::Tensor inputs, at::Tensor outputs,
auto num_select_iter_type =
scalartype_to_datatype(num_selected.dtype().toScalarType());

if (num_total == 0) {
return;
}
DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] {
DISPATCH_INTEGER_DATATYPE_FUNCTION(
num_select_iter_type, NumSelectedIteratorT, [&] {
Expand All @@ -545,6 +548,9 @@ void select_index(at::Tensor flags, at::Tensor output_indices,
auto num_select_iter_type =
scalartype_to_datatype(num_selected.dtype().toScalarType());

if (num_total == 0) {
return;
}
DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] {
DISPATCH_INTEGER_DATATYPE_FUNCTION(
num_select_iter_type, NumSelectedIteratorT, [&] {
Expand Down
9 changes: 6 additions & 3 deletions corelib/dynamicemb/src/optimizer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ All rights reserved. # SPDX-License-Identifier: Apache-2.0
#include "optimizer_kernel.cuh"
#include "torch_utils.h"
#include "utils.h"
#include <functional>

void find_pointers(std::shared_ptr<dyn_emb::DynamicVariableBase> table,
const size_t n, const at::Tensor keys, at::Tensor values,
Expand Down Expand Up @@ -545,7 +546,8 @@ void launch_update_kernel_for_combined_table(
GradType *grads, WeightType *dev_table, WeightType *uvm_table,
IndexType *indices, OptimizerType opt, int64_t const ev_nums,
uint32_t const dim, int64_t const stride, int64_t const split_index,
int device_id) {
int device_id,
std::function<float(int)> smem_size_f = [](int block_size) { return 0; }) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto &device_prop = DeviceProp::getDeviceProp(device_id);
if (dim % 4 == 0) {
Expand Down Expand Up @@ -574,7 +576,7 @@ void launch_update_kernel_for_combined_table(

auto kernel = update_with_index_kernel<GradType, WeightType, IndexType,
OptimizerType>;
kernel<<<grid_size, block_size, 0, stream>>>(
kernel<<<grid_size, block_size, smem_size_f(block_size), stream>>>(
ev_nums, dim, stride, split_index, grads, dev_table, uvm_table, indices,
nullptr, opt);
}
Expand Down Expand Up @@ -797,7 +799,8 @@ void rowwise_adagrad_for_combined_table(at::Tensor grads, at::Tensor indices,

launch_update_kernel_for_combined_table<g_t, w_t, i_t, decltype(opt)>(
grad_ptr, dev_ptr, uvm_ptr, index_ptr, opt, ev_nums, dim, stride,
split_index, device_id);
split_index, device_id,
[](int block_size) { return block_size * sizeof(float); });
});
});
});
Expand Down
198 changes: 160 additions & 38 deletions corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):

"""
For torchrec's adam optimizer, it will increment the optimizer_step in every forward,
which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
which will affect the weights update, pay attention to it or try to use `set_optimizer_step()`
to control(not verified) it.
"""

Expand Down Expand Up @@ -444,6 +444,7 @@ def test_forward_train_eval(opt_type, opt_params, caching, deterministic, PS):
[
(True, DynamicEmbPoolingMode.NONE, [8, 8, 8]),
(False, DynamicEmbPoolingMode.NONE, [16, 16, 16]),
(False, DynamicEmbPoolingMode.NONE, [17, 17, 17]),
(False, DynamicEmbPoolingMode.SUM, [128, 32, 16]),
(False, DynamicEmbPoolingMode.MEAN, [4, 8, 16]),
],
Expand All @@ -467,7 +468,10 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
max_capacity = 2048

dyn_emb_table_options_list = []
cmp_with_torchrec = True
for dim in dims:
if dim % 4 != 0:
cmp_with_torchrec = False
dyn_emb_table_options = DynamicEmbTableOptions(
dim=dim,
init_capacity=max_capacity,
Expand All @@ -492,49 +496,68 @@ def test_backward(opt_type, opt_params, caching, pooling_mode, dims, determinist
**opt_params,
)
num_embs = [max_capacity // 2 for d in dims]
stbe = create_split_table_batched_embedding(
table_names,
feature_table_map,
OPTIM_TYPE[opt_type],
opt_params,
dims,
num_embs,
POOLING_MODE[pooling_mode],
device,
)
init_embedding_tables(stbe, bdeb)
"""
feature number = 4, batch size = 2

f0 [0,1], [12],
f1 [64,8], [12],
f2 [15, 2, 7], [105],
f3 [], [0]
"""
for i in range(10):
indices = torch.tensor(
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
).to(key_type)
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
key_type
if cmp_with_torchrec:
stbe = create_split_table_batched_embedding(
table_names,
feature_table_map,
OPTIM_TYPE[opt_type],
opt_params,
dims,
num_embs,
POOLING_MODE[pooling_mode],
device,
)
init_embedding_tables(stbe, bdeb)
"""
feature number = 4, batch size = 2
f0 [0,1], [12],
f1 [64,8], [12],
f2 [15, 2, 7], [105],
f3 [], [0]
"""
for i in range(10):
indices = torch.tensor(
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
).to(key_type)
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
key_type
)

embs_bdeb = bdeb(indices, offsets)
embs_stbe = stbe(indices, offsets)

torch.cuda.synchronize()
with torch.no_grad():
torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06)
embs_bdeb = bdeb(indices, offsets)
embs_stbe = stbe(indices, offsets)

torch.cuda.synchronize()
with torch.no_grad():
torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06)

loss = embs_bdeb.mean()
loss.backward()
loss_stbe = embs_stbe.mean()
loss_stbe.backward()

torch.cuda.synchronize()
torch.testing.assert_close(loss, loss_stbe)

print(f"Passed iteration {i}")
else:
# This scenario will not test correctness, but rather test whether it functions correctly.
for i in range(10):
indices = torch.tensor(
[0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device
).to(key_type)
offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to(
key_type
)

loss = embs_bdeb.mean()
loss.backward()
loss_stbe = embs_stbe.mean()
loss_stbe.backward()
embs_bdeb = bdeb(indices, offsets)
loss = embs_bdeb.mean()
loss.backward()

torch.cuda.synchronize()
torch.testing.assert_close(loss, loss_stbe)
torch.cuda.synchronize()

print(f"Passed iteration {i}")
print(f"Passed iteration {i}")

if deterministic:
del os.environ["DEMB_DETERMINISM_MODE"]
Expand Down Expand Up @@ -853,3 +876,102 @@ def test_deterministic_insert(opt_type, opt_params, caching, PS, iteration, batc

del os.environ["DEMB_DETERMINISM_MODE"]
print("all check passed")


@pytest.mark.parametrize(
"opt_type,opt_params",
[
(EmbOptimType.SGD, {"learning_rate": 0.3}),
(
EmbOptimType.EXACT_ROWWISE_ADAGRAD,
{
"learning_rate": 0.3,
"eps": 3e-5,
},
),
],
)
@pytest.mark.parametrize("dim", [7, 8])
@pytest.mark.parametrize("caching", [True, False])
@pytest.mark.parametrize("deterministic", [True, False])
@pytest.mark.parametrize("PS", [None])
def test_empty_batch(opt_type, opt_params, dim, caching, deterministic, PS):
print(
f"step in test_forward_train_eval_empty_batch , opt_type = {opt_type} opt_params = {opt_params}"
)

if deterministic:
os.environ["DEMB_DETERMINISM_MODE"] = "ON"

assert torch.cuda.is_available()
device_id = 0
device = torch.device(f"cuda:{device_id}")

dims = [dim, dim, dim]
table_names = ["table0", "table1", "table2"]
key_type = torch.int64
value_type = torch.float32

init_capacity = 1024
max_capacity = 2048

dyn_emb_table_options_list = []
for dim in dims:
dyn_emb_table_options = DynamicEmbTableOptions(
dim=dim,
init_capacity=init_capacity,
max_capacity=max_capacity,
index_type=key_type,
embedding_dtype=value_type,
device_id=device_id,
score_strategy=DynamicEmbScoreStrategy.TIMESTAMP,
caching=caching,
local_hbm_for_values=1024**3,
external_storage=PS,
)
dyn_emb_table_options_list.append(dyn_emb_table_options)

bdebt = BatchedDynamicEmbeddingTablesV2(
table_names=table_names,
table_options=dyn_emb_table_options_list,
feature_table_map=[0, 0, 1, 2],
pooling_mode=DynamicEmbPoolingMode.NONE,
optimizer=opt_type,
use_index_dedup=True,
**opt_params,
)
bdebt.enable_prefetch = True
"""
feature number = 4, batch size = 1
f0 [],
f1 [],
f2 [],
f3 [],
"""
indices = torch.tensor([], dtype=key_type, device=device)
offsets = torch.tensor([0, 0, 0, 0, 0], dtype=key_type, device=device)

pretch_stream = torch.cuda.Stream()
forward_stream = torch.cuda.Stream()

if caching:
with torch.cuda.stream(pretch_stream):
bdebt.prefetch(indices, offsets, forward_stream)
torch.cuda.synchronize()

with torch.cuda.stream(forward_stream):
res = bdebt(indices, offsets)
torch.cuda.synchronize()

res.mean().backward()

with torch.no_grad():
bdebt.eval()
bdebt(indices, offsets)
torch.cuda.synchronize()

if deterministic:
del os.environ["DEMB_DETERMINISM_MODE"]

print("all check passed")