Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions memprof_plotter/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"""


def download_artefact(url: str) -> bytes | None:
def download_artefact(url: str) -> zipfile.ZipFile | None:
"""
PyGithub does not support retrieving artefacts into buffers, so we have to resort
to requests
Expand All @@ -58,13 +58,15 @@ def download_artefact(url: str) -> bytes | None:
return None
zf = zipfile.ZipFile(BytesIO(req.content))
if "tsp_db.sqlite3" in zf.namelist():
return zf.read("tsp_db.sqlite3")
return zf
else:
print("Artefact does not contain required TSP database")
return None


def get_artefacts(nruns: int, workflow: github.Workflow.Workflow, artefact: str, filter: list[str]) -> dict[int, bytes]:
def get_artefacts(
nruns: int, workflow: github.Workflow.Workflow, artefact: str, filter: list[str]
) -> dict[str, zipfile.ZipFile]:
irun = 0
runs = {}
for run in workflow.get_runs(status="success"):
Expand All @@ -84,6 +86,27 @@ def get_artefacts(nruns: int, workflow: github.Workflow.Workflow, artefact: str,
return runs


class Zip_to_sql_conn:
def __init__(self, zip: zipfile.ZipFile):
self.db = zip.read("tsp_db.sqlite3")
self.conn = sqlite3.connect(":memory:")
self.tmpfile = None
if hasattr(self.conn, "deserialize"):
self.conn.deserialize(self.db)
else:
self.tmpfile = tempfile.NamedTemporaryFile()
self.tmpfile.write(self.db)
self.conn = sqlite3.connect(self.tmpfile.name)

def __enter__(self) -> sqlite3.Connection:
return self.conn

def __exit__(self, type, value, traceback):
self.conn.close()
if self.tmpfile:
self.tmpfile.close()


def main():
if gh_token == "BAD_KEY":
raise KeyError("GH_TOKEN must be set in environment")
Expand Down Expand Up @@ -136,35 +159,24 @@ def main():
d_cat = {}
d_names = {}

tmpfile = None

for runid, run in runs.items():
conn = sqlite3.connect(":memory:")
if hasattr(conn, "deserialize"):
conn.deserialize(run)
else:
tmpfile = tempfile.NamedTemporaryFile()
tmpfile.write(run)
conn = sqlite3.connect(tmpfile.name)
cur = conn.cursor()
cur.execute(get_all_cmds_query)
for cmd, cat in cur.fetchall():
d_cat[f"{cat}_{cmd}"] = cat or "other"
d_times[f"{cat}_{cmd}"][runid] = []
d_rss[f"{cat}_{cmd}"][runid] = []
d_names[f"{cat}_{cmd}"] = cmd

try:
cur.execute(get_mem_query)
except sqlite3.OperationalError:
### No such table memprof
continue
for cmd, cat, time, rss in cur.fetchall():
d_times[f"{cat}_{cmd}"][runid].append(time)
d_rss[f"{cat}_{cmd}"][runid].append(rss)
conn.close()
if tmpfile:
tmpfile.close()
for runid, zf in runs.items():
with Zip_to_sql_conn(zf) as conn:
cur = conn.cursor()
cur.execute(get_all_cmds_query)
for cmd, cat in cur.fetchall():
d_cat[f"{cat}_{cmd}"] = cat or "other"
d_times[f"{cat}_{cmd}"][runid] = []
d_rss[f"{cat}_{cmd}"][runid] = []
d_names[f"{cat}_{cmd}"] = cmd

try:
cur.execute(get_mem_query)
except sqlite3.OperationalError:
### No such table memprof
continue
for cmd, cat, time, rss in cur.fetchall():
d_times[f"{cat}_{cmd}"][runid].append(time)
d_rss[f"{cat}_{cmd}"][runid].append(rss)

for k, v in d_rss.items():
os.makedirs(f"{ns.outdir}/{d_cat[k]}", exist_ok=True)
Expand Down