Skip to content

Commit 97171f9

Browse files
Lucio AnderliniLucio Anderlini
authored andcommitted
fixed CompressedHepMCLoader to cope with new bz2 format without path in tar members. Also, reducing 1 copy, extracting and patching in one go.
1 parent bb4db6e commit 97171f9

File tree

1 file changed

+55
-60
lines changed

1 file changed

+55
-60
lines changed

PyLamarr/loaders/CompressedHepMCLoader.py

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
from typing import Optional, Collection, Any
1414

15-
from pkg_resources import require
16-
1715
from PyLamarr import EventBatch
1816

1917
@dataclass
@@ -81,76 +79,73 @@ def _get_number_of_events(self, filename: str, default: int) -> int:
8179
@contextmanager
8280
def archive_mirror(self, filename: str):
8381
tmp_dir = os.path.join(
84-
self.tmpdir,
82+
self.tmpdir,
8583
f"pylamarr.tmp.{random.randint(0, 0xFFFFFF):06x}"
8684
)
85+
self.logger.info(f"Creating temporary directory {tmp_dir}")
8786
os.mkdir(tmp_dir)
88-
tmp_archive = f"{tmp_dir}.tar.bz2"
8987

9088
try:
91-
yield self.files_in_archive(filename, tmp_dir=tmp_dir, tmp_archive=tmp_archive)
89+
yield self.files_in_archive(filename, tmp_dir=tmp_dir)
9290
finally:
9391
self.logger.info(f"Removing directory {tmp_dir}")
9492
shutil.rmtree(tmp_dir)
95-
self.logger.info(f"Removing temporary file {tmp_archive}")
96-
os.remove(tmp_archive)
9793

98-
def copy_and_maybe_patch_hepmc(self, filename):
99-
"Apply patches to the HepMC2 file to avoid segmentation fault in HepMC3 ascii reader"
94+
def copy_and_maybe_patch_hepmc(self, input_file_data: str):
95+
"""
96+
Apply patches to the HepMC2 file to avoid segmentation fault in HepMC3 ascii reader
97+
"""
10098
requires_particle_gun_patch = False
101-
with open(filename) as input_file:
102-
lines = []
103-
for line in input_file:
104-
line = line[:-1] if line[-1] == '\n' else line
105-
if len(line) > 0 and line[0] == 'E': ## Event line
106-
tokens = line.split(" ")
107-
# Documentation at https://hepmc.web.cern.ch/hepmc/releases/HepMC2_user_manual.pdf
108-
# Section 6.2
109-
if int(tokens[6]) == 1: # For Particle Gun process
110-
self._particle_gun_patched_events += 1
111-
n_vertices = int(tokens[8])
112-
tokens[8] = str(n_vertices + 1)
113-
tokens[12 + int(tokens[11])] = str(1)
114-
tokens += ["1.0"]
115-
requires_particle_gun_patch = True
116-
lines += [" ".join(tokens), 'N 1 "0"']
117-
else:
118-
lines.append(line)
119-
elif len(line) > 0 and line[0] == 'V' and requires_particle_gun_patch: # First vertex
120-
# PGUN Patch:
121-
# HepMC3::HepMC2Reader does not tolerate a PV with no incoming particles,
122-
# so we create a fake vertex and a fake beam particle.
123-
vertex_id = line.split(" ")[1]
124-
lines += ["V -99999 0 0 0 0 0 0 0 0", "P 0 0 0. 0. 0. 0. 0. 3 0 0 %s 0" % vertex_id, line]
125-
requires_particle_gun_patch = False
99+
src_lines = input_file_data.split('\n')
100+
if len([li for li in src_lines if li.replace("\n", "").replace(" ", "") != ""]) == 0:
101+
self.logger.warning(f"No valid line found in input file")
102+
dst_lines = []
103+
104+
for line in src_lines:
105+
line = line[:-1] if len(line) > 0 and line[-1] == '\n' else line
106+
if len(line) > 0 and line[0] == 'E': ## Event line
107+
tokens = line.split(" ")
108+
# Documentation at https://hepmc.web.cern.ch/hepmc/releases/HepMC2_user_manual.pdf
109+
# Section 6.2
110+
if int(tokens[6]) == 1: # For Particle Gun process
111+
self._particle_gun_patched_events += 1
112+
n_vertices = int(tokens[8])
113+
tokens[8] = str(n_vertices + 1)
114+
tokens[12 + int(tokens[11])] = str(1)
115+
tokens += ["1.0"]
116+
requires_particle_gun_patch = True
117+
dst_lines += [" ".join(tokens), 'N 1 "0"']
126118
else:
127-
lines.append(line)
128-
129-
return "\n".join(lines)
119+
dst_lines.append(line)
120+
elif len(line) > 0 and line[0] == 'V' and requires_particle_gun_patch: # First vertex
121+
# PGUN Patch:
122+
# HepMC3::HepMC2Reader does not tolerate a PV with no incoming particles,
123+
# so we create a fake vertex and a fake beam particle.
124+
vertex_id = line.split(" ")[1]
125+
dst_lines += ["V -99999 0 0 0 0 0 0 0 0", "P 0 0 0. 0. 0. 0. 0. 3 0 0 %s 0" % vertex_id, line]
126+
requires_particle_gun_patch = False
127+
else:
128+
dst_lines.append(line)
129+
130+
return "\n".join(dst_lines)
130131

131-
def files_in_archive(self, filename: str, tmp_dir: str, tmp_archive: str):
132-
self.logger.info(f"Copying archive to local storage")
133-
shutil.copy(filename, tmp_archive)
134-
self.logger.info(f"Extracting archive {filename} in {tmp_dir}")
135-
with tarfile.open(tmp_archive) as archive:
136-
archive.extractall(tmp_dir)
137-
138-
for (root, dirs, filenames) in os.walk(tmp_dir):
139-
for filename in filenames:
140-
if filename.endswith(".mc2"):
141-
self.logger.info(f"Found {filename} in archive.")
142-
with open(os.path.join(tmp_dir, filename), 'w') as file_copy:
143-
file_copy.write(self.copy_and_maybe_patch_hepmc(os.path.join(root, filename)))
144-
yield os.path.join(tmp_dir, filename)
145-
146-
132+
def files_in_archive(self, filename: str, tmp_dir: str):
133+
with tarfile.open(filename, mode='r:*') as tar:
134+
for member in tar.getmembers():
135+
if member.isfile() and member.name.endswith("mc2"):
136+
key = os.path.basename(member.name)
137+
file_content = tar.extractfile(member).read().decode('utf-8')
138+
patched_filename = os.path.join(tmp_dir, os.path.basename(filename))
139+
with open(patched_filename, 'w') as file_copy:
140+
file_copy.write(self.copy_and_maybe_patch_hepmc(file_content))
141+
yield patched_filename
147142

148143
def load(self, filename: str):
149144
"""
150145
Internal.
151146
"""
152147
if self._db is None:
153-
raise ValueError("PandasLoader tried loading with uninitialized db.\n"
148+
raise ValueError("CompressedHepMCLoader tried loading with uninitialized db.\n"
154149
"Missed ()?")
155150

156151
event_counter = 0
@@ -164,23 +159,23 @@ def load(self, filename: str):
164159
batch_info = dict()
165160
with self.archive_mirror(filename) as files_in_archive:
166161
for i_file, hepmc_file in enumerate(files_in_archive):
167-
run_number = self._get_run_number(filename)
162+
run_number = self._get_run_number(filename)
168163
event_number = self._get_evt_number(hepmc_file, i_file)
169164
n_events = len(batches['event_numbers'])
170165
batch_info.update(dict(
171166
n_events=n_events,
172167
batch_id=batch_counter,
173-
description=f"Run {run_number}",
168+
description=f"Run {run_number}",
174169
_hepmcloader=self._hepmcloader,
175170
))
176-
171+
177172
if tot_events > 0 and self._events_per_batch is not None:
178-
batch_info['n_batches'] = ceil(tot_events/self._events_per_batch)
173+
batch_info['n_batches'] = ceil(tot_events/self._events_per_batch)
179174

180-
if self._max_event is not None and event_counter >= self._max_event:
175+
if self._max_event is not None and event_counter >= self._max_event:
181176
break
182177

183-
if self._events_per_batch is not None and n_events >= self._events_per_batch:
178+
if self._events_per_batch is not None and n_events >= self._events_per_batch:
184179
yield HepMC2EventBatch(
185180
**batch_info,
186181
**batches

0 commit comments

Comments
 (0)