3434import torch .nn as nn
3535import numpy as np
3636import os
37+ import itertools
3738import pyarrow as pa
3839import pyarrow .parquet as pq
3940from mahout_qdp import QdpEngine
@@ -73,10 +74,6 @@ def generate_data(n_qubits, n_samples):
7374 if os .path .exists (DATA_FILE ):
7475 os .remove (DATA_FILE )
7576
76- MAHOUT_DATA_FILE = DATA_FILE .replace (".parquet" , "_mahout.parquet" )
77- if os .path .exists (MAHOUT_DATA_FILE ):
78- os .remove (MAHOUT_DATA_FILE )
79-
8077 print (f"Generating { n_samples } samples of { n_qubits } qubits..." )
8178 dim = 1 << n_qubits
8279
@@ -93,20 +90,9 @@ def generate_data(n_qubits, n_samples):
9390 batch_table = pa .Table .from_arrays ([arrays ], names = ["feature_vector" ])
9491 writer .write_table (batch_table )
9592
96- # Generate for Mahout (flat Float64 format, one sample per batch)
97- schema_flat = pa .schema ([("data" , pa .float64 ())])
98- with pq .ParquetWriter (MAHOUT_DATA_FILE , schema_flat ) as writer :
99- for i in range (n_samples ):
100- sample_data = np .random .rand (dim ).astype (np .float64 )
101- array = pa .array (sample_data , type = pa .float64 ())
102- batch_table = pa .Table .from_arrays ([array ], names = ["data" ])
103- writer .write_table (batch_table )
104-
10593 file_size_mb = os .path .getsize (DATA_FILE ) / (1024 * 1024 )
106- mahout_size_mb = os .path .getsize (MAHOUT_DATA_FILE ) / (1024 * 1024 )
10794 print (f" Generated { n_samples } samples" )
108- print (f" PennyLane/Qiskit format: { file_size_mb :.2f} MB" )
109- print (f" Mahout format: { mahout_size_mb :.2f} MB" )
95+ print (f" Parquet file size: { file_size_mb :.2f} MB" )
11096
11197
11298# -----------------------------------------------------------
@@ -115,7 +101,7 @@ def generate_data(n_qubits, n_samples):
115101def run_qiskit (n_qubits , n_samples ):
116102 if not HAS_QISKIT :
117103 print ("\n [Qiskit] Not installed, skipping." )
118- return 0.0
104+ return 0.0 , None
119105
120106 print ("\n [Qiskit] Full Pipeline (Disk -> GPU)..." )
121107 model = DummyQNN (n_qubits ).cuda ()
@@ -132,6 +118,8 @@ def run_qiskit(n_qubits, n_samples):
132118 io_time = time .perf_counter () - start_time
133119 print (f" IO Time: { io_time :.4f} s" )
134120
121+ all_qiskit_states = []
122+
135123 # Process batches
136124 for i in range (0 , n_samples , BATCH_SIZE ):
137125 batch = raw_data [i : i + BATCH_SIZE ]
@@ -158,12 +146,15 @@ def run_qiskit(n_qubits, n_samples):
158146 gpu_tensor = torch .tensor (
159147 np .array (batch_states ), device = "cuda" , dtype = torch .complex64
160148 )
149+ all_qiskit_states .append (gpu_tensor )
161150 _ = model (gpu_tensor .abs ())
162151
163152 torch .cuda .synchronize ()
164153 total_time = time .perf_counter () - start_time
165154 print (f"\n Total Time: { total_time :.4f} s" )
166- return total_time
155+
156+ all_qiskit_tensor = torch .cat (all_qiskit_states , dim = 0 )
157+ return total_time , all_qiskit_tensor
167158
168159
169160# -----------------------------------------------------------
@@ -172,7 +163,7 @@ def run_qiskit(n_qubits, n_samples):
172163def run_pennylane (n_qubits , n_samples ):
173164 if not HAS_PENNYLANE :
174165 print ("\n [PennyLane] Not installed, skipping." )
175- return 0.0
166+ return 0.0 , None
176167
177168 print ("\n [PennyLane] Full Pipeline (Disk -> GPU)..." )
178169
@@ -198,6 +189,8 @@ def circuit(inputs):
198189 io_time = time .perf_counter () - start_time
199190 print (f" IO Time: { io_time :.4f} s" )
200191
192+ all_pl_states = []
193+
201194 # Process batches
202195 for i in range (0 , n_samples , BATCH_SIZE ):
203196 batch_cpu = torch .tensor (raw_data [i : i + BATCH_SIZE ])
@@ -208,14 +201,22 @@ def circuit(inputs):
208201 except Exception :
209202 state_cpu = torch .stack ([circuit (x ) for x in batch_cpu ])
210203
204+ all_pl_states .append (state_cpu )
205+
211206 # Transfer to GPU
212207 state_gpu = state_cpu .to ("cuda" , dtype = torch .float32 )
213208 _ = model (state_gpu .abs ())
214209
215210 torch .cuda .synchronize ()
216211 total_time = time .perf_counter () - start_time
217212 print (f" Total Time: { total_time :.4f} s" )
218- return total_time
213+
214+ # Stack all collected states
215+ all_pl_states_tensor = torch .cat (
216+ all_pl_states , dim = 0
217+ ) # Should handle cases where last batch is smaller
218+
219+ return total_time , all_pl_states_tensor
219220
220221
221222# -----------------------------------------------------------
@@ -224,28 +225,31 @@ def circuit(inputs):
224225def run_mahout (engine , n_qubits , n_samples ):
225226 print ("\n [Mahout] Full Pipeline (Disk -> GPU)..." )
226227 model = DummyQNN (n_qubits ).cuda ()
227- MAHOUT_DATA_FILE = DATA_FILE .replace (".parquet" , "_mahout.parquet" )
228228
229229 torch .cuda .synchronize ()
230230 start_time = time .perf_counter ()
231231
232- # Read Parquet and encode all samples
233- import pyarrow .parquet as pq
234-
235- parquet_file = pq .ParquetFile (MAHOUT_DATA_FILE )
232+ # Direct Parquet to GPU pipeline
233+ parquet_encode_start = time .perf_counter ()
234+ batched_tensor = engine .encode_from_parquet (DATA_FILE , n_qubits , "amplitude" )
235+ parquet_encode_time = time .perf_counter () - parquet_encode_start
236+ print (f" Parquet->GPU (IO+Encode): { parquet_encode_time :.4f} s" )
236237
237- all_states = []
238- for batch in parquet_file .iter_batches ():
239- sample_data = batch .column (0 ).to_numpy ()
240- qtensor = engine .encode (sample_data .tolist (), n_qubits , "amplitude" )
241- gpu_state = torch .from_dlpack (qtensor )
242- all_states .append (gpu_state )
238+ # Convert to torch tensor (single DLPack call)
239+ dlpack_start = time .perf_counter ()
240+ gpu_batched = torch .from_dlpack (batched_tensor )
241+ dlpack_time = time .perf_counter () - dlpack_start
242+ print (f" DLPack conversion: { dlpack_time :.4f} s" )
243243
244- # Stack and convert
245- gpu_all_data = torch .stack (all_states ).abs ().to (torch .float32 )
244+ # Reshape to [n_samples, state_len] (still complex)
245+ state_len = 1 << n_qubits
246+ gpu_reshaped = gpu_batched .view (n_samples , state_len )
246247
247- encode_time = time .perf_counter () - start_time
248- print (f" IO + Encode Time: { encode_time :.4f} s" )
248+ # Convert to float for model (batch already on GPU)
249+ reshape_start = time .perf_counter ()
250+ gpu_all_data = gpu_reshaped .abs ().to (torch .float32 )
251+ reshape_time = time .perf_counter () - reshape_start
252+ print (f" Reshape & convert: { reshape_time :.4f} s" )
249253
250254 # Forward pass (data already on GPU)
251255 for i in range (0 , n_samples , BATCH_SIZE ):
@@ -255,7 +259,46 @@ def run_mahout(engine, n_qubits, n_samples):
255259 torch .cuda .synchronize ()
256260 total_time = time .perf_counter () - start_time
257261 print (f" Total Time: { total_time :.4f} s" )
258- return total_time
262+ return total_time , gpu_reshaped
263+
264+
265+ def compare_states (name_a , states_a , name_b , states_b ):
266+ print ("\n " + "=" * 70 )
267+ print (f"VERIFICATION ({ name_a } vs { name_b } )" )
268+ print ("=" * 70 )
269+
270+ # Ensure both tensors are on GPU for comparison
271+ n_compare = min (len (states_a ), len (states_b ))
272+ tensor_a = states_a [:n_compare ].cuda ()
273+ tensor_b = states_b [:n_compare ].cuda ()
274+
275+ # Compare Probabilities (|psi|^2)
276+ diff_probs = (tensor_a .abs () ** 2 - tensor_b .abs () ** 2 ).abs ().max ().item ()
277+ print (f"Max Probability Difference: { diff_probs :.2e} " )
278+
279+ # Compare Raw Amplitudes
280+ # We compare full complex difference magnitude
281+ diff_amps = (tensor_a - tensor_b ).abs ().max ().item ()
282+ print (f"Max Amplitude Difference: { diff_amps :.2e} " )
283+
284+ if diff_probs < 1e-5 :
285+ print (">> SUCCESS: Quantum States Match!" )
286+ else :
287+ print (">> FAILURE: States do not match." )
288+
289+
290+ def verify_correctness (states_dict ):
291+ # Filter out None values
292+ valid_states = {
293+ name : states for name , states in states_dict .items () if states is not None
294+ }
295+
296+ if len (valid_states ) < 2 :
297+ return
298+
299+ keys = sorted (list (valid_states .keys ()))
300+ for name_a , name_b in itertools .combinations (keys , 2 ):
301+ compare_states (name_a , valid_states [name_a ], name_b , valid_states [name_b ])
259302
260303
261304if __name__ == "__main__" :
@@ -268,8 +311,26 @@ def run_mahout(engine, n_qubits, n_samples):
268311 parser .add_argument (
269312 "--samples" , type = int , default = 200 , help = "Number of training samples"
270313 )
314+ parser .add_argument (
315+ "--frameworks" ,
316+ nargs = "+" ,
317+ default = ["mahout" , "pennylane" ],
318+ choices = ["mahout" , "pennylane" , "qiskit" , "all" ],
319+ help = "Frameworks to benchmark (default: mahout pennylane). Use 'all' to run all available frameworks." ,
320+ )
271321 args = parser .parse_args ()
272322
323+ # Expand "all" option
324+ if "all" in args .frameworks :
325+ all_frameworks = []
326+ if "mahout" in args .frameworks or "all" in args .frameworks :
327+ all_frameworks .append ("mahout" )
328+ if "pennylane" in args .frameworks or "all" in args .frameworks :
329+ all_frameworks .append ("pennylane" )
330+ if "qiskit" in args .frameworks or "all" in args .frameworks :
331+ all_frameworks .append ("qiskit" )
332+ args .frameworks = all_frameworks
333+
273334 generate_data (args .qubits , args .samples )
274335
275336 try :
@@ -282,10 +343,20 @@ def run_mahout(engine, n_qubits, n_samples):
282343 print (f"E2E BENCHMARK: { args .qubits } Qubits, { args .samples } Samples" )
283344 print ("=" * 70 )
284345
346+ # Initialize results
347+ t_pl , pl_all_states = 0.0 , None
348+ t_mahout , mahout_all_states = 0.0 , None
349+ t_qiskit , qiskit_all_states = 0.0 , None
350+
285351 # Run benchmarks
286- t_pl = run_pennylane (args .qubits , args .samples )
287- t_qiskit = run_qiskit (args .qubits , args .samples )
288- t_mahout = run_mahout (engine , args .qubits , args .samples )
352+ if "pennylane" in args .frameworks :
353+ t_pl , pl_all_states = run_pennylane (args .qubits , args .samples )
354+
355+ if "qiskit" in args .frameworks :
356+ t_qiskit , qiskit_all_states = run_qiskit (args .qubits , args .samples )
357+
358+ if "mahout" in args .frameworks :
359+ t_mahout , mahout_all_states = run_mahout (engine , args .qubits , args .samples )
289360
290361 print ("\n " + "=" * 70 )
291362 print ("E2E LATENCY (Lower is Better)" )
@@ -311,3 +382,12 @@ def run_mahout(engine, n_qubits, n_samples):
311382 print (f"Speedup vs PennyLane: { t_pl / t_mahout :10.2f} x" )
312383 if t_qiskit > 0 :
313384 print (f"Speedup vs Qiskit: { t_qiskit / t_mahout :10.2f} x" )
385+
386+ # Run Verification after benchmarks
387+ verify_correctness (
388+ {
389+ "Mahout" : mahout_all_states ,
390+ "PennyLane" : pl_all_states ,
391+ "Qiskit" : qiskit_all_states ,
392+ }
393+ )
0 commit comments