@@ -55,9 +55,7 @@ fn test_amplitude_encoding_workflow() {
5555 println ! ( "Created test data: {} elements" , data. len( ) ) ;
5656
5757 let result = engine. encode ( & data, 10 , "amplitude" ) ;
58- assert ! ( result. is_ok( ) , "Encoding should succeed" ) ;
59-
60- let dlpack_ptr = result. unwrap ( ) ;
58+ let dlpack_ptr = result. expect ( "Encoding should succeed" ) ;
6159 assert ! ( !dlpack_ptr. is_null( ) , "DLPack pointer should not be null" ) ;
6260 println ! ( "PASS: Encoding succeeded, DLPack pointer valid" ) ;
6361
@@ -91,9 +89,7 @@ fn test_amplitude_encoding_async_pipeline() {
9189 println ! ( "Created test data: {} elements" , data. len( ) ) ;
9290
9391 let result = engine. encode ( & data, 18 , "amplitude" ) ;
94- assert ! ( result. is_ok( ) , "Encoding should succeed" ) ;
95-
96- let dlpack_ptr = result. unwrap ( ) ;
92+ let dlpack_ptr = result. expect ( "Encoding should succeed" ) ;
9793 assert ! ( !dlpack_ptr. is_null( ) , "DLPack pointer should not be null" ) ;
9894 println ! ( "PASS: Encoding succeeded, DLPack pointer valid" ) ;
9995
@@ -108,6 +104,104 @@ fn test_amplitude_encoding_async_pipeline() {
108104 }
109105}
110106
107+ #[ test]
108+ #[ cfg( target_os = "linux" ) ]
109+ fn test_batch_dlpack_2d_shape ( ) {
110+ println ! ( "Testing batch DLPack 2D shape..." ) ;
111+
112+ let engine = match QdpEngine :: new ( 0 ) {
113+ Ok ( e) => e,
114+ Err ( _) => {
115+ println ! ( "SKIP: No GPU available" ) ;
116+ return ;
117+ }
118+ } ;
119+
120+ // Create batch data: 3 samples, each with 4 elements (2 qubits)
121+ let num_samples = 3 ;
122+ let num_qubits = 2 ;
123+ let sample_size = 4 ;
124+ let batch_data: Vec < f64 > = ( 0 ..num_samples * sample_size)
125+ . map ( |i| ( i as f64 ) / 10.0 )
126+ . collect ( ) ;
127+
128+ let result = engine. encode_batch ( & batch_data, num_samples, sample_size, num_qubits, "amplitude" ) ;
129+ let dlpack_ptr = result. expect ( "Batch encoding should succeed" ) ;
130+ assert ! ( !dlpack_ptr. is_null( ) , "DLPack pointer should not be null" ) ;
131+
132+ unsafe {
133+ let managed = & * dlpack_ptr;
134+ let tensor = & managed. dl_tensor ;
135+
136+ // Verify 2D shape for batch tensor
137+ assert_eq ! ( tensor. ndim, 2 , "Batch tensor should be 2D" ) ;
138+
139+ let shape_slice = std:: slice:: from_raw_parts ( tensor. shape , tensor. ndim as usize ) ;
140+ assert_eq ! ( shape_slice[ 0 ] , num_samples as i64 , "First dimension should be num_samples" ) ;
141+ assert_eq ! ( shape_slice[ 1 ] , ( 1 << num_qubits) as i64 , "Second dimension should be 2^num_qubits" ) ;
142+
143+ let strides_slice = std:: slice:: from_raw_parts ( tensor. strides , tensor. ndim as usize ) ;
144+ let state_len = 1 << num_qubits;
145+ assert_eq ! ( strides_slice[ 0 ] , state_len as i64 , "Stride for first dimension should be state_len" ) ;
146+ assert_eq ! ( strides_slice[ 1 ] , 1 , "Stride for second dimension should be 1" ) ;
147+
148+ println ! ( "PASS: Batch DLPack tensor has correct 2D shape: [{}, {}]" , shape_slice[ 0 ] , shape_slice[ 1 ] ) ;
149+ println ! ( "PASS: Strides are correct: [{}, {}]" , strides_slice[ 0 ] , strides_slice[ 1 ] ) ;
150+
151+ // Free memory
152+ if let Some ( deleter) = managed. deleter {
153+ deleter ( dlpack_ptr) ;
154+ }
155+ }
156+ }
157+
158+ #[ test]
159+ #[ cfg( target_os = "linux" ) ]
160+ fn test_single_encode_dlpack_2d_shape ( ) {
161+ println ! ( "Testing single encode returns 2D shape..." ) ;
162+
163+ let engine = match QdpEngine :: new ( 0 ) {
164+ Ok ( e) => e,
165+ Err ( _) => {
166+ println ! ( "SKIP: No GPU available" ) ;
167+ return ;
168+ }
169+ } ;
170+
171+ let data = common:: create_test_data ( 16 ) ;
172+ let result = engine. encode ( & data, 4 , "amplitude" ) ;
173+ assert ! ( result. is_ok( ) , "Encoding should succeed" ) ;
174+
175+ let dlpack_ptr = result. unwrap ( ) ;
176+ assert ! ( !dlpack_ptr. is_null( ) , "DLPack pointer should not be null" ) ;
177+
178+ unsafe {
179+ let managed = & * dlpack_ptr;
180+ let tensor = & managed. dl_tensor ;
181+
182+ // Verify 2D shape for single encode: [1, 2^num_qubits]
183+ assert_eq ! ( tensor. ndim, 2 , "Single encode should be 2D" ) ;
184+
185+ let shape_slice = std:: slice:: from_raw_parts ( tensor. shape , tensor. ndim as usize ) ;
186+ assert_eq ! ( shape_slice[ 0 ] , 1 , "First dimension should be 1 for single encode" ) ;
187+ assert_eq ! ( shape_slice[ 1 ] , 16 , "Second dimension should be [2^4]" ) ;
188+
189+ let strides_slice = std:: slice:: from_raw_parts ( tensor. strides , tensor. ndim as usize ) ;
190+ assert_eq ! ( strides_slice[ 0 ] , 16 , "Stride for first dimension should be state_len" ) ;
191+ assert_eq ! ( strides_slice[ 1 ] , 1 , "Stride for second dimension should be 1" ) ;
192+
193+ println ! (
194+ "PASS: Single encode returns 2D shape: [{}, {}]" ,
195+ shape_slice[ 0 ] , shape_slice[ 1 ]
196+ ) ;
197+
198+ // Free memory
199+ if let Some ( deleter) = managed. deleter {
200+ deleter ( dlpack_ptr) ;
201+ }
202+ }
203+ }
204+
111205#[ test]
112206#[ cfg( target_os = "linux" ) ]
113207fn test_dlpack_device_id ( ) {
0 commit comments