-
Notifications
You must be signed in to change notification settings - Fork 43
Fix wrong evicted values when insert failed/busy in insert_and_evict. #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
75d8152 to
c80fc74
Compare
| if insert_busy_mask.sum().item() != 0: | ||
| out_indices = indices[insert_busy_mask] | ||
| evicted_values[out_indices, :] = values.to(self.value_type())[ | ||
| insert_busy_mask | ||
| ] | ||
| indices[insert_busy_mask] = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the if statement( as well as h2d), can't we?
If there is no busy indices, indices[insert_busy_mask] will return empty tensor. And the folllowing ops should be nop?
@jiashuy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the if statement, but indices[insert_busy_mask] will bring d2h, because torch C++ uses cub to do mask select and synchronize to got the size of out_indices .
So to remove the d2h thoroughly, we need a cutomized CUDA kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please try indices.masked_filled_(-1) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know masked_filled, can evicted_values be filled use masked_filled ?
If so, it will be simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's possible not to build out_indices, it's easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, I mean the line 1816. Not 1812.
I believe indices[insert_busy_mask] load implies an inevitable d2h. Yeah, I now agree with you. But I'm not sure if it's really necessary to remove the d2h. (Significant perf loss?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shijieliu and I think the more the d2h, the harder to pipeline the embedding's forward.
You can see there still some d2h in the forward, but we don't want to make it more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And as for the performance, I haven't test it. But if we don't use pipeline, I think it make little diference here?
Insertion failure hardly happend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I found out a useful op:
src.masked_scatter_(mask, value)
https://docs.pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter.html
| ) | ||
| evicted_scores = evicted_scores[0] | ||
|
|
||
| select_insert_failed_values( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried the masked_scatter_() operation? I think if it meets our requirements, we can adopt it considering the maintenance and robustness. (Unless the perf is really dissatisfying).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I haven't tried.
Not sure whether it supports CUDA device, bfloat16 dtype and multi-dimension.
In order to achieve the goal quickly, I implemented this fused kernel yesterday.
I will try the masked_scatter_ and masked_filled_ in the future, maybe another PR, how do you think? @JacoCheung
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(torch) It should support all cases. My intent is to shift the repsonsibility to pytorch. So we have less job (including compile the unit).
I think it would be better to verify at this moment. 🚀
|
/review |
Greptile SummaryThis PR fixes a critical bug where evicted values were incorrectly populated when insertion operations failed or were busy in The fix consists of two parts:
The implementation includes both vectorized (Vec4) and non-vectorized variants for optimal performance across different embedding dimensions. Tests have been updated to properly handle the Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant KVT as DynamicEmbeddingTable
participant KIM as KeyIndexMap
participant Kernel as table_insert_and_evict_kernel
participant Select as select_insert_failed_values
participant Load as load_from_combined_table
User->>KVT: insert_and_evict(keys, values)
KVT->>KVT: Allocate insert_results tensor
KVT->>KIM: insert_and_evict(keys, indices, insert_results)
KIM->>Kernel: Launch CUDA kernel
Note over Kernel: For each key insertion attempt
alt Insert succeeds
Kernel->>Kernel: Set index = bucket_id * capacity + iter
Kernel->>Kernel: Set insert_results[i] = Insert/Reclaim/Assign/Evict
else Insert fails (Busy)
Kernel->>Kernel: Compute out_id from evicted_counter
Kernel->>Kernel: Set index = out_id (NEW FIX)
Kernel->>Kernel: Set insert_results[i] = Busy
Kernel->>Kernel: Store key in evicted_keys[out_id]
end
Kernel-->>KIM: Return evicted data
KIM-->>KVT: Return num_evicted, evicted_keys, evicted_indices
KVT->>Select: select_insert_failed_values(insert_results, indices, values, evicted_values)
Note over Select: CUDA kernel processes failed insertions
loop For each batch item with Busy status
Select->>Select: Read out_idx = indices[emb_id]
Select->>Select: Copy values[emb_id] to evicted_values[out_idx]
Select->>Select: Set indices[emb_id] = -1
end
Select-->>KVT: Return (indices updated)
KVT->>Load: load_from_combined_table(evicted_indices, evicted_values)
Load-->>KVT: Load complete
KVT-->>User: Return success
|

Description
Checklist