core: Safer tensor handles + checked vectorization#877
core: Safer tensor handles + checked vectorization#877ariawisp wants to merge 1 commit intotracel-ai:mainfrom
Conversation
27f4c2e to
8e4b170
Compare
nathanielsimard
left a comment
There was a problem hiding this comment.
I'm not convinced this is a good idea; those calculations are 100% wasteful with Burn, since validations are done at the tensor API level (user-facing) and not internally. We assume data consistency internally, which reduces error handling to a minimum. Otherwise, we would have the same error handling at multiple levels of abstractions, which is very wasteful and adds complexity.
Do you have opinions on this @wingertge
| /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks | ||
| /// for vectorization compatibility. | ||
| /// Try to convert the handle into a tensor argument, validating that the | ||
| /// requested vectorization factor is supported by the runtime. This does not | ||
| /// enforce inner-most contiguity or alignment requirements as kernels may | ||
| /// legally vectorize along axes other than the innermost. | ||
| pub fn try_as_tensor_arg( | ||
| &'a self, | ||
| vectorization: u8, | ||
| ) -> Result<TensorArg<'a, R>, TensorArgError> { | ||
| if !R::supported_line_sizes().contains(&vectorization) { | ||
| return Err(TensorArgError::UnsupportedVectorization { requested: vectorization, supported: R::supported_line_sizes() }); | ||
| } | ||
| Ok(self.as_tensor_arg(vectorization)) | ||
| } |
There was a problem hiding this comment.
The problem I see here is that the validation is done in 100% of cases in Burn and CubeCL when we choose the line size, not when applying it! It's kind of wasteful to do it multiple times, especially since we iterate over a list each time. I don't mind too much having a try function, but it would not be good practice to use them; validation should be done before creating the tensor argument.
| if shape.len() != strides.len() { | ||
| return Err(TensorHandleError::RankMismatch { | ||
| shape_rank: shape.len(), | ||
| stride_rank: strides.len(), | ||
| }); | ||
| } | ||
| if elem_size == 0 { | ||
| return Err(TensorHandleError::ElemSizeZero); | ||
| } | ||
| // Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed). | ||
| for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() { | ||
| if s == 0 && d > 1 { | ||
| return Err(TensorHandleError::ZeroStride { axis: i }); | ||
| } | ||
| } |
There was a problem hiding this comment.
The same thing here: most of those things are validated in other places, and it's kind of wasteful to do those validations multiple times.
|
I actually think it could be a good idea, but specifically gated behind |
58e6a34 to
714498a
Compare
|
Per feedback, I reverted internal hot paths to the fast path ( |
nathanielsimard
left a comment
There was a problem hiding this comment.
That seems like the right balance. The CI should be fixed before merging.
b52c9ee to
77556c2
Compare
…t across crates
- TensorHandleRef::{try_from_parts, try_from_typed}
- TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only)
- Errors: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported }
- Adopt try_as_tensor_arg in attention/matmul/convolution/reduce/std
- Runtime tests for handle validation and unsupported vectorization factors
core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs
internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
77556c2 to
3a98f13
Compare
|
Branch has been rebased on main (cf40b4e) and confirmed CI workflow passes on my fork. |
This PR makes constructing tensor handles and choosing a vectorization factor safer and more ergonomic—especially for host wrappers and FFI—by surfacing misuse as clear, early errors instead of relying on scattered
unsafeand assumptions.Behavior
unsafeentry points remain available; no changes to kernel or runtime ABIs.Impact
Validation
Notes
-D warnings.PR has been validated with Burn — no compilation or test errors.