Skip to content

Commit ce4b7ca

Browse files
authored
[QDP] add batch kernel support (#700)
* [QDP] Add batch encoding support * Refactor batch pre-processing
1 parent 335036a commit ce4b7ca

File tree

11 files changed

+583
-193
lines changed

11 files changed

+583
-193
lines changed

qdp/benchmark/benchmark_e2e_final.py

Lines changed: 119 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torch.nn as nn
3535
import numpy as np
3636
import os
37+
import itertools
3738
import pyarrow as pa
3839
import pyarrow.parquet as pq
3940
from mahout_qdp import QdpEngine
@@ -73,10 +74,6 @@ def generate_data(n_qubits, n_samples):
7374
if os.path.exists(DATA_FILE):
7475
os.remove(DATA_FILE)
7576

76-
MAHOUT_DATA_FILE = DATA_FILE.replace(".parquet", "_mahout.parquet")
77-
if os.path.exists(MAHOUT_DATA_FILE):
78-
os.remove(MAHOUT_DATA_FILE)
79-
8077
print(f"Generating {n_samples} samples of {n_qubits} qubits...")
8178
dim = 1 << n_qubits
8279

@@ -93,20 +90,9 @@ def generate_data(n_qubits, n_samples):
9390
batch_table = pa.Table.from_arrays([arrays], names=["feature_vector"])
9491
writer.write_table(batch_table)
9592

96-
# Generate for Mahout (flat Float64 format, one sample per batch)
97-
schema_flat = pa.schema([("data", pa.float64())])
98-
with pq.ParquetWriter(MAHOUT_DATA_FILE, schema_flat) as writer:
99-
for i in range(n_samples):
100-
sample_data = np.random.rand(dim).astype(np.float64)
101-
array = pa.array(sample_data, type=pa.float64())
102-
batch_table = pa.Table.from_arrays([array], names=["data"])
103-
writer.write_table(batch_table)
104-
10593
file_size_mb = os.path.getsize(DATA_FILE) / (1024 * 1024)
106-
mahout_size_mb = os.path.getsize(MAHOUT_DATA_FILE) / (1024 * 1024)
10794
print(f" Generated {n_samples} samples")
108-
print(f" PennyLane/Qiskit format: {file_size_mb:.2f} MB")
109-
print(f" Mahout format: {mahout_size_mb:.2f} MB")
95+
print(f" Parquet file size: {file_size_mb:.2f} MB")
11096

11197

11298
# -----------------------------------------------------------
@@ -115,7 +101,7 @@ def generate_data(n_qubits, n_samples):
115101
def run_qiskit(n_qubits, n_samples):
116102
if not HAS_QISKIT:
117103
print("\n[Qiskit] Not installed, skipping.")
118-
return 0.0
104+
return 0.0, None
119105

120106
print("\n[Qiskit] Full Pipeline (Disk -> GPU)...")
121107
model = DummyQNN(n_qubits).cuda()
@@ -132,6 +118,8 @@ def run_qiskit(n_qubits, n_samples):
132118
io_time = time.perf_counter() - start_time
133119
print(f" IO Time: {io_time:.4f} s")
134120

121+
all_qiskit_states = []
122+
135123
# Process batches
136124
for i in range(0, n_samples, BATCH_SIZE):
137125
batch = raw_data[i : i + BATCH_SIZE]
@@ -158,12 +146,15 @@ def run_qiskit(n_qubits, n_samples):
158146
gpu_tensor = torch.tensor(
159147
np.array(batch_states), device="cuda", dtype=torch.complex64
160148
)
149+
all_qiskit_states.append(gpu_tensor)
161150
_ = model(gpu_tensor.abs())
162151

163152
torch.cuda.synchronize()
164153
total_time = time.perf_counter() - start_time
165154
print(f"\n Total Time: {total_time:.4f} s")
166-
return total_time
155+
156+
all_qiskit_tensor = torch.cat(all_qiskit_states, dim=0)
157+
return total_time, all_qiskit_tensor
167158

168159

169160
# -----------------------------------------------------------
@@ -172,7 +163,7 @@ def run_qiskit(n_qubits, n_samples):
172163
def run_pennylane(n_qubits, n_samples):
173164
if not HAS_PENNYLANE:
174165
print("\n[PennyLane] Not installed, skipping.")
175-
return 0.0
166+
return 0.0, None
176167

177168
print("\n[PennyLane] Full Pipeline (Disk -> GPU)...")
178169

@@ -198,6 +189,8 @@ def circuit(inputs):
198189
io_time = time.perf_counter() - start_time
199190
print(f" IO Time: {io_time:.4f} s")
200191

192+
all_pl_states = []
193+
201194
# Process batches
202195
for i in range(0, n_samples, BATCH_SIZE):
203196
batch_cpu = torch.tensor(raw_data[i : i + BATCH_SIZE])
@@ -208,14 +201,22 @@ def circuit(inputs):
208201
except Exception:
209202
state_cpu = torch.stack([circuit(x) for x in batch_cpu])
210203

204+
all_pl_states.append(state_cpu)
205+
211206
# Transfer to GPU
212207
state_gpu = state_cpu.to("cuda", dtype=torch.float32)
213208
_ = model(state_gpu.abs())
214209

215210
torch.cuda.synchronize()
216211
total_time = time.perf_counter() - start_time
217212
print(f" Total Time: {total_time:.4f} s")
218-
return total_time
213+
214+
# Stack all collected states
215+
all_pl_states_tensor = torch.cat(
216+
all_pl_states, dim=0
217+
) # Should handle cases where last batch is smaller
218+
219+
return total_time, all_pl_states_tensor
219220

220221

221222
# -----------------------------------------------------------
@@ -224,28 +225,31 @@ def circuit(inputs):
224225
def run_mahout(engine, n_qubits, n_samples):
225226
print("\n[Mahout] Full Pipeline (Disk -> GPU)...")
226227
model = DummyQNN(n_qubits).cuda()
227-
MAHOUT_DATA_FILE = DATA_FILE.replace(".parquet", "_mahout.parquet")
228228

229229
torch.cuda.synchronize()
230230
start_time = time.perf_counter()
231231

232-
# Read Parquet and encode all samples
233-
import pyarrow.parquet as pq
234-
235-
parquet_file = pq.ParquetFile(MAHOUT_DATA_FILE)
232+
# Direct Parquet to GPU pipeline
233+
parquet_encode_start = time.perf_counter()
234+
batched_tensor = engine.encode_from_parquet(DATA_FILE, n_qubits, "amplitude")
235+
parquet_encode_time = time.perf_counter() - parquet_encode_start
236+
print(f" Parquet->GPU (IO+Encode): {parquet_encode_time:.4f} s")
236237

237-
all_states = []
238-
for batch in parquet_file.iter_batches():
239-
sample_data = batch.column(0).to_numpy()
240-
qtensor = engine.encode(sample_data.tolist(), n_qubits, "amplitude")
241-
gpu_state = torch.from_dlpack(qtensor)
242-
all_states.append(gpu_state)
238+
# Convert to torch tensor (single DLPack call)
239+
dlpack_start = time.perf_counter()
240+
gpu_batched = torch.from_dlpack(batched_tensor)
241+
dlpack_time = time.perf_counter() - dlpack_start
242+
print(f" DLPack conversion: {dlpack_time:.4f} s")
243243

244-
# Stack and convert
245-
gpu_all_data = torch.stack(all_states).abs().to(torch.float32)
244+
# Reshape to [n_samples, state_len] (still complex)
245+
state_len = 1 << n_qubits
246+
gpu_reshaped = gpu_batched.view(n_samples, state_len)
246247

247-
encode_time = time.perf_counter() - start_time
248-
print(f" IO + Encode Time: {encode_time:.4f} s")
248+
# Convert to float for model (batch already on GPU)
249+
reshape_start = time.perf_counter()
250+
gpu_all_data = gpu_reshaped.abs().to(torch.float32)
251+
reshape_time = time.perf_counter() - reshape_start
252+
print(f" Reshape & convert: {reshape_time:.4f} s")
249253

250254
# Forward pass (data already on GPU)
251255
for i in range(0, n_samples, BATCH_SIZE):
@@ -255,7 +259,46 @@ def run_mahout(engine, n_qubits, n_samples):
255259
torch.cuda.synchronize()
256260
total_time = time.perf_counter() - start_time
257261
print(f" Total Time: {total_time:.4f} s")
258-
return total_time
262+
return total_time, gpu_reshaped
263+
264+
265+
def compare_states(name_a, states_a, name_b, states_b):
266+
print("\n" + "=" * 70)
267+
print(f"VERIFICATION ({name_a} vs {name_b})")
268+
print("=" * 70)
269+
270+
# Ensure both tensors are on GPU for comparison
271+
n_compare = min(len(states_a), len(states_b))
272+
tensor_a = states_a[:n_compare].cuda()
273+
tensor_b = states_b[:n_compare].cuda()
274+
275+
# Compare Probabilities (|psi|^2)
276+
diff_probs = (tensor_a.abs() ** 2 - tensor_b.abs() ** 2).abs().max().item()
277+
print(f"Max Probability Difference: {diff_probs:.2e}")
278+
279+
# Compare Raw Amplitudes
280+
# We compare full complex difference magnitude
281+
diff_amps = (tensor_a - tensor_b).abs().max().item()
282+
print(f"Max Amplitude Difference: {diff_amps:.2e}")
283+
284+
if diff_probs < 1e-5:
285+
print(">> SUCCESS: Quantum States Match!")
286+
else:
287+
print(">> FAILURE: States do not match.")
288+
289+
290+
def verify_correctness(states_dict):
291+
# Filter out None values
292+
valid_states = {
293+
name: states for name, states in states_dict.items() if states is not None
294+
}
295+
296+
if len(valid_states) < 2:
297+
return
298+
299+
keys = sorted(list(valid_states.keys()))
300+
for name_a, name_b in itertools.combinations(keys, 2):
301+
compare_states(name_a, valid_states[name_a], name_b, valid_states[name_b])
259302

260303

261304
if __name__ == "__main__":
@@ -268,8 +311,26 @@ def run_mahout(engine, n_qubits, n_samples):
268311
parser.add_argument(
269312
"--samples", type=int, default=200, help="Number of training samples"
270313
)
314+
parser.add_argument(
315+
"--frameworks",
316+
nargs="+",
317+
default=["mahout", "pennylane"],
318+
choices=["mahout", "pennylane", "qiskit", "all"],
319+
help="Frameworks to benchmark (default: mahout pennylane). Use 'all' to run all available frameworks.",
320+
)
271321
args = parser.parse_args()
272322

323+
# Expand "all" option
324+
if "all" in args.frameworks:
325+
all_frameworks = []
326+
if "mahout" in args.frameworks or "all" in args.frameworks:
327+
all_frameworks.append("mahout")
328+
if "pennylane" in args.frameworks or "all" in args.frameworks:
329+
all_frameworks.append("pennylane")
330+
if "qiskit" in args.frameworks or "all" in args.frameworks:
331+
all_frameworks.append("qiskit")
332+
args.frameworks = all_frameworks
333+
273334
generate_data(args.qubits, args.samples)
274335

275336
try:
@@ -282,10 +343,20 @@ def run_mahout(engine, n_qubits, n_samples):
282343
print(f"E2E BENCHMARK: {args.qubits} Qubits, {args.samples} Samples")
283344
print("=" * 70)
284345

346+
# Initialize results
347+
t_pl, pl_all_states = 0.0, None
348+
t_mahout, mahout_all_states = 0.0, None
349+
t_qiskit, qiskit_all_states = 0.0, None
350+
285351
# Run benchmarks
286-
t_pl = run_pennylane(args.qubits, args.samples)
287-
t_qiskit = run_qiskit(args.qubits, args.samples)
288-
t_mahout = run_mahout(engine, args.qubits, args.samples)
352+
if "pennylane" in args.frameworks:
353+
t_pl, pl_all_states = run_pennylane(args.qubits, args.samples)
354+
355+
if "qiskit" in args.frameworks:
356+
t_qiskit, qiskit_all_states = run_qiskit(args.qubits, args.samples)
357+
358+
if "mahout" in args.frameworks:
359+
t_mahout, mahout_all_states = run_mahout(engine, args.qubits, args.samples)
289360

290361
print("\n" + "=" * 70)
291362
print("E2E LATENCY (Lower is Better)")
@@ -311,3 +382,12 @@ def run_mahout(engine, n_qubits, n_samples):
311382
print(f"Speedup vs PennyLane: {t_pl / t_mahout:10.2f}x")
312383
if t_qiskit > 0:
313384
print(f"Speedup vs Qiskit: {t_qiskit / t_mahout:10.2f}x")
385+
386+
# Run Verification after benchmarks
387+
verify_correctness(
388+
{
389+
"Mahout": mahout_all_states,
390+
"PennyLane": pl_all_states,
391+
"Qiskit": qiskit_all_states,
392+
}
393+
)

qdp/qdp-core/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ pub enum MahoutError {
3636

3737
#[error("I/O error: {0}")]
3838
Io(String),
39+
40+
#[error("Not implemented: {0}")]
41+
NotImplemented(String),
3942
}
4043

4144
/// Result type alias for Mahout operations

0 commit comments

Comments
 (0)