Skip to content

Commit 03f71ec

Browse files
authored
[QDP] DLPack shape/strides: Support batch 2D tensor (#747)
1 parent aa30679 commit 03f71ec

File tree

11 files changed

+221
-41
lines changed

11 files changed

+221
-41
lines changed

qdp/benchmark/benchmark_e2e.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,17 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
277277
dlpack_time = time.perf_counter() - dlpack_start
278278
print(f" DLPack conversion: {dlpack_time:.4f} s")
279279

280-
# Reshape to [n_samples, state_len] (still complex)
280+
# Tensor is already 2D [n_samples, state_len] from to_dlpack()
281281
state_len = 1 << n_qubits
282+
assert gpu_batched.shape == (n_samples, state_len), (
283+
f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
284+
)
282285

283286
# Convert to float for model (batch already on GPU)
284287
reshape_start = time.perf_counter()
285-
gpu_reshaped = gpu_batched.view(n_samples, state_len)
286-
gpu_all_data = gpu_reshaped.abs().to(torch.float32)
288+
gpu_all_data = gpu_batched.abs().to(torch.float32)
287289
reshape_time = time.perf_counter() - reshape_start
288-
print(f" Reshape & convert: {reshape_time:.4f} s")
290+
print(f" Convert to float32: {reshape_time:.4f} s")
289291

290292
# Forward pass (data already on GPU)
291293
for i in range(0, n_samples, BATCH_SIZE):
@@ -299,7 +301,7 @@ def run_mahout_parquet(engine, n_qubits, n_samples):
299301
# Clean cache after benchmark completion
300302
clean_cache()
301303

302-
return total_time, gpu_reshaped
304+
return total_time, gpu_batched
303305

304306

305307
# -----------------------------------------------------------
@@ -325,13 +327,16 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
325327
dlpack_time = time.perf_counter() - dlpack_start
326328
print(f" DLPack conversion: {dlpack_time:.4f} s")
327329

330+
# Tensor is already 2D [n_samples, state_len] from to_dlpack()
328331
state_len = 1 << n_qubits
332+
assert gpu_batched.shape == (n_samples, state_len), (
333+
f"Expected shape ({n_samples}, {state_len}), got {gpu_batched.shape}"
334+
)
329335

330336
reshape_start = time.perf_counter()
331-
gpu_reshaped = gpu_batched.view(n_samples, state_len)
332-
gpu_all_data = gpu_reshaped.abs().to(torch.float32)
337+
gpu_all_data = gpu_batched.abs().to(torch.float32)
333338
reshape_time = time.perf_counter() - reshape_start
334-
print(f" Reshape & convert: {reshape_time:.4f} s")
339+
print(f" Convert to float32: {reshape_time:.4f} s")
335340

336341
for i in range(0, n_samples, BATCH_SIZE):
337342
batch = gpu_all_data[i : i + BATCH_SIZE]
@@ -344,7 +349,7 @@ def run_mahout_arrow(engine, n_qubits, n_samples):
344349
# Clean cache after benchmark completion
345350
clean_cache()
346351

347-
return total_time, gpu_reshaped
352+
return total_time, gpu_batched
348353

349354

350355
def compare_states(name_a, states_a, name_b, states_b):

qdp/qdp-core/src/dlpack.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,25 @@ impl GpuStateVector {
120120
/// Freed by DLPack deleter when PyTorch releases tensor.
121121
/// Do not free manually.
122122
pub fn to_dlpack(&self) -> *mut DLManagedTensor {
123-
// Allocate shape/strides on heap (freed by deleter)
124-
let shape = vec![self.size_elements as i64];
125-
let strides = vec![1i64];
123+
// Always return 2D tensor: Batch [num_samples, state_len], Single [1, state_len]
124+
let (shape, strides) = if let Some(num_samples) = self.num_samples {
125+
// Batch: [num_samples, state_len_per_sample]
126+
debug_assert!(
127+
num_samples > 0 && self.size_elements % num_samples == 0,
128+
"Batch state vector size must be divisible by num_samples"
129+
);
130+
let state_len_per_sample = self.size_elements / num_samples;
131+
let shape = vec![num_samples as i64, state_len_per_sample as i64];
132+
let strides = vec![state_len_per_sample as i64, 1i64];
133+
(shape, strides)
134+
} else {
135+
// Single: [1, size_elements]
136+
let state_len = self.size_elements;
137+
let shape = vec![1i64, state_len as i64];
138+
let strides = vec![state_len as i64, 1i64];
139+
(shape, strides)
140+
};
141+
let ndim: c_int = 2;
126142

127143
// Transfer ownership to DLPack deleter
128144
let shape_ptr = Box::into_raw(shape.into_boxed_slice()) as *mut i64;
@@ -142,7 +158,7 @@ impl GpuStateVector {
142158
device_type: DLDeviceType::kDLCUDA,
143159
device_id: self.device_id as c_int,
144160
},
145-
ndim: 1,
161+
ndim,
146162
dtype: DLDataType {
147163
code: DL_COMPLEX,
148164
bits: dtype_bits,

qdp/qdp-core/src/gpu/memory.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ pub struct GpuStateVector {
190190
pub(crate) buffer: Arc<BufferStorage>,
191191
pub num_qubits: usize,
192192
pub size_elements: usize,
193+
/// Batch size (None for single state)
194+
pub(crate) num_samples: Option<usize>,
193195
pub device_id: usize,
194196
}
195197

@@ -230,6 +232,7 @@ impl GpuStateVector {
230232
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
231233
num_qubits: qubits,
232234
size_elements: _size_elements,
235+
num_samples: None,
233236
device_id: _device.ordinal(),
234237
})
235238
}
@@ -302,6 +305,7 @@ impl GpuStateVector {
302305
buffer: Arc::new(BufferStorage::F64(GpuBufferRaw { slice })),
303306
num_qubits: qubits,
304307
size_elements: total_elements,
308+
num_samples: Some(num_samples),
305309
device_id: _device.ordinal(),
306310
})
307311
}
@@ -367,6 +371,7 @@ impl GpuStateVector {
367371
buffer: Arc::new(BufferStorage::F32(GpuBufferRaw { slice })),
368372
num_qubits: self.num_qubits,
369373
size_elements: self.size_elements,
374+
num_samples: self.num_samples, // Preserve batch information
370375
device_id: device.ordinal(),
371376
})
372377
}

qdp/qdp-core/src/lib.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl QdpEngine {
8787
/// * `encoding_method` - Strategy: "amplitude", "angle", or "basis"
8888
///
8989
/// # Returns
90-
/// DLPack pointer for zero-copy PyTorch integration
90+
/// DLPack pointer for zero-copy PyTorch integration (shape: [1, 2^num_qubits])
9191
///
9292
/// # Safety
9393
/// Pointer freed by DLPack deleter, do not free manually.
@@ -201,6 +201,27 @@ impl QdpEngine {
201201
if sample_size == 0 {
202202
return Err(MahoutError::InvalidInput("Sample size cannot be zero".into()));
203203
}
204+
if sample_size > STAGE_SIZE_ELEMENTS {
205+
return Err(MahoutError::InvalidInput(format!(
206+
"Sample size {} exceeds staging buffer capacity {} (elements)",
207+
sample_size, STAGE_SIZE_ELEMENTS
208+
)));
209+
}
210+
211+
// Reuse a single norm buffer across chunks to avoid per-chunk allocations.
212+
//
213+
// Important: the norm buffer must outlive the async kernels that consume it.
214+
// Per-chunk allocation + drop can lead to use-after-free when the next chunk
215+
// reuses the same device memory while the previous chunk is still running.
216+
let max_samples_per_chunk = std::cmp::max(
217+
1,
218+
std::cmp::min(num_samples, STAGE_SIZE_ELEMENTS / sample_size),
219+
);
220+
let mut norm_buffer = self.device.alloc_zeros::<f64>(max_samples_per_chunk)
221+
.map_err(|e| MahoutError::MemoryAllocation(format!(
222+
"Failed to allocate norm buffer: {:?}",
223+
e
224+
)))?;
204225

205226
full_buf_tx.send(Ok((host_buf_first, first_len)))
206227
.map_err(|_| MahoutError::Io("Failed to send first buffer".into()))?;
@@ -277,9 +298,10 @@ impl QdpEngine {
277298
let state_ptr_offset = total_state_vector.ptr_void().cast::<u8>()
278299
.add(offset_bytes)
279300
.cast::<std::ffi::c_void>();
280-
281-
let mut norm_buffer = self.device.alloc_zeros::<f64>(samples_in_chunk)
282-
.map_err(|e| MahoutError::MemoryAllocation(format!("Failed to allocate norm buffer: {:?}", e)))?;
301+
debug_assert!(
302+
samples_in_chunk <= max_samples_per_chunk,
303+
"samples_in_chunk must be <= max_samples_per_chunk"
304+
);
283305

284306
{
285307
crate::profile_scope!("GPU::NormBatch");

qdp/qdp-core/src/preprocessing.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ impl Preprocessor {
8484
sample_size: usize,
8585
num_qubits: usize,
8686
) -> Result<()> {
87+
if num_samples == 0 {
88+
return Err(MahoutError::InvalidInput(
89+
"num_samples must be greater than 0".to_string()
90+
));
91+
}
92+
8793
if batch_data.len() != num_samples * sample_size {
8894
return Err(MahoutError::InvalidInput(
8995
format!("Batch data length {} doesn't match num_samples {} * sample_size {}",

qdp/qdp-core/tests/api_workflow.rs

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ fn test_amplitude_encoding_workflow() {
5555
println!("Created test data: {} elements", data.len());
5656

5757
let result = engine.encode(&data, 10, "amplitude");
58-
assert!(result.is_ok(), "Encoding should succeed");
59-
60-
let dlpack_ptr = result.unwrap();
58+
let dlpack_ptr = result.expect("Encoding should succeed");
6159
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
6260
println!("PASS: Encoding succeeded, DLPack pointer valid");
6361

@@ -91,9 +89,7 @@ fn test_amplitude_encoding_async_pipeline() {
9189
println!("Created test data: {} elements", data.len());
9290

9391
let result = engine.encode(&data, 18, "amplitude");
94-
assert!(result.is_ok(), "Encoding should succeed");
95-
96-
let dlpack_ptr = result.unwrap();
92+
let dlpack_ptr = result.expect("Encoding should succeed");
9793
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
9894
println!("PASS: Encoding succeeded, DLPack pointer valid");
9995

@@ -108,6 +104,104 @@ fn test_amplitude_encoding_async_pipeline() {
108104
}
109105
}
110106

107+
#[test]
108+
#[cfg(target_os = "linux")]
109+
fn test_batch_dlpack_2d_shape() {
110+
println!("Testing batch DLPack 2D shape...");
111+
112+
let engine = match QdpEngine::new(0) {
113+
Ok(e) => e,
114+
Err(_) => {
115+
println!("SKIP: No GPU available");
116+
return;
117+
}
118+
};
119+
120+
// Create batch data: 3 samples, each with 4 elements (2 qubits)
121+
let num_samples = 3;
122+
let num_qubits = 2;
123+
let sample_size = 4;
124+
let batch_data: Vec<f64> = (0..num_samples * sample_size)
125+
.map(|i| (i as f64) / 10.0)
126+
.collect();
127+
128+
let result = engine.encode_batch(&batch_data, num_samples, sample_size, num_qubits, "amplitude");
129+
let dlpack_ptr = result.expect("Batch encoding should succeed");
130+
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
131+
132+
unsafe {
133+
let managed = &*dlpack_ptr;
134+
let tensor = &managed.dl_tensor;
135+
136+
// Verify 2D shape for batch tensor
137+
assert_eq!(tensor.ndim, 2, "Batch tensor should be 2D");
138+
139+
let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize);
140+
assert_eq!(shape_slice[0], num_samples as i64, "First dimension should be num_samples");
141+
assert_eq!(shape_slice[1], (1 << num_qubits) as i64, "Second dimension should be 2^num_qubits");
142+
143+
let strides_slice = std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize);
144+
let state_len = 1 << num_qubits;
145+
assert_eq!(strides_slice[0], state_len as i64, "Stride for first dimension should be state_len");
146+
assert_eq!(strides_slice[1], 1, "Stride for second dimension should be 1");
147+
148+
println!("PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]", shape_slice[0], shape_slice[1]);
149+
println!("PASS: Strides are correct: [{}, {}]", strides_slice[0], strides_slice[1]);
150+
151+
// Free memory
152+
if let Some(deleter) = managed.deleter {
153+
deleter(dlpack_ptr);
154+
}
155+
}
156+
}
157+
158+
#[test]
159+
#[cfg(target_os = "linux")]
160+
fn test_single_encode_dlpack_2d_shape() {
161+
println!("Testing single encode returns 2D shape...");
162+
163+
let engine = match QdpEngine::new(0) {
164+
Ok(e) => e,
165+
Err(_) => {
166+
println!("SKIP: No GPU available");
167+
return;
168+
}
169+
};
170+
171+
let data = common::create_test_data(16);
172+
let result = engine.encode(&data, 4, "amplitude");
173+
assert!(result.is_ok(), "Encoding should succeed");
174+
175+
let dlpack_ptr = result.unwrap();
176+
assert!(!dlpack_ptr.is_null(), "DLPack pointer should not be null");
177+
178+
unsafe {
179+
let managed = &*dlpack_ptr;
180+
let tensor = &managed.dl_tensor;
181+
182+
// Verify 2D shape for single encode: [1, 2^num_qubits]
183+
assert_eq!(tensor.ndim, 2, "Single encode should be 2D");
184+
185+
let shape_slice = std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize);
186+
assert_eq!(shape_slice[0], 1, "First dimension should be 1 for single encode");
187+
assert_eq!(shape_slice[1], 16, "Second dimension should be [2^4]");
188+
189+
let strides_slice = std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize);
190+
assert_eq!(strides_slice[0], 16, "Stride for first dimension should be state_len");
191+
assert_eq!(strides_slice[1], 1, "Stride for second dimension should be 1");
192+
193+
println!(
194+
"PASS: Single encode returns 2D shape: [{}, {}]",
195+
shape_slice[0], shape_slice[1]
196+
);
197+
198+
// Free memory
199+
if let Some(deleter) = managed.deleter {
200+
deleter(dlpack_ptr);
201+
}
202+
}
203+
}
204+
111205
#[test]
112206
#[cfg(target_os = "linux")]
113207
fn test_dlpack_device_id() {

qdp/qdp-core/tests/memory_safety.rs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,26 @@ fn test_dlpack_tensor_metadata_default() {
106106
let managed = &mut *ptr;
107107
let tensor = &managed.dl_tensor;
108108

109-
assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
109+
assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
110110
assert!(!tensor.data.is_null(), "Data pointer should be valid");
111111
assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
112112
assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
113113

114-
let shape = *tensor.shape;
115-
assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
114+
let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize);
115+
assert_eq!(shape[0], 1, "First dimension should be 1 for single encode");
116+
assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
116117

117-
let stride = *tensor.strides;
118-
assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
118+
let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize);
119+
assert_eq!(strides[0], 1024, "Stride for first dimension should be state_len");
120+
assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
119121

120122
assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
121-
assert_eq!(tensor.dtype.bits, 64, "Should be 64 bits (2x32-bit floats)");
123+
assert_eq!(tensor.dtype.bits, 128, "Should be 128 bits (2x64-bit floats, Float64)");
122124

123125
println!("PASS: DLPack metadata verified");
124126
println!(" ndim: {}", tensor.ndim);
125-
println!(" shape: {}", shape);
126-
println!(" stride: {}", stride);
127+
println!(" shape: [{}, {}]", shape[0], shape[1]);
128+
println!(" strides: [{}, {}]", strides[0], strides[1]);
127129
println!(
128130
" dtype: code={}, bits={}",
129131
tensor.dtype.code, tensor.dtype.bits
@@ -154,16 +156,18 @@ fn test_dlpack_tensor_metadata_f64() {
154156
let managed = &mut *ptr;
155157
let tensor = &managed.dl_tensor;
156158

157-
assert_eq!(tensor.ndim, 1, "Should be 1D tensor");
159+
assert_eq!(tensor.ndim, 2, "Should be 2D tensor");
158160
assert!(!tensor.data.is_null(), "Data pointer should be valid");
159161
assert!(!tensor.shape.is_null(), "Shape pointer should be valid");
160162
assert!(!tensor.strides.is_null(), "Strides pointer should be valid");
161163

162-
let shape = *tensor.shape;
163-
assert_eq!(shape, 1024, "Shape should be 1024 (2^10)");
164+
let shape = std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize);
165+
assert_eq!(shape[0], 1, "First dimension should be 1 for single encode");
166+
assert_eq!(shape[1], 1024, "Second dimension should be 1024 (2^10)");
164167

165-
let stride = *tensor.strides;
166-
assert_eq!(stride, 1, "Stride for 1D contiguous array should be 1");
168+
let strides = std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize);
169+
assert_eq!(strides[0], 1024, "Stride for first dimension should be state_len");
170+
assert_eq!(strides[1], 1, "Stride for second dimension should be 1");
167171

168172
assert_eq!(tensor.dtype.code, 5, "Should be complex type (code=5)");
169173
assert_eq!(
@@ -173,8 +177,8 @@ fn test_dlpack_tensor_metadata_f64() {
173177

174178
println!("PASS: DLPack metadata verified");
175179
println!(" ndim: {}", tensor.ndim);
176-
println!(" shape: {}", shape);
177-
println!(" stride: {}", stride);
180+
println!(" shape: [{}, {}]", shape[0], shape[1]);
181+
println!(" strides: [{}, {}]", strides[0], strides[1]);
178182
println!(
179183
" dtype: code={}, bits={}",
180184
tensor.dtype.code, tensor.dtype.bits

0 commit comments

Comments
 (0)