Skip to content

Commit 954a3d3

Browse files
committed
Add script to determine ideal number of shards
This re-uses some code from the `monitoring/` dir with random modifications (maybe in the future we could de-duplicate these) but for now the code is fairly spaghetti-esque. `determine_shards.py` helps generate content for PRs like apache/tvm#12473
1 parent 77dfd9d commit 954a3d3

File tree

9 files changed

+931
-0
lines changed

9 files changed

+931
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

dev/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.httpcache/

dev/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# developer scripts
2+
3+
This is a collection of random scripts that are helpful to developing on TVM's CI. As one-off scripts these do not conform to the usual TVM quality standards which is why they are stored out of tree.
4+
5+
## `determine_shards.py`
6+
7+
Given a goal runtime for each test shard and a Jenkins job, print out the number of shards that should be used for each step.
8+
9+
```bash
10+
# print out number of shards per test step
11+
python determine_shards.py --runtime-goal-m 90 --branch PR-12473
12+
13+
# see bottleneck steps individually
14+
python determine_shards.py --runtime-goal-m 90 --branch PR-12473 --list-steps
15+
```

dev/determine_shards.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import argparse
2+
import asyncio
3+
import re
4+
import statistics
5+
import math
6+
import rich
7+
8+
from typing import *
9+
10+
from utils import forward
11+
from utils.forward import *
12+
13+
14+
def is_parallelizable(name: str, desc: str) -> bool:
15+
descs = {
16+
"Run CPU integration tests",
17+
"Run Hexagon tests",
18+
"Run Python GPU integration tests",
19+
"Run Python GPU unit tests",
20+
"Run Python frontend tests",
21+
"Run Python unit tests",
22+
"Run VTA tests in FSIM",
23+
"Run VTA tests in TSIM",
24+
"Run i386 integration tests",
25+
"Run test_arm_compute_lib test",
26+
"Run TOPI tests",
27+
"Run microTVM tests",
28+
}
29+
if name in descs:
30+
return True
31+
return False
32+
33+
34+
def analyze_stages(stage_name: str, stages: List[Stage], goal_runtime_m: float):
35+
steps_across_shards = {}
36+
for stage in stages:
37+
for step in stage.steps:
38+
if step.name not in steps_across_shards:
39+
steps_across_shards[step.name] = []
40+
steps_across_shards[step.name].append(step)
41+
42+
fixed_runtime_m = 0
43+
parallelizable_runtime_m = 0
44+
for name, steps in steps_across_shards.items():
45+
parallelizable = is_parallelizable(name, "")
46+
median_runtime_m = (
47+
statistics.median([step.duration_ms for step in steps]) / 1000.0 / 60.0
48+
)
49+
total_runtime_m = sum([step.duration_ms for step in steps]) / 1000.0 / 60.0
50+
if parallelizable:
51+
parallelizable_runtime_m += total_runtime_m
52+
else:
53+
fixed_runtime_m += median_runtime_m
54+
55+
parallel_part = goal_runtime_m - fixed_runtime_m
56+
print(stage_name)
57+
if parallel_part <= 0:
58+
print(
59+
f" fixed runtime is too long ({round(fixed_runtime_m, 2)}), cannot reach goal time"
60+
)
61+
return
62+
63+
num_shards = parallelizable_runtime_m / parallel_part
64+
num_shards = math.ceil(num_shards)
65+
66+
print(f" fixed runtime (m): {round(fixed_runtime_m, 2)}")
67+
print(f" parallel runtime (m): {round(parallelizable_runtime_m, 2)}")
68+
print(f" required shards: {num_shards}")
69+
70+
71+
def list_steps(build: Build):
72+
def total_rt(stage: Stage):
73+
return sum(step.duration_ms for step in stage.steps)
74+
75+
build.stages = sorted(build.stages, key=total_rt)
76+
print("For build at", build.blue_url)
77+
for stage in build.stages:
78+
if stage.name in {"Build", "Test", "Deploy"}:
79+
continue
80+
total = sum(step.duration_ms for step in stage.steps)
81+
if len(stage.steps) == 0:
82+
rich.print(f"{stage.name}: skipped")
83+
continue
84+
median = statistics.median([step.duration_ms for step in stage.steps])
85+
m75 = statistics.median(
86+
[step.duration_ms for step in stage.steps if step.duration_ms > median]
87+
)
88+
rich.print(f"{stage.name}: {round(total /1000.0/60.0)}m")
89+
for step in stage.steps:
90+
if step.duration_ms > m75:
91+
rich.print(
92+
f" [bold red]{step.name}[/bold red]: {round(step.duration_ms / 1000.0 / 60.0, 2)}"
93+
)
94+
elif step.duration_ms > median:
95+
rich.print(
96+
f" [magenta]{step.name}[/magenta]: {round(step.duration_ms / 1000.0 / 60.0, 2)}"
97+
)
98+
else:
99+
rich.print(
100+
f" {step.name}: {round(step.duration_ms / 1000.0 / 60.0, 2)}"
101+
)
102+
103+
104+
def analyze(build: Build, goal_runtime_m: float):
105+
test_stages: List[Stage] = []
106+
should_add = False
107+
for stage in build.stages:
108+
if stage.name == "Test":
109+
should_add = True
110+
elif stage.name == "Deploy":
111+
should_add = False
112+
elif should_add:
113+
test_stages.append(stage)
114+
115+
names_to_stages = {}
116+
for stage in test_stages:
117+
names_to_stages[stage.name] = stage
118+
119+
merged_shards = {}
120+
for stage in test_stages:
121+
m = re.match(r"(.*) \d+ of \d+", stage.name)
122+
if m:
123+
base_name = m.groups()[0]
124+
if base_name not in merged_shards:
125+
merged_shards[base_name] = []
126+
merged_shards[base_name].append(stage)
127+
else:
128+
merged_shards[stage.name] = [stage]
129+
130+
for name, stages in merged_shards.items():
131+
analyze_stages(name, stages, goal_runtime_m)
132+
133+
134+
async def main(args):
135+
async with aiohttp.ClientSession() as s:
136+
forward.SESSION = s
137+
data = await fetch_branch(name=args.branch)
138+
return data
139+
140+
141+
if __name__ == "__main__":
142+
parser = argparse.ArgumentParser(
143+
description="Determine number of Jenkins shards to use"
144+
)
145+
parser.add_argument("--runtime-goal-m", required=True)
146+
parser.add_argument("--list-steps", action="store_true")
147+
parser.add_argument("--branch", default="main")
148+
parser.add_argument("--build", default="4082")
149+
args = parser.parse_args()
150+
init(dir=".httpcache")
151+
init_log()
152+
153+
branch = asyncio.run(main(args))
154+
build = branch.builds[0]
155+
156+
if args.list_steps:
157+
list_steps(build)
158+
else:
159+
print(f"To reach goal runtime of {args.runtime_goal_m} for tests:")
160+
analyze(build, goal_runtime_m=float(args.runtime_goal_m))

dev/utils/db.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
from sqlalchemy import create_engine
3+
4+
from sqlalchemy.dialects.postgresql import insert
5+
6+
7+
def connection_string(db="tvm"):
8+
host = os.environ["db_host"]
9+
password = os.environ["db_password"]
10+
user = os.environ["db_user"]
11+
12+
if db is None:
13+
return f"postgresql://{user}:{password}@{host}"
14+
else:
15+
return f"postgresql://{user}:{password}@{host}/{db}"
16+
17+
18+
engine = None
19+
20+
21+
def get_engine(connection_string: str):
22+
global engine
23+
if engine is None:
24+
engine = create_engine(connection_string, echo=bool(os.getenv("ECHO", False)))
25+
26+
return engine
27+
28+
29+
def clear_engine():
30+
global engine
31+
engine = None
32+
33+
34+
def upsert(engine, model, insert_dict):
35+
"""
36+
Insert or update to an engine backed by MySQL
37+
"""
38+
inserted = insert(model).values(**insert_dict)
39+
# MySQL version:
40+
# upserted = inserted.on_duplicate_key_update(
41+
# **{k: inserted.inserted[k] for k, v in insert_dict.items()}
42+
# )
43+
44+
# Postgres version:
45+
upserted = inserted.on_conflict_do_update(
46+
index_elements=model._pks,
47+
# index_where=my_table.c.user_email.like("%@gmail.com"),
48+
set_=insert_dict,
49+
)
50+
res = engine.execute(upserted)
51+
return res.lastrowid

0 commit comments

Comments
 (0)