From 27a67bf3166e9a5e8542c46490eb9bfa57ed5432 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Wed, 8 Oct 2025 11:11:30 -0400 Subject: [PATCH 01/33] Added changes/additions to Dockerfile --- Dockerfile | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index b043df81..cc493e68 100644 --- a/Dockerfile +++ b/Dockerfile @@ -124,7 +124,9 @@ ARG MIOPEN_DIR=$ROCM_LIBS_DIR/projects/miopen RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR WORKDIR $MIOPEN_DIR RUN git sparse-checkout set projects/miopen -ARG MIOPEN_BRANCH=4940cf3ec +# not sure what this commit is, using latest develop for now +# ARG MIOPEN_BRANCH=4940cf3ec +ARG MIOPEN_BRANCH=develop RUN git pull && git checkout $MIOPEN_BRANCH ARG PREFIX=/opt/rocm @@ -209,3 +211,26 @@ RUN python3 setup.py install # reset WORKDIR to /tuna WORKDIR /tuna + +# save BASEIMAGE as env variable +ENV BASEIMAGE=${BASEIMAGE} + +# install mysql-server and mysql-client +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + mysql-server \ + mysql-client + +# install redis-server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + redis-server + +# install RabbitMQ server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + rabbitmq-server + +# install iproute2 +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + iproute2 + +# clean up apt cache +RUN apt-get clean && rm -rf /var/lib/apt/lists/* \ No newline at end of file From 11462b3e604c61c806fadacb6e6902aee0163db7 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 16 Oct 2025 08:51:42 +0000 Subject: [PATCH 02/33] auto format --- tuna/mituna_interface.py | 1196 +++++++++++++++++++------------------- 1 file changed, 611 insertions(+), 585 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 140d1542..345a915b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -61,600 +61,626 @@ job_counter_lock = threading.Lock() -class MITunaInterface(): #pylint:disable=too-many-instance-attributes,too-many-public-methods - """ Interface class extended by libraries. The purpose of this class is to define - common functionalities. """ - - def __init__(self, library=Library.MIOPEN) -> None: - - self.self: Library = self - - self.logger: logging.Logger = setup_logger(logger_name=library.value, - add_streamhandler=True) - self.args: argparse.Namespace - - self.fetch_state: set = set() - self.max_job_retries = 10 - self.dbt = None - self.operation = None - self.db_name = os.environ['TUNA_DB_NAME'] - self.prefix = None - - def check_docker(self, - worker: WorkerInterface, - dockername="miopentuna") -> bool: - """! Checking for docker - @param worker The worker interface instance - @param dockername The name of the docker - """ - out2: ChannelFile - _, out2, _ = worker.exec_command("sudo docker info") - while not out2.channel.exit_status_ready(): - self.logger.warning(out2.readline()) - if out2.channel.exit_status > 0: - self.logger.warning( - "docker not installed or failed to run with sudo .... ") - return False - - out: StringIO = StringIO() - line: Optional[str] = None - _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) - return True - if line is None: - self.logger.warning('%s docker image does not exist', dockername) - return False - - return False - - def check_status(self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna") -> bool: - """! Function to check gpu_status - @param worker The worker interface instance - @param b_first Flag to keep track of visited GPU - @param gpu_idx Unique ID of the GPU - @param machine The machine instance - @param dockername The name of the docker - """ - - if machine.chk_gpu_status(worker.gpu_id): - self.logger.info('Machine: (%s, %u) GPU_ID: %u OK', machine.hostname, - machine.port, gpu_idx) - else: - self.logger.info('Machine: (%s, %u) GPU_ID: %u ERROR', machine.hostname, - machine.port, gpu_idx) - - if not b_first: - return False - b_first = False - _, out, _ = worker.exec_command("docker info") - while not out.channel.exit_status_ready(): - pass - - if out.channel.exit_status > 0: - self.check_docker(worker, dockername) - else: - _, out, _ = worker.exec_command(f"docker images | grep {dockername}") - line: Optional[str] = None - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) - break +class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods + """Interface class extended by libraries. The purpose of this class is to define + common functionalities.""" + + def __init__(self, library=Library.MIOPEN) -> None: + + self.self: Library = self + + self.logger: logging.Logger = setup_logger( + logger_name=library.value, add_streamhandler=True + ) + self.args: argparse.Namespace + + self.fetch_state: set = set() + self.max_job_retries = 10 + self.dbt = None + self.operation = None + self.db_name = os.environ["TUNA_DB_NAME"] + self.prefix = None + + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: + """! Checking for docker + @param worker The worker interface instance + @param dockername The name of the docker + """ + out2: ChannelFile + _, out2, _ = worker.exec_command("sudo docker info") + while not out2.channel.exit_status_ready(): + self.logger.warning(out2.readline()) + if out2.channel.exit_status > 0: + self.logger.warning("docker not installed or failed to run with sudo .... ") + return False + + out: StringIO = StringIO() + line: Optional[str] = None + _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + return True + if line is None: + self.logger.warning("%s docker image does not exist", dockername) + return False + + return False + + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: + """! Function to check gpu_status + @param worker The worker interface instance + @param b_first Flag to keep track of visited GPU + @param gpu_idx Unique ID of the GPU + @param machine The machine instance + @param dockername The name of the docker + """ + + if machine.chk_gpu_status(worker.gpu_id): + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) + else: + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) + + if not b_first: + return False + b_first = False + _, out, _ = worker.exec_command("docker info") + while not out.channel.exit_status_ready(): + pass + + if out.channel.exit_status > 0: + self.check_docker(worker, dockername) + else: + _, out, _ = worker.exec_command(f"docker images | grep {dockername}") + line: Optional[str] = None + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + break + else: + self.logger.warning("%s docker image does not exist", dockername) + + return True + + def add_tables(self) -> bool: + """Add self specific tables""" + return self.add_tables() + + def get_num_procs(self, machine: Machine) -> List: + """Determine number of processes by compute capacity""" + worker_ids: List = [] + num_procs: int + env: Dict[str, Any] + env = get_env_vars() + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) else: - self.logger.warning('%s docker image does not exist', dockername) - - return True - - def add_tables(self) -> bool: - """Add self specific tables""" - return self.add_tables() - - def get_num_procs(self, machine: Machine) -> List: - """Determine number of processes by compute capacity""" - worker_ids: List = [] - num_procs: int - env: Dict[str, Any] - env = get_env_vars() - if env['slurm_cpus'] > 0: - num_procs = int(env['slurm_cpus']) - else: - num_procs = int(machine.get_num_cpus() * .6) - - worker_ids = list(range(num_procs)) - - if len(worker_ids) == 0: - self.logger.error('num_procs must be bigger than zero to launch worker') - self.logger.error('Cannot launch worker on machine: %s', machine.id) - worker_ids = [] - - return worker_ids - - def get_f_vals(self, - machine: Machine, - worker_ids: range, - tuning=False) -> Dict[str, Any]: - #pylint:disable=unused-argument - """Determine kwargs for worker_interface""" - f_vals: Dict[str, Any] - f_vals = self.compose_f_vals(machine) - f_vals['envmt'] = self.get_envmt() - - if not tuning: - f_vals["num_procs"] = Value('i', len(worker_ids)) - - return f_vals - - def get_envmt(self): - """Get runtime envmt""" - raise NotImplementedError("Not implemented") - - def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: - """! Compose dict for WorkerInterface constructor - @param args The command line arguments - @param machine Machine instance - """ - f_vals: Dict[str, Any] = {} - f_vals["b_first"] = True - - #adding non-serializable obj when not running through celery - if not tuning: - f_vals["machine"] = machine - f_vals["bar_lock"] = Lock() - #multiprocess queue for jobs, shared on machine - f_vals["job_queue"] = mpQueue() - f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value('i', 0) - - return f_vals - - def get_kwargs(self, - gpu_idx: int, - f_vals: Dict[str, Any], - tuning=False) -> Dict[str, Any]: - """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - """ - envmt: Dict[str, Any] = f_vals["envmt"].copy() - kwargs: Dict[str, Any] = {} - - kwargs = { - 'gpu_id': gpu_idx, - 'envmt': envmt, - 'label': self.args.label, - 'docker_name': self.args.docker_name, - 'session_id': self.args.session_id - } - - #adding non-serializable obj when not running through celery - if not tuning: - kwargs["machine"] = f_vals["machine"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - kwargs["num_procs"] = f_vals["num_procs"] - kwargs["bar_lock"] = f_vals["bar_lock"] - kwargs["end_jobs"] = f_vals["end_jobs"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """Get list of jobs""" - raise NotImplementedError("Not implemented") - - def get_jobs(self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False): - """Interface function to get jobs based on session and find_state""" - #job_rows: List[SimpleDict] - ids: list - row: SimpleDict - - self.logger.info('Fetching DB rows...') - job_list = self.get_job_list(session, find_state, claim_num) - - if not self.check_jobs_found(job_list, find_state, session_id): - return [] - - if no_update: - return job_list - - ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info('Updating job state to %s', set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query(job, ['state'], - self.dbt.job_table.__tablename__) - else: - raise CustomError('DBTable must be set') - session.execute(query) - - session.commit() - - return job_list - - def shutdown_workers(self): - """Shutdown all active celery workers regardless of queue""" - return stop_active_workers() - - def cancel_consumer(self, queue): - """Cancel consumers for queue""" - try: - cmd = f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - subp = subprocess.Popen( #pylint: disable=consider-using-with - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - universal_newlines=True) - - #filter the workers by session id - sess_str = "sess_" + queue.split('_')[-1] - stdout, _ = subp.stdout, subp.stderr - while True: - line = stdout.readline() - if not line: - break - #stop workers that were feeding from this queue - if "->" in line and sess_str in line: - hostname = line.split('->')[1].split()[0].split(':')[0] - stop_named_worker(hostname) - - except Exception as exp: #pylint: disable=broad-exception-caught - self.logger.warning( - 'Error occurred trying to cancel consumer for queue: %s ', queue) - self.logger.warning(exp) - return False - - self.logger.info('Sucessfully cancelled consumer for queue: %s', queue) - - return True - - def celery_enqueue_call(self, context, q_name, task_id=False): - """Wrapper function for celery enqueue func""" - raise NotImplementedError('Not implemented') - - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" - self.logger.info('Starting enqueue') - with DbSession() as session: - while True: - job_list = [] - #get all the jobs from mySQL - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, #pylint: disable=no-member - self.args.session_id, #pylint: disable=no-member - job_batch_size) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i:min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - #calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - self.logger.info('Job counter: %s', job_counter.value) - if not job_list: - self.logger.info('All tasks added to queue') - break - - async def cleanup_redis_results(self, prefix): - """Remove stale redis results by key""" - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - keys = [] - cursor = "0" - if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - #no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info('Found %s old results', len(results)) - for key in keys: - try: - await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) - continue - - self.logger.info('Done removing old redis results for prefix: %s', prefix) - - return True - - async def consume(self, job_counter, prefix): - """Retrieve celery results from redis db""" - - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - while job_counter.value > 0: - cursor = "0" - keys = [] - while cursor != 0: + num_procs = int(machine.get_num_cpus() * 0.6) + + worker_ids = list(range(num_procs)) + + if len(worker_ids) == 0: + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) + worker_ids = [] + + return worker_ids + + def get_f_vals( + self, machine: Machine, worker_ids: range, tuning=False + ) -> Dict[str, Any]: + # pylint:disable=unused-argument + """Determine kwargs for worker_interface""" + f_vals: Dict[str, Any] + f_vals = self.compose_f_vals(machine) + f_vals["envmt"] = self.get_envmt() + + if not tuning: + f_vals["num_procs"] = Value("i", len(worker_ids)) + + return f_vals + + def get_envmt(self): + """Get runtime envmt""" + raise NotImplementedError("Not implemented") + + def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: + """! Compose dict for WorkerInterface constructor + @param args The command line arguments + @param machine Machine instance + """ + f_vals: Dict[str, Any] = {} + f_vals["b_first"] = True + + # adding non-serializable obj when not running through celery + if not tuning: + f_vals["machine"] = machine + f_vals["bar_lock"] = Lock() + # multiprocess queue for jobs, shared on machine + f_vals["job_queue"] = mpQueue() + f_vals["job_queue_lock"] = Lock() + f_vals["end_jobs"] = Value("i", 0) + + return f_vals + + def get_kwargs( + self, gpu_idx: int, f_vals: Dict[str, Any], tuning=False + ) -> Dict[str, Any]: + """! Helper function to set up kwargs for worker instances + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + """ + envmt: Dict[str, Any] = f_vals["envmt"].copy() + kwargs: Dict[str, Any] = {} + + kwargs = { + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, + } + + # adding non-serializable obj when not running through celery + if not tuning: + kwargs["machine"] = f_vals["machine"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + kwargs["num_procs"] = f_vals["num_procs"] + kwargs["bar_lock"] = f_vals["bar_lock"] + kwargs["end_jobs"] = f_vals["end_jobs"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """Get list of jobs""" + raise NotImplementedError("Not implemented") + + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): + """Interface function to get jobs based on session and find_state""" + # job_rows: List[SimpleDict] + ids: list + row: SimpleDict + + self.logger.info("Fetching DB rows...") + job_list = self.get_job_list(session, find_state, claim_num) + + if not self.check_jobs_found(job_list, find_state, session_id): + return [] + + if no_update: + return job_list + + ids = [row.id for row in job_list] + self.logger.info("%s jobs %s", find_state, ids) + self.logger.info("Updating job state to %s", set_state) + for job in job_list: + job.state = set_state + if self.dbt is not None: + query: str = gen_update_query( + job, ["state"], self.dbt.job_table.__tablename__ + ) + else: + raise CustomError("DBTable must be set") + session.execute(query) + + session.commit() + + return job_list + + def shutdown_workers(self): + """Shutdown all active celery workers regardless of queue""" + return stop_active_workers() + + def cancel_consumer(self, queue): + """Cancel consumers for queue""" + try: + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=True, + ) + + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] + stdout, _ = subp.stdout, subp.stderr + while True: + line = stdout.readline() + if not line: + break + # stop workers that were feeding from this queue + if "->" in line and sess_str in line: + hostname = line.split("->")[1].split()[0].split(":")[0] + stop_named_worker(hostname) + + except Exception as exp: # pylint: disable=broad-exception-caught + self.logger.warning( + "Error occurred trying to cancel consumer for queue: %s ", queue + ) + self.logger.warning(exp) + return False + + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) + + return True + + def celery_enqueue_call(self, context, q_name, task_id=False): + """Wrapper function for celery enqueue func""" + raise NotImplementedError("Not implemented") + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs""" + self.logger.info("Starting enqueue") + with DbSession() as session: + while True: + job_list = [] + # get all the jobs from mySQL + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, + ) + + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + for i in range(0, len(job_list), job_batch_size): + batch_jobs = job_list[i : min(i + job_batch_size, len(job_list))] + context_list = self.get_context_list(session, batch_jobs) + for context in context_list: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + + self.logger.info("Job counter: %s", job_counter.value) + if not job_list: + self.logger.info("All tasks added to queue") + break + + async def cleanup_redis_results(self, prefix): + """Remove stale redis results by key""" + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + keys = [] + cursor = "0" if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s results', len(results)) - for key in keys: + self.logger.info("Found %s old results", len(results)) + for key in keys: + try: + await redis.delete(key) + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + continue + + self.logger.info("Done removing old redis results for prefix: %s", prefix) + + return True + + async def consume(self, job_counter, prefix): + """Retrieve celery results from redis db""" + + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + while job_counter.value > 0: + cursor = "0" + keys = [] + while cursor != 0: + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s results", len(results)) + for key in keys: + try: + data = await redis.get(key) + if data: + _ = await self.parse_result(data.decode("utf-8")) + await redis.delete(key) + with job_counter_lock: + job_counter.value = job_counter.value - 1 + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + + await asyncio.sleep(1) + self.logger.info("Job counter reached 0") + await redis.close() + + return True + + def prep_tuning(self): + """Prep env for tuning start""" + cmd = None + subp_list = [] + q_name = None + if self.operation == Operation.COMPILE: + q_name = get_q_name(self, op_compile=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long + else: + q_name = get_q_name(self, op_eval=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long + + self.logger.info("celery Q name: %s", q_name) + if not self.args.enqueue_only: + try: + self.logger.info("Launching celery workers for queue %s", q_name) + subp_list = launch_celery_worker(self.operation, cmd, self.args, True) + self.logger.info("Done launching celery workers") + if not subp_list: + raise CustomError("Could not launch celery worker") + except kombu.exceptions.OperationalError as k_err: + self.logger.error("Redis error ocurred: %s", k_err) + return False + else: + purge_queue([q_name]) + + return q_name, subp_list + + # pylint: disable=too-many-locals + def tune(self, job_batch_size=1000): + """tuning loop to spin out celery tasks""" + + if self.args.shutdown_workers: + self.logger.info("Shutting down all celery workers") + stop_active_workers() + return True + + try: + q_name, subp_list = self.prep_tuning() + except CustomError as verr: + self.logger.error(verr) + return False + try: - data = await redis.get(key) - if data: - _ = await self.parse_result(data.decode('utf-8')) - await redis.delete(key) + # if enqueue_only is False, we launch the celery workers + if not self.args.enqueue_only: + for subp in subp_list: + subp.wait() + return True + except KeyboardInterrupt: + for subp in subp_list: + subp.kill() + return False + + start = time.time() + + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) + try: + enqueue_proc = Process( + target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] + ) + # Start enqueue proc + enqueue_proc.start() + + # cleanup old results + cleanup_proc = Process( + target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix) + ) + cleanup_proc.start() + cleanup_proc.join() + + # start async consume thread, blocking + consume_proc = Process( + target=self.async_wrap, args=(self.consume, job_counter, self.prefix) + ) + self.logger.info("Starting consume thread") + consume_proc.start() + + enqueue_proc.join() + # enqueue finished first fetch, remove hold on job_counter with job_counter_lock: - job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) - - await asyncio.sleep(1) - self.logger.info('Job counter reached 0') - await redis.close() - - return True - - def prep_tuning(self): - """Prep env for tuning start""" - cmd = None - subp_list = [] - q_name = None - if self.operation == Operation.COMPILE: - q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" #pylint: disable=line-too-long - else: - q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" #pylint: disable=line-too-long - - self.logger.info('celery Q name: %s', q_name) - if not self.args.enqueue_only: - try: - self.logger.info('Launching celery workers for queue %s', q_name) - subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info('Done launching celery workers') - if not subp_list: - raise CustomError('Could not launch celery worker') - except kombu.exceptions.OperationalError as k_err: - self.logger.error('Redis error ocurred: %s', k_err) - return False - else: - purge_queue([q_name]) - - return q_name, subp_list - - #pylint: disable=too-many-locals - def tune(self, job_batch_size=1000): - """tuning loop to spin out celery tasks""" - - if self.args.shutdown_workers: - self.logger.info('Shutting down all celery workers') - stop_active_workers() - return True - - try: - q_name, subp_list = self.prep_tuning() - except CustomError as verr: - self.logger.error(verr) - return False - - try: - #if enqueue_only is False, we launch the celery workers - if not self.args.enqueue_only: - for subp in subp_list: - subp.wait() - return True - except KeyboardInterrupt: - for subp in subp_list: - subp.kill() - return False - - start = time.time() - - #set job count to 1 until first job fetch is finished - job_counter = Value('i', 1) - try: - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - #Start enqueue proc - enqueue_proc.start() - - #cleanup old results - cleanup_proc = Process(target=self.async_wrap, - args=(self.cleanup_redis_results, self.prefix)) - cleanup_proc.start() - cleanup_proc.join() - - #start async consume thread, blocking - consume_proc = Process(target=self.async_wrap, - args=(self.consume, job_counter, self.prefix)) - self.logger.info('Starting consume thread') - consume_proc.start() - - enqueue_proc.join() - #enqueue finished first fetch, remove hold on job_counter - with job_counter_lock: - job_counter.value = job_counter.value - 1 - - #check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(10) - - consume_proc.join() - - except (KeyboardInterrupt, Exception) as exp: #pylint: disable=broad-exception-caught - self.logger.error('Error ocurred %s', exp) - purge_queue([q_name]) - self.cancel_consumer(q_name) - self.reset_job_state_on_ctrl_c() - with job_counter_lock: - job_counter.value = 0 - - self.cancel_consumer(q_name) - end = time.time() - self.logger.info("Took {:0>8} to tune".format( #pylint: disable=consider-using-f-string - str(timedelta(seconds=end - start)))) - - return True - - async def async_callback(self, async_func, *args): - """Wrapper function to await on async function""" - await async_func(*args) - - def async_wrap(self, async_func, *args): - """Run async function""" - try: - asyncio.run(self.async_callback(async_func, *args)) - except KeyboardInterrupt: - self.logger.warning('Keyboard interrupt caught, terminating') - - def reset_job_state_on_ctrl_c(self): - """Reset job state for jobs in flight""" - temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id #pylint: disable=invalid-name - attribs = ['state'] - temp_obj.state = 1 - - self.logger.info('Resetting job state in DB for in flight jobs') - - if self.operation == Operation.COMPILE: - state = 16 - elif self.operation == Operation.EVAL: - state = 12 - - query = gen_update_query(temp_obj, attribs, - self.dbt.job_table.__tablename__, - [('session', self.args.session_id), - ('state', state)]) - with DbSession() as session: - - #pylint: disable=duplicate-code - def callback() -> bool: - session.execute(query) - session.commit() + job_counter.value = job_counter.value - 1 + + # check for new jobs + while consume_proc.is_alive(): + enqueue_proc = Process( + target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] + ) + enqueue_proc.start() + enqueue_proc.join() + time.sleep(10) + + consume_proc.join() + + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) + purge_queue([q_name]) + self.cancel_consumer(q_name) + self.reset_job_state_on_ctrl_c() + with job_counter_lock: + job_counter.value = 0 + + self.cancel_consumer(q_name) + end = time.time() + self.logger.info( + "Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string + str(timedelta(seconds=end - start)) + ) + ) + return True - #pylint: enable=duplicate-code - - assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info('Sucessfully reset job state') - return True - - return False - - def has_tunable_operation(self): - """Check if current operation is a tuning operation""" - raise NotImplementedError("Not implemented") - - def get_job_attr(self): - """Get job attr for row selection""" - job_attr: List[str] = None - try: - job_attr = [column.name for column in inspect(self.dbt.job_table).c] - job_attr.remove("insert_ts") - job_attr.remove("update_ts") - except NoInspectionAvailable as error: - self.logger.warning("Ignoring error for init_session: %s", error) - return job_attr - - def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], - session_id: int) -> bool: - """check for end of jobs""" - if not job_rows: - # we are done - self.logger.warning('No %s jobs found, session %s', find_state, - session_id) - return False - return True - - @lru_cache(1) - def get_context_items(self): - """Helper function to get items for celery job context""" - kwargs = None - f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) - kwargs = self.get_kwargs(0, f_vals, tuning=True) - return kwargs - - def serialize_jobs(self, session, batch_jobs): - """Return list of serialize jobs""" - raise NotImplementedError("Not implemented") - - def build_context(self, serialized_jobs): - """Build context list for enqueue job""" - raise NotImplementedError("Not implemented") - - def get_context_list(self, session, batch_jobs): - """Return list of jobs (context) for celery queue""" - - context_list: List[dict] = None - serialized_jobs = self.serialize_jobs(session, batch_jobs) - #build context for each celery task - context_list = self.build_context(serialized_jobs) - - return context_list - - async def parse_result(self, data): - """Function callback for celery async jobs to store results""" - data = json.loads(data) - - with DbSession() as session: - try: - fin_json = data['result']['ret'] - context = data['result']['context'] - except KeyError as kerr: - self.logger.error(kerr) - return False + async def async_callback(self, async_func, *args): + """Wrapper function to await on async function""" + await async_func(*args) - self.logger.info('Parsing: %s', fin_json) - if self.operation == Operation.COMPILE: - self.process_compile_results(session, fin_json, context) - elif self.operation == Operation.EVAL: - self.process_eval_results(session, fin_json, context) - else: - raise CustomError('Unsupported tuning operation') + def async_wrap(self, async_func, *args): + """Run async function""" + try: + asyncio.run(self.async_callback(async_func, *args)) + except KeyboardInterrupt: + self.logger.warning("Keyboard interrupt caught, terminating") + + def reset_job_state_on_ctrl_c(self): + """Reset job state for jobs in flight""" + temp_obj = SimpleDict() + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] + temp_obj.state = 1 + + self.logger.info("Resetting job state in DB for in flight jobs") + + if self.operation == Operation.COMPILE: + state = 16 + elif self.operation == Operation.EVAL: + state = 12 + + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) + with DbSession() as session: + + # pylint: disable=duplicate-code + def callback() -> bool: + session.execute(query) + session.commit() + return True + + # pylint: enable=duplicate-code + + assert session_retry(session, callback, lambda x: x(), self.logger) + self.logger.info("Sucessfully reset job state") + return True + + return False - return True + def has_tunable_operation(self): + """Check if current operation is a tuning operation""" + raise NotImplementedError("Not implemented") - def process_compile_results(self, session, fin_json, context): - """Process result from fin_build worker""" - raise NotImplementedError("Not implemented") + def get_job_attr(self): + """Get job attr for row selection""" + job_attr: List[str] = None + try: + job_attr = [column.name for column in inspect(self.dbt.job_table).c] + job_attr.remove("insert_ts") + job_attr.remove("update_ts") + except NoInspectionAvailable as error: + self.logger.warning("Ignoring error for init_session: %s", error) + return job_attr + + def check_jobs_found( + self, job_rows: List[SimpleDict], find_state: List[Any], session_id: int + ) -> bool: + """check for end of jobs""" + if not job_rows: + # we are done + self.logger.warning("No %s jobs found, session %s", find_state, session_id) + return False + return True - def process_eval_results(self, session, fin_json, context): - """Process fin_json result""" - raise NotImplementedError("Not implemented") + @lru_cache(1) + def get_context_items(self): + """Helper function to get items for celery job context""" + kwargs = None + f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) + kwargs = self.get_kwargs(0, f_vals, tuning=True) + return kwargs + + def serialize_jobs(self, session, batch_jobs): + """Return list of serialize jobs""" + raise NotImplementedError("Not implemented") + + def build_context(self, serialized_jobs): + """Build context list for enqueue job""" + raise NotImplementedError("Not implemented") + + def get_context_list(self, session, batch_jobs): + """Return list of jobs (context) for celery queue""" + + context_list: List[dict] = None + serialized_jobs = self.serialize_jobs(session, batch_jobs) + # build context for each celery task + context_list = self.build_context(serialized_jobs) + + return context_list + + async def parse_result(self, data): + """Function callback for celery async jobs to store results""" + data = json.loads(data) + + with DbSession() as session: + try: + fin_json = data["result"]["ret"] + context = data["result"]["context"] + except KeyError as kerr: + self.logger.error(kerr) + return False + + self.logger.info("Parsing: %s", fin_json) + if self.operation == Operation.COMPILE: + self.process_compile_results(session, fin_json, context) + elif self.operation == Operation.EVAL: + self.process_eval_results(session, fin_json, context) + else: + raise CustomError("Unsupported tuning operation") + + return True + + def process_compile_results(self, session, fin_json, context): + """Process result from fin_build worker""" + raise NotImplementedError("Not implemented") + + def process_eval_results(self, session, fin_json, context): + """Process fin_json result""" + raise NotImplementedError("Not implemented") From 06fbe3cc6f7cbbef064ecf8d7b2fe7feaf7ed88b Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 16 Oct 2025 11:09:49 +0000 Subject: [PATCH 03/33] auto format --- tuna/miopen/miopen_lib.py | 1631 +++++++++++++++++++------------------ 1 file changed, 860 insertions(+), 771 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index a20e1a23..3450be34 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -60,7 +60,8 @@ from tuna.miopen.db.triggers import drop_miopen_triggers, get_miopen_triggers from tuna.miopen.utils.config_type import ConfigType from tuna.miopen.db.tables import MIOpenDBTables -#from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue + +# from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue from tuna.miopen.utils.json_to_sql import process_fdb_w_kernels, process_tuning_data from tuna.miopen.utils.json_to_sql import process_pdb_compile from tuna.miopen.utils.json_to_sql import clean_cache_table @@ -75,782 +76,870 @@ class MIOpen(MITunaInterface): - """Class to support MIOpen specific tuning functionality""" - - # pylint: disable=too-many-public-methods - - def __init__(self): - super().__init__(library=Library.MIOPEN) - self.args = None - self.set_state = None - - def parse_args(self): - # pylint: disable=too-many-statements - """Function to parse arguments""" - parser = setup_arg_parser( - 'Run Performance Tuning on a certain architecture', [ - TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY - ]) - parser.add_argument( - '--find_mode', - dest='find_mode', - type=int, - default=1, - help='Set the MIOPEN_FIND_MODE environment variable for MIOpen', - choices=['1', '3']) - parser.add_argument('--ticket', - dest='ticket', - type=str, - default=None, - help='Specify tuning ticket number') - parser.add_argument( - '--solver_id', - type=int, - dest='solver_id', - default=None, - help='Specify solver_id. Use --list_solvers to see options') - parser.add_argument('--dynamic_solvers_only', - dest='dynamic_solvers_only', - action='store_true', - default=False, - help='Only tune dynamic solvers.') - parser.add_argument( - '-B', - '--blacklist', - dest='blacklist', - type=str, - default=None, - help='MIOpen blacklist algorithm, if multiple then comma separate') - parser.add_argument('-i', - '--reset_interval', - type=int, - dest='reset_interval', - required=False, - help='Restart interval for job in hours.') - parser.add_argument( - '--gpu_lim', - dest='gpu_lim', - type=int, - default=None, - help='Limit the number of gpu workers created by Tuna, index from 0') - - parser.add_argument( - '-R', - '--rich_data', - dest='rich_data', - action='store_true', - default=False, - help='record intermediate parameter results from perf tuning') - - subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand('import_configs', - get_import_cfg_parser(), - required=False) - - subcommands.add_subcommand('load_job', - get_load_job_parser(), - required=False) - - subcommands.add_subcommand('export_db', - get_export_db_parser(), - required=False) - - subcommands.add_subcommand('update_golden', - get_update_golden_parser(), - required=False) - - group = parser.add_mutually_exclusive_group() - group.add_argument('--add_tables', - dest='add_tables', - action='store_true', - help='Add MIOpen library specific tables') - - group.add_argument('--init_session', - action='store_true', - dest='init_session', - help='Set up a new tuning session.') - group.add_argument( - '--fin_steps', - type=str, - dest='fin_steps', - help='Specify fin steps. Multiple steps should be comma separated.') - group.add_argument('--list_solvers', - action='store_true', - dest='list_solvers', - help='List of solvers from the solver table') - - # JD: implement the following two using fin_steps - group.add_argument('--update_solvers', - dest='update_solvers', - action='store_true', - help='Update the list of solvers in the database') - group.add_argument('--update_applicability', - dest='update_applicability', - action='store_true', - help='Update the applicability table in the database') - group.add_argument('-s', - '--status', - dest='check_status', - action='store_true', - default=False, - help='Check the status of machines') - - group.add_argument('-e', - '--exec', - dest='execute_cmd', - type=str, - default=None, - help='execute on each machine') - - self.args = parser.parse_args() - - if self.args.config_type is None: - self.args.config_type = ConfigType.convolution - - #overwritte common lib args with subcommand args value - if self.args.subcommand is not None: - self.overwrite_common_args() - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(-1) - - if self.args.list_solvers: - print_solvers() - raise CustomError('Printing solvers...') - - if self.args.fin_steps and self.args.subcommand != 'load_job': - self.check_fin_args(parser) - self.set_prefix() - - if self.args.find_mode is None and not (self.args.check_status or - self.args.restart_machine or - self.args.execute_cmd): - parser.error('find_mode must be specified for a tuning run') - - if self.args.blacklist: - self.check_blacklist(parser) - - args_check(self.args, parser) - - fin_session_steps = [ - 'miopen_find_compile', 'miopen_find_eval', 'miopen_perf_compile', - 'miopen_perf_eval', 'get_applicability', 'find_compile', 'find_eval' - ] - has_fin = False - if self.args.fin_steps: - has_fin = all(x in fin_session_steps for x in self.args.fin_steps) - - if (self.args.update_applicability or has_fin) and not self.args.session_id: - parser.error("session_id must be specified with this operation") - - self.dbt = MIOpenDBTables(session_id=self.args.session_id, - config_type=self.args.config_type) - self.update_operation() - - def set_prefix(self): - """Set redis key prefix""" - if isinstance(self.args.fin_steps, Iterable): - steps_str = ('-').join(x for x in self.args.fin_steps) - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_"\ - f"{steps_str}" - else: - steps_str = self.args.fin_steps[0] - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - - self.logger.info('redis prefix: %s', self.prefix) - - def overwrite_common_args(self): - """Overwrite common MIOpen_lib args with subcommand args""" - if self.args.subcommand is not None: - subc_dict = vars(self.args.get(self.args.subcommand)) - for sub_key in subc_dict: - if sub_key in vars(self.args): - self.args[sub_key] = subc_dict.get(sub_key) - - def check_fin_args(self, parser): - """! Helper function for fin args - @param parser The command line argument parser + """Class to support MIOpen specific tuning functionality""" + + # pylint: disable=too-many-public-methods + + def __init__(self): + super().__init__(library=Library.MIOPEN) + self.args = None + self.set_state = None + + def parse_args(self): + # pylint: disable=too-many-statements + """Function to parse arguments""" + parser = setup_arg_parser( + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) + parser.add_argument( + "--find_mode", + dest="find_mode", + type=int, + default=1, + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) + parser.add_argument( + "--solver_id", + type=int, + dest="solver_id", + default=None, + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) + parser.add_argument( + "-B", + "--blacklist", + dest="blacklist", + type=str, + default=None, + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) + parser.add_argument( + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", + type=int, + default=None, + help="Limit the number of gpu workers created by Tuna, index from 0", + ) + + parser.add_argument( + "-R", + "--rich_data", + dest="rich_data", + action="store_true", + default=False, + help="record intermediate parameter results from perf tuning", + ) + + subcommands = parser.add_subcommands(required=False) + subcommands.add_subcommand( + "import_configs", get_import_cfg_parser(), required=False + ) + + subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) + + subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) + + subcommands.add_subcommand( + "update_golden", get_update_golden_parser(), required=False + ) + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", + type=str, + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) + + # JD: implement the following two using fin_steps + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) + + self.args = parser.parse_args() + + if self.args.config_type is None: + self.args.config_type = ConfigType.convolution + + # overwritte common lib args with subcommand args value + if self.args.subcommand is not None: + self.overwrite_common_args() + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(-1) + + if self.args.list_solvers: + print_solvers() + raise CustomError("Printing solvers...") + + if self.args.fin_steps and self.args.subcommand != "load_job": + self.check_fin_args(parser) + self.set_prefix() + + if self.args.find_mode is None and not ( + self.args.check_status or self.args.restart_machine or self.args.execute_cmd + ): + parser.error("find_mode must be specified for a tuning run") + + if self.args.blacklist: + self.check_blacklist(parser) + + args_check(self.args, parser) + + fin_session_steps = [ + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", + ] + has_fin = False + if self.args.fin_steps: + has_fin = all(x in fin_session_steps for x in self.args.fin_steps) + + if (self.args.update_applicability or has_fin) and not self.args.session_id: + parser.error("session_id must be specified with this operation") + + self.dbt = MIOpenDBTables( + session_id=self.args.session_id, config_type=self.args.config_type + ) + self.update_operation() + + def set_prefix(self): + """Set redis key prefix""" + if isinstance(self.args.fin_steps, Iterable): + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = ( + f"d_{self.db_name}_sess_{self.args.session_id}_" f"{steps_str}" + ) + else: + steps_str = self.args.fin_steps[0] + self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" + + self.logger.info("redis prefix: %s", self.prefix) + + def overwrite_common_args(self): + """Overwrite common MIOpen_lib args with subcommand args""" + if self.args.subcommand is not None: + subc_dict = vars(self.args.get(self.args.subcommand)) + for sub_key in subc_dict: + if sub_key in vars(self.args): + self.args[sub_key] = subc_dict.get(sub_key) + + def check_fin_args(self, parser): + """! Helper function for fin args + @param parser The command line argument parser + """ + valid_fin_steps = list(k for k in FinStep.__members__) + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") + self.args.fin_steps = f_steps + for step in self.args.fin_steps: + if step not in valid_fin_steps: + parser.error(f"Supported fin steps are: {valid_fin_steps}") + assert len(self.args.fin_steps) == 1 + + def check_blacklist(self, parser): + """! Helper function + @param parser The command line argument parser + @return ret Boolean value + """ + self.args.blacklist = self.args.blacklist.split(",") + for sol in self.args.blacklist: + if sol not in MIOPEN_ALG_LIST: + parser.error("Incorrect blacklist value") + + def do_fin_work(self, gpu, f_vals): + """! Helper function to execute job independendent fin work + @param gpu Unique ID of the GPU + @param f_vals Dict containing runtime information + """ + kwargs = self.get_kwargs(gpu, f_vals) + fin_worker = FinClass(**kwargs) + + if self.args.update_solvers: + if not fin_worker.get_solvers(): + self.logger.error("No solvers returned from Fin class") + + return True + + def launch_worker(self, gpu_idx, f_vals, worker_lst): + """! Function to launch worker + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param worker_lst List containing worker instances + @return ret Boolean value """ - valid_fin_steps = list(k for k in FinStep.__members__) - if ',' in self.args.fin_steps: - parser.error('Multiple fin_steps currently not supported') - f_steps = self.args.fin_steps.split(',') - self.args.fin_steps = f_steps - for step in self.args.fin_steps: - if step not in valid_fin_steps: - parser.error(f"Supported fin steps are: {valid_fin_steps}") - assert len(self.args.fin_steps) == 1 - - def check_blacklist(self, parser): - """! Helper function - @param parser The command line argument parser - @return ret Boolean value - """ - self.args.blacklist = self.args.blacklist.split(',') - for sol in self.args.blacklist: - if sol not in MIOPEN_ALG_LIST: - parser.error("Incorrect blacklist value") - - def do_fin_work(self, gpu, f_vals): - """! Helper function to execute job independendent fin work - @param gpu Unique ID of the GPU - @param f_vals Dict containing runtime information - """ - kwargs = self.get_kwargs(gpu, f_vals) - fin_worker = FinClass(**kwargs) - - if self.args.update_solvers: - if not fin_worker.get_solvers(): - self.logger.error('No solvers returned from Fin class') - - return True - - def launch_worker(self, gpu_idx, f_vals, worker_lst): - """! Function to launch worker - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param worker_lst List containing worker instances - @return ret Boolean value - """ - # pylint: disable=too-many-branches - worker = None - kwargs = self.get_kwargs(gpu_idx, f_vals) - if self.args.update_applicability: - kwargs['fin_steps'] = ['applicability'] - worker = FinClass(**kwargs) - worker.start() - worker_lst.append(worker) - return True - - worker = FinClass(**kwargs) - ret = False - if self.args.check_status: - if not super().check_status(worker, f_vals["b_first"], gpu_idx, - f_vals["machine"], self.args.docker_name): - ret = True - elif self.args.init_session: - Session().add_new_session(self.args, worker) - elif self.args.execute_cmd: - # JD: Move the worker.exec_command to machine - self.logger.info(self.args.execute_cmd) - _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") - - return ret - - def compose_worker_list(self, machines): - # pylint: disable=too-many-branches - """! Helper function to compose worker_list - @param machines List of machines to execute on - """ - worker_lst = [] - fin_work_done = False - for machine in machines: - if self.args.restart_machine: - machine.restart_server(wait=False) - continue - - #fin_steps should only contain one step - worker_ids = None - if self.args.fin_steps and 'eval' in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() - if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) - else: - worker_ids = super().get_num_procs(machine) - - if self.args.update_applicability: - f_vals = super().get_f_vals(machine, [1]) - kwargs = self.get_kwargs(0, f_vals) - kwargs['fin_steps'] = ['applicability'] + # pylint: disable=too-many-branches + worker = None + kwargs = self.get_kwargs(gpu_idx, f_vals) + if self.args.update_applicability: + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + worker.start() + worker_lst.append(worker) + return True + worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) - cfg_rows = query.all() - len_rows = len(cfg_rows) - proc_lim = (len_rows + 99) / 100 - if 32 < proc_lim: - proc_lim = 32 - while len(worker_ids) > proc_lim: - worker_ids.pop() - - if len(worker_ids) == 0: - return None - - f_vals = super().get_f_vals(machine, worker_ids) - - if (self.args.update_solvers) and not fin_work_done: - self.do_fin_work(0, f_vals) - fin_work_done = True - break - - for gpu_idx in worker_ids: - self.logger.info('launch mid %u, proc %u', machine.id, gpu_idx) - if not self.launch_worker(gpu_idx, f_vals, worker_lst): - break - - return worker_lst - - def add_tables(self): - """! Function to create new DB tables - @return Bool - """ - ret_t = create_tables(get_miopen_tables()) - self.logger.info('DB creation successful: %s', ret_t) - recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) - return True - - def run(self): - # pylint: disable=duplicate-code - """! Main function to launch library""" - res = None - if self.args is None: - self.parse_args() - - if self.args.add_tables: - self.add_tables() - return None - - if self.args.subcommand is not None and self.args.subcommand == 'import_configs': - run_import_configs(self.args.import_configs, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'load_job': - run_load_job(self.args.load_job, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'export_db': - run_export_db(self.args.export_db, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == 'update_golden': - run_update_golden(self.args.update_golden, self.logger) - return None - - machines = load_machines(self.args) - res = self.compose_worker_list(machines) - return res - - def get_envmt(self): - """! Function to construct environment var - """ - envmt = ["MIOPEN_LOG_LEVEL=4"] - - envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") - envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") - - if self.args.find_mode: - envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") - - if self.args.blacklist: - bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(','): - envmt.append(bk_var) - - return envmt - - def get_kwargs(self, gpu_idx, f_vals, tuning=False): - """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param tuning Boolean that indicates if kwargs are for a tuning step - @return kwargs Dictionary - """ - kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs['fin_steps'] = self.args.fin_steps - kwargs['dynamic_solvers_only'] = self.args.dynamic_solvers_only - kwargs['config_type'] = self.args.config_type - kwargs['reset_interval'] = self.args.reset_interval - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """! Get list of jobs - @param session DB session - @param find_state DB job state - @param claim_num Number of DB jobs to pick up - @return List of DB jobs - - """ - job_list = self.get_job_objs(session, find_state, self.args.label, self.dbt, - self.get_job_attr(), claim_num, - self.args.fin_steps) - - return job_list - - def get_job_objs(self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: - """! Get list of job objects - @param session DB session - @param find_state DB job state - @param label DB job reason - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param claim_num Number of DB jobs to pick up - @param fin_steps List of MIFin steps - @return List of DB jobs - """ - entries: List[Tuple[SimpleDict, ...]] - conds: List[str] = [f"session={dbt.session.id}", "valid=1"] - - if label: - conds.append(f"reason='{label}'") - - conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip('{').strip('}') + ")") - - entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, - fin_steps) - return entries - - def compose_work_objs(self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: - """! Query a job list for update - @param session DB session - @param conds List of conditions for DB job WHERE clause - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param fin_steps List of MIFin steps - @return List of MIFin work objects - """ - job_entries = [] - if fin_steps: - conds.append(f"fin_step like '%{fin_steps[0]}%'") - else: - conds.append("fin_step='not_fin'") - - cond_str = ' AND '.join(conds) - if cond_str: - cond_str = f"WHERE {cond_str}" - if claim_num: - cond_str += f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" - else: - cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" - - job_entries = gen_select_objs(session, job_attr, - dbt.job_table.__tablename__, cond_str) - - return job_entries - - def compose_work_objs_fin(self, session, job_entries, - dbt) -> List[Tuple[SimpleDict, SimpleDict]]: - """! Return jobs for fin work - @param session DB session - @param job_entries List of DB jobs - @param dbt Class representing all DB tables associated with this class - @return ret Job tuple - """ - ret = [] - - cfg_rel = { - key: { - 'key': list(val.local_columns)[0].name, - 'ftble': str(list(val.remote_side)[0]).split('.', maxsplit=1)[0], - 'fkey': str(list(val.remote_side)[0]).split('.')[1] - } for key, val in inspect(dbt.config_table).relationships.items() - } - - if job_entries: - id_str = ','.join({str(job.config) for job in job_entries}) - cfg_cond_str = f"where valid=1 and id in ({id_str})" - cfg_attr = [column.name for column in inspect(dbt.config_table).c] - cfg_entries = gen_select_objs(session, cfg_attr, - dbt.config_table.__tablename__, - cfg_cond_str) - - cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) - - cfg_map = {cfg.id: cfg for cfg in cfg_entries} - - for job in job_entries: - ret.append((job, cfg_map[job.config])) - - return ret - - def attach_tensors(self, session, cfg_rel, cfg_entries): - """! Attach tensor relationship information to config entries - @param session DB session - @param cfg_rel DB Config col value - @param cfg_entries List of DB Config entries - @return cfg_entries List of DB Config entries with attached tensors (foreign keys) - - """ - for key, val in cfg_rel.items(): - rel_attr = [ - column.name - for column in inspect(get_class_by_tablename(val['ftble'])).c - ] - val['fattr'] = rel_attr - - for cfg in cfg_entries: - for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val['key']) - rel_cond_str = f"where {val['fkey']}={rel_val}" - setattr( - cfg, key, - gen_select_objs(session, val['fattr'], val['ftble'], - rel_cond_str)[0]) - return cfg_entries - - #deprecated - def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], - job_attr: List[str]) -> List[SimpleDict]: - """Find job tables in query results""" - if has_attr_set(job_rows[0], job_attr): - job_tables: List[SimpleDict] = job_rows - else: - job_i: int = 0 - tble: SimpleDict - for i, tble in enumerate(job_rows[0]): - if has_attr_set(tble, job_attr): - job_i = i - break - job_tables = [row[job_i] for row in job_rows] - - return job_tables - - def update_operation(self): - """! Update the workers type that this library needs""" - if self.args.fin_steps: - if 'miopen_find_compile' in self.args.fin_steps \ - or 'miopen_perf_compile' in self.args.fin_steps: - self.fetch_state.add('new') - self.set_state = 'compile_start' - self.operation = Operation.COMPILE - elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps: - self.fetch_state.add('new') - self.fetch_state.add('compiled') - self.set_state = 'eval_start' - self.operation = Operation.EVAL - - if self.args.update_applicability: - self.fetch_state.add("new") - - def has_tunable_operation(self): - """! Check if its a tuning loop operation - @return Bool value that represents if operation is tuning - """ - if self.args is None: - self.parse_args() - if self.args.subcommand and "load_job" in self.args.subcommand: - return False - if self.args.shutdown_workers: - return True - - return self.args.fin_steps and any( - s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS) - - @lru_cache(1) - def get_fdb_attr(self): - """! Get find_db table attrs - @return fdb_attr find_db table attributes without timestamps - """ - fdb_attr = None - fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] - fdb_attr.remove("insert_ts") - fdb_attr.remove("update_ts") - return fdb_attr - - @lru_cache(1) - def get_tuning_data_attr(self): - """! Get tuning_data table attrs - @return tuning_data_attr tuning_data table attributes without timestamps - """ - tuning_data_attr = None - tuning_data_attr = [ - column.name for column in inspect(self.dbt.tuning_data_table).c - ] - tuning_data_attr.remove("insert_ts") - tuning_data_attr.remove("update_ts") - return tuning_data_attr - - def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): - """! Return list of serialize jobs - @param session DB session - @param batch_jobs List of DB jobs - @return DB jobs, serialized - """ - entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) - return serialize_chunk(entries) - - def build_context( - self, serialized_jobs: Tuple[SimpleDict, SimpleDict]) -> List[dict]: - """Build context list for enqueue job""" - context_list = [] - kwargs = self.get_context_items() - fdb_attr = self.get_fdb_attr() - tuning_data_attr = self.get_tuning_data_attr() - for job, config in serialized_jobs: - context = { - 'job': job, - 'config': config, - 'operation': self.operation, - 'arch': self.dbt.session.arch, - 'num_cu': self.dbt.session.num_cu, - 'kwargs': kwargs, - 'rich_data': self.args.rich_data, - 'fdb_attr': fdb_attr, - 'tuning_data_attr': tuning_data_attr - } - context_list.append(context) - - return context_list - - def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): - """! Enqueue job (context) for queue:q_name - @param context Context for Celery job - @param q_name Custom Celery queue name - @param task_id Custom Redis Key - """ - - #hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - #function with correct q_name arg - #if import is moved to top it will result in circular imports - Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel - - return celery_enqueue.apply_async((context,), - task_id=('-').join([self.prefix, - uuid()]), - queue=q_name, - reply_to=q_name) - - def process_compile_results(self, session, fin_json, context): - """! Process result from fin_build worker - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) - pending = [] - solver_id_map = get_solver_ids() - - failed_job = False - result_str = '' - status = None - try: - if fin_json: - if 'success' in fin_json and fin_json["success"] is False: - status = [fin_json] + ret = False + if self.args.check_status: + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): + ret = True + elif self.args.init_session: + Session().add_new_session(self.args, worker) + elif self.args.execute_cmd: + # JD: Move the worker.exec_command to machine + self.logger.info(self.args.execute_cmd) + _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") + + return ret + + def compose_worker_list(self, machines): + # pylint: disable=too-many-branches + """! Helper function to compose worker_list + @param machines List of machines to execute on + """ + worker_lst = [] + fin_work_done = False + for machine in machines: + if self.args.restart_machine: + machine.restart_server(wait=False) + continue + + # fin_steps should only contain one step + worker_ids = None + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: + worker_ids = machine.get_avail_gpus() + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = range(self.args.gpu_lim) + else: + worker_ids = super().get_num_procs(machine) + + if self.args.update_applicability: + f_vals = super().get_f_vals(machine, [1]) + kwargs = self.get_kwargs(0, f_vals) + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + query = worker.query_cfgs(self.args.label) + cfg_rows = query.all() + len_rows = len(cfg_rows) + proc_lim = (len_rows + 99) / 100 + if 32 < proc_lim: + proc_lim = 32 + while len(worker_ids) > proc_lim: + worker_ids.pop() + + if len(worker_ids) == 0: + return None + + f_vals = super().get_f_vals(machine, worker_ids) + + if (self.args.update_solvers) and not fin_work_done: + self.do_fin_work(0, f_vals) + fin_work_done = True + break + + for gpu_idx in worker_ids: + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) + if not self.launch_worker(gpu_idx, f_vals, worker_lst): + break + + return worker_lst + + def add_tables(self): + """! Function to create new DB tables + @return Bool + """ + ret_t = create_tables(get_miopen_tables()) + self.logger.info("DB creation successful: %s", ret_t) + recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) + return True + + def run(self): + # pylint: disable=duplicate-code + """! Main function to launch library""" + res = None + if self.args is None: + self.parse_args() + + if self.args.add_tables: + self.add_tables() + return None + + if ( + self.args.subcommand is not None + and self.args.subcommand == "import_configs" + ): + run_import_configs(self.args.import_configs, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "load_job": + run_load_job(self.args.load_job, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "export_db": + run_export_db(self.args.export_db, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "update_golden": + run_update_golden(self.args.update_golden, self.logger) + return None + + machines = load_machines(self.args) + res = self.compose_worker_list(machines) + return res + + def get_envmt(self): + """! Function to construct environment var""" + envmt = ["MIOPEN_LOG_LEVEL=4"] + + envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") + envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") + + if self.args.find_mode: + envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") + + if self.args.blacklist: + bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) + for bk_var in bk_str.split(","): + envmt.append(bk_var) + + return envmt + + def get_kwargs(self, gpu_idx, f_vals, tuning=False): + """! Helper function to set up kwargs for worker instances + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param tuning Boolean that indicates if kwargs are for a tuning step + @return kwargs Dictionary + """ + kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """! Get list of jobs + @param session DB session + @param find_state DB job state + @param claim_num Number of DB jobs to pick up + @return List of DB jobs + + """ + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) + + return job_list + + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Get list of job objects + @param session DB session + @param find_state DB job state + @param label DB job reason + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param claim_num Number of DB jobs to pick up + @param fin_steps List of MIFin steps + @return List of DB jobs + """ + entries: List[Tuple[SimpleDict, ...]] + conds: List[str] = [f"session={dbt.session.id}", "valid=1"] + + if label: + conds.append(f"reason='{label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + entries = self.compose_work_objs( + session, conds, dbt, job_attr, claim_num, fin_steps + ) + return entries + + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Query a job list for update + @param session DB session + @param conds List of conditions for DB job WHERE clause + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param fin_steps List of MIFin steps + @return List of MIFin work objects + """ + job_entries = [] + if fin_steps: + conds.append(f"fin_step like '%{fin_steps[0]}%'") else: - if 'miopen_find_compile_result' in fin_json: - status = process_fdb_w_kernels(session, fin_json, - copy.deepcopy(context), self.dbt, - context['fdb_attr'], pending) - - elif 'miopen_perf_compile_result' in fin_json: - status = process_pdb_compile(session, fin_json, job, self.dbt, - solver_id_map) - - success, result_str = get_fin_result(status) - failed_job = not success - - except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) - session.rollback() - failed_job = True - except DataError as err: - self.logger.warning( - 'FinBuild: Invalid data, likely large workspace. DB Error: %s', err) - session.rollback() - failed_job = True - - if failed_job: - set_job_state(session, job, self.dbt, 'errored', False, result=result_str) - else: - set_job_state(session, - job, - self.dbt, - 'compiled', - False, - result=result_str) - - return True - - def process_eval_results(self, session, fin_json, context): - """! Process fin_json result - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) - failed_job = True - result_str = '' - pending = [] - orig_state = 'compiled' - - try: - if fin_json: - if 'success' in fin_json and fin_json["success"] is False: - status = [fin_json] + conds.append("fin_step='not_fin'") + + cond_str = " AND ".join(conds) + if cond_str: + cond_str = f"WHERE {cond_str}" + if claim_num: + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) else: - if 'miopen_find_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_find_eval_result', - check_str='evaluated') - elif 'miopen_perf_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') - if context["rich_data"]: - status = process_tuning_data(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['tuning_data_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') - - success, result_str = get_fin_result(status) - failed_job = not success - - if failed_job: - if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): #pylint: disable=no-member - self.logger.warning('max job retries exhausted, setting to errored') - set_job_state(session, job, self.dbt, 'errored', result=result_str) + cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" + + job_entries = gen_select_objs( + session, job_attr, dbt.job_table.__tablename__, cond_str + ) + + return job_entries + + def compose_work_objs_fin( + self, session, job_entries, dbt + ) -> List[Tuple[SimpleDict, SimpleDict]]: + """! Return jobs for fin work + @param session DB session + @param job_entries List of DB jobs + @param dbt Class representing all DB tables associated with this class + @return ret Job tuple + """ + ret = [] + + cfg_rel = { + key: { + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], + } + for key, val in inspect(dbt.config_table).relationships.items() + } + + if job_entries: + id_str = ",".join({str(job.config) for job in job_entries}) + cfg_cond_str = f"where valid=1 and id in ({id_str})" + cfg_attr = [column.name for column in inspect(dbt.config_table).c] + cfg_entries = gen_select_objs( + session, cfg_attr, dbt.config_table.__tablename__, cfg_cond_str + ) + + cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) + + cfg_map = {cfg.id: cfg for cfg in cfg_entries} + + for job in job_entries: + ret.append((job, cfg_map[job.config])) + + return ret + + def attach_tensors(self, session, cfg_rel, cfg_entries): + """! Attach tensor relationship information to config entries + @param session DB session + @param cfg_rel DB Config col value + @param cfg_entries List of DB Config entries + @return cfg_entries List of DB Config entries with attached tensors (foreign keys) + + """ + for key, val in cfg_rel.items(): + rel_attr = [ + column.name + for column in inspect(get_class_by_tablename(val["ftble"])).c + ] + val["fattr"] = rel_attr + + for cfg in cfg_entries: + for key, val in cfg_rel.items(): + rel_val = getattr(cfg, val["key"]) + rel_cond_str = f"where {val['fkey']}={rel_val}" + setattr( + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], rel_cond_str)[ + 0 + ], + ) + return cfg_entries + + # deprecated + def get_job_tables( + self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str] + ) -> List[SimpleDict]: + """Find job tables in query results""" + if has_attr_set(job_rows[0], job_attr): + job_tables: List[SimpleDict] = job_rows else: - self.logger.warning('resetting job state to %s, incrementing retries', - orig_state) - set_job_state(session, + job_i: int = 0 + tble: SimpleDict + for i, tble in enumerate(job_rows[0]): + if has_attr_set(tble, job_attr): + job_i = i + break + job_tables = [row[job_i] for row in job_rows] + + return job_tables + + def update_operation(self): + """! Update the workers type that this library needs""" + if self.args.fin_steps: + if ( + "miopen_find_compile" in self.args.fin_steps + or "miopen_perf_compile" in self.args.fin_steps + ): + self.fetch_state.add("new") + self.set_state = "compile_start" + self.operation = Operation.COMPILE + elif ( + "miopen_find_eval" in self.args.fin_steps + or "miopen_perf_eval" in self.args.fin_steps + ): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" + self.operation = Operation.EVAL + + if self.args.update_applicability: + self.fetch_state.add("new") + + def has_tunable_operation(self): + """! Check if its a tuning loop operation + @return Bool value that represents if operation is tuning + """ + if self.args is None: + self.parse_args() + if self.args.subcommand and "load_job" in self.args.subcommand: + return False + if self.args.shutdown_workers: + return True + + return self.args.fin_steps and any( + s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS + ) + + @lru_cache(1) + def get_fdb_attr(self): + """! Get find_db table attrs + @return fdb_attr find_db table attributes without timestamps + """ + fdb_attr = None + fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] + fdb_attr.remove("insert_ts") + fdb_attr.remove("update_ts") + return fdb_attr + + @lru_cache(1) + def get_tuning_data_attr(self): + """! Get tuning_data table attrs + @return tuning_data_attr tuning_data table attributes without timestamps + """ + tuning_data_attr = None + tuning_data_attr = [ + column.name for column in inspect(self.dbt.tuning_data_table).c + ] + tuning_data_attr.remove("insert_ts") + tuning_data_attr.remove("update_ts") + return tuning_data_attr + + def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): + """! Return list of serialize jobs + @param session DB session + @param batch_jobs List of DB jobs + @return DB jobs, serialized + """ + entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) + return serialize_chunk(entries) + + def build_context( + self, serialized_jobs: Tuple[SimpleDict, SimpleDict] + ) -> List[dict]: + """Build context list for enqueue job""" + context_list = [] + kwargs = self.get_context_items() + fdb_attr = self.get_fdb_attr() + tuning_data_attr = self.get_tuning_data_attr() + for job, config in serialized_jobs: + context = { + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, + } + context_list.append(context) + + return context_list + + def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): + """! Enqueue job (context) for queue:q_name + @param context Context for Celery job + @param q_name Custom Celery queue name + @param task_id Custom Redis Key + """ + + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue, + ) # pylint: disable=import-outside-toplevel + + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) + + def process_compile_results(self, session, fin_json, context): + """! Process result from fin_build worker + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) + pending = [] + solver_id_map = get_solver_ids() + + failed_job = False + result_str = "" + status = None + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) + + elif "miopen_perf_compile_result" in fin_json: + status = process_pdb_compile( + session, fin_json, job, self.dbt, solver_id_map + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + failed_job = True + except DataError as err: + self.logger.warning( + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err + ) + session.rollback() + failed_job = True + + if failed_job: + set_job_state(session, job, self.dbt, "errored", False, result=result_str) + else: + set_job_state(session, job, self.dbt, "compiled", False, result=result_str) + + return True + + def process_eval_results(self, session, fin_json, context): + """! Process fin_json result + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) + failed_job = True + result_str = "" + pending = [] + orig_state = "compiled" + + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + if context["rich_data"]: + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + if failed_job: + if job.retries >= ( + MAX_ERRORED_JOB_RETRIES - 1 + ): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") + set_job_state(session, job, self.dbt, "errored", result=result_str) + else: + self.logger.warning( + "resetting job state to %s, incrementing retries", orig_state + ) + set_job_state( + session, job, self.dbt, orig_state, increment_retries=True, - result=result_str) - else: - self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, 'evaluated', result=result_str) - clean_cache_table(self.dbt, job) - except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) - session.rollback() - set_job_state(session, job, self.dbt, 'errored', result=result_str) - - return True + result=result_str, + ) + else: + self.logger.info("\n\n Setting job state to evaluated") + set_job_state(session, job, self.dbt, "evaluated", result=result_str) + clean_cache_table(self.dbt, job) + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + set_job_state(session, job, self.dbt, "errored", result=result_str) + + return True From ebc50d874ee6bf511382c6cfdaaee3c5caeda4ee Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 17 Oct 2025 08:34:43 +0000 Subject: [PATCH 04/33] WIP: parallell functionality --- tuna/miopen/miopen_lib.py | 8 +++ tuna/mituna_interface.py | 102 ++++++++++++++++++++++++++++++++------ 2 files changed, 96 insertions(+), 14 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index 3450be34..f8be7c31 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -943,3 +943,11 @@ def process_eval_results(self, session, fin_json, context): set_job_state(session, job, self.dbt, "errored", result=result_str) return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 345a915b..51a11474 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -81,6 +81,12 @@ def __init__(self, library=Library.MIOPEN) -> None: self.db_name = os.environ["TUNA_DB_NAME"] self.prefix = None + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: """! Checking for docker @param worker The worker interface instance @@ -343,12 +349,21 @@ def celery_enqueue_call(self, context, q_name, task_id=False): raise NotImplementedError("Not implemented") def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" + """Enqueue celery jobs with machine-specific progress tracking""" self.logger.info("Starting enqueue") + current_batch_size = 0 + with DbSession() as session: while True: - job_list = [] - # get all the jobs from mySQL + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + break + + # Get jobs from database job_list = self.get_jobs( session, self.fetch_state, @@ -357,20 +372,63 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): job_batch_size, ) + if not job_list: + self.logger.info("No more jobs available to enqueue") + break + + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) + + self.logger.info("Claimed jobs: %s", list(new_job_ids)) + with job_counter_lock: job_counter.value = job_counter.value + len(job_list) - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i : min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) + # Process all jobs in this batch (remove the inner for loop) + context_list = self.get_context_list(session, job_list) + for context in context_list: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, + ) - self.logger.info("Job counter: %s", job_counter.value) - if not job_list: - self.logger.info("All tasks added to queue") - break + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + + def should_enqueue_more_jobs(self, session, current_batch_size): + """Check if we should enqueue more jobs based on THIS instance's progress""" + # Count only jobs claimed by this machine instance + our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + + # Allow enqueuing when less than 25% of our claimed jobs are still in progress + progress_threshold = current_batch_size * self.progress_factor + + self.logger.info( + "Our jobs in progress: %d, completed: %d, threshold: %d", + our_in_progress_count, + len(self.completed_job_ids), + progress_threshold, + ) + + return our_in_progress_count < progress_threshold + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data""" + # Keep sets from growing indefinitely + max_tracking_size = 10000 + if len(self.completed_job_ids) > max_tracking_size: + # Keep only the most recent completions + recent_completions = list(self.completed_job_ids)[-5000:] + self.completed_job_ids = set(recent_completions) + + # Remove old claimed jobs that are completed + self.claimed_job_ids -= set(recent_completions[:-1000]) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key""" @@ -525,6 +583,9 @@ def tune(self, job_batch_size=1000): with job_counter_lock: job_counter.value = job_counter.value - 1 + # Progress-aware polling - shorter intervals, smarter enqueuing + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) + # check for new jobs while consume_proc.is_alive(): enqueue_proc = Process( @@ -532,7 +593,7 @@ def tune(self, job_batch_size=1000): ) enqueue_proc.start() enqueue_proc.join() - time.sleep(10) + time.sleep(poll_interval) # Shorter, configurable polling consume_proc.join() @@ -663,6 +724,13 @@ async def parse_result(self, data): try: fin_json = data["result"]["ret"] context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + if job_id and job_id in self.claimed_job_ids: + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed", job_id) + except KeyError as kerr: self.logger.error(kerr) return False @@ -677,6 +745,12 @@ async def parse_result(self, data): return True + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + def process_compile_results(self, session, fin_json, context): """Process result from fin_build worker""" raise NotImplementedError("Not implemented") From b30ba371777b098682c38481db8bec0e7d3dc29d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 17 Oct 2025 10:35:18 +0000 Subject: [PATCH 05/33] perf(mituna_interface): optimize job state updates and improve enqueue reliability - Replace individual UPDATE queries with bulk UPDATE for job state changes - Add retry logic with configurable max attempts for database operations - Implement consecutive empty fetch tracking to prevent infinite loops - Add proper error handling and recovery for database session failures - Track enqueued jobs to prevent duplicate processing - Add configurable TUNA_MAX_EMPTY_FETCHES environment variable - Improve logging for better observability of enqueue process This optimization significantly reduces database round-trips when updating multiple job states and makes the enqueue process more resilient to transient failures. --- tuna/mituna_interface.py | 146 +++++++++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 50 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 51a11474..da17340b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -289,15 +289,22 @@ def get_jobs( ids = [row.id for row in job_list] self.logger.info("%s jobs %s", find_state, ids) self.logger.info("Updating job state to %s", set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query( - job, ["state"], self.dbt.job_table.__tablename__ - ) - else: - raise CustomError("DBTable must be set") + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = '{set_state}' + WHERE id IN ({id_str}) + """ session.execute(query) + + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") session.commit() @@ -349,57 +356,96 @@ def celery_enqueue_call(self, context, q_name, task_id=False): raise NotImplementedError("Not implemented") def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs with machine-specific progress tracking""" + """Enqueue celery jobs with machine-specific progress tracking and error handling""" self.logger.info("Starting enqueue") current_batch_size = 0 - - with DbSession() as session: - while True: - # Check if we should enqueue more jobs based on OUR progress - if current_batch_size > 0: - if not self.should_enqueue_more_jobs(session, current_batch_size): - self.logger.info( - "Waiting for our current batch to progress before enqueuing more" + + max_retries = 3 + retry_delay = 5 # seconds + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + + while True: + # Retry loop for database operations + for attempt in range(max_retries): + try: + with DbSession() as session: + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + return # Exit gracefully + + # Get jobs from database + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, ) - break - # Get jobs from database - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, # pylint: disable=no-member - self.args.session_id, # pylint: disable=no-member - job_batch_size, - ) - - if not job_list: - self.logger.info("No more jobs available to enqueue") - break + if not job_list: + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) + + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info('No new jobs after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return # Exit gracefully + + time.sleep(60) # Wait before next check + break # Break retry loop, continue main loop - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) + # Reset counter when jobs are found + consecutive_empty_fetches = 0 - self.logger.info("Claimed jobs: %s", list(new_job_ids)) + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) + self.logger.info("Claimed jobs: %s", list(new_job_ids)) - # Process all jobs in this batch (remove the inner for loop) - context_list = self.get_context_list(session, job_list) - for context in context_list: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - current_batch_size = len(job_list) - self.logger.info( - "Job counter: %s, enqueued batch size: %s", - job_counter.value, - current_batch_size, - ) + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + # Process all jobs in this batch + context_list = self.get_context_list(session, job_list) + for context in context_list: + try: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + # Continue with other jobs rather than failing completely + continue + + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, + ) - # Cleanup old tracking data periodically - self.cleanup_completed_jobs() + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + break # Success, break retry loop + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error('Max retries exceeded for database operation. Exiting.') + raise + + # If we got here with no jobs, the consecutive_empty_fetches logic handled it + if not job_list: + continue def should_enqueue_more_jobs(self, session, current_batch_size): """Check if we should enqueue more jobs based on THIS instance's progress""" From 3001f9e21d20c4df1da2dfeb40e687be453a44e4 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 21 Oct 2025 07:44:09 +0000 Subject: [PATCH 06/33] used yapf formatter --- tests/test_celery.py | 1 + tuna/miopen/miopen_lib.py | 1551 ++++++++++++++++++------------------- tuna/mituna_interface.py | 1390 +++++++++++++++++---------------- 3 files changed, 1464 insertions(+), 1478 deletions(-) diff --git a/tests/test_celery.py b/tests/test_celery.py index 1e4aa01d..4e2a20e8 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -51,6 +51,7 @@ from tuna.miopen.worker.fin_utils import compose_config_obj, fin_job from tuna.miopen.utils.lib_helper import get_worker + @pytest.mark.asyncio async def test_celery_workers(): miopen = MIOpen() diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index f8be7c31..ab55d23c 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -76,474 +76,474 @@ class MIOpen(MITunaInterface): - """Class to support MIOpen specific tuning functionality""" - - # pylint: disable=too-many-public-methods - - def __init__(self): - super().__init__(library=Library.MIOPEN) - self.args = None - self.set_state = None - - def parse_args(self): - # pylint: disable=too-many-statements - """Function to parse arguments""" - parser = setup_arg_parser( - "Run Performance Tuning on a certain architecture", - [ - TunaArgs.ARCH, - TunaArgs.NUM_CU, - TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, - TunaArgs.SESSION_ID, - TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, - TunaArgs.LABEL, - TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, - TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY, - ], - ) - parser.add_argument( - "--find_mode", - dest="find_mode", - type=int, - default=1, - help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", - choices=["1", "3"], - ) - parser.add_argument( - "--ticket", - dest="ticket", - type=str, - default=None, - help="Specify tuning ticket number", - ) - parser.add_argument( - "--solver_id", - type=int, - dest="solver_id", - default=None, - help="Specify solver_id. Use --list_solvers to see options", - ) - parser.add_argument( - "--dynamic_solvers_only", - dest="dynamic_solvers_only", - action="store_true", - default=False, - help="Only tune dynamic solvers.", - ) - parser.add_argument( - "-B", - "--blacklist", - dest="blacklist", - type=str, - default=None, - help="MIOpen blacklist algorithm, if multiple then comma separate", - ) - parser.add_argument( - "-i", - "--reset_interval", - type=int, - dest="reset_interval", - required=False, - help="Restart interval for job in hours.", - ) - parser.add_argument( - "--gpu_lim", - dest="gpu_lim", - type=int, - default=None, - help="Limit the number of gpu workers created by Tuna, index from 0", - ) - - parser.add_argument( - "-R", - "--rich_data", - dest="rich_data", - action="store_true", - default=False, - help="record intermediate parameter results from perf tuning", - ) - - subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand( - "import_configs", get_import_cfg_parser(), required=False - ) - - subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) - - subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) - - subcommands.add_subcommand( - "update_golden", get_update_golden_parser(), required=False - ) - - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--add_tables", - dest="add_tables", - action="store_true", - help="Add MIOpen library specific tables", - ) - - group.add_argument( - "--init_session", - action="store_true", - dest="init_session", - help="Set up a new tuning session.", - ) - group.add_argument( - "--fin_steps", - type=str, - dest="fin_steps", - help="Specify fin steps. Multiple steps should be comma separated.", - ) - group.add_argument( - "--list_solvers", - action="store_true", - dest="list_solvers", - help="List of solvers from the solver table", - ) - - # JD: implement the following two using fin_steps - group.add_argument( - "--update_solvers", - dest="update_solvers", - action="store_true", - help="Update the list of solvers in the database", - ) - group.add_argument( - "--update_applicability", - dest="update_applicability", - action="store_true", - help="Update the applicability table in the database", - ) - group.add_argument( - "-s", - "--status", - dest="check_status", - action="store_true", - default=False, - help="Check the status of machines", - ) - - group.add_argument( - "-e", - "--exec", - dest="execute_cmd", - type=str, - default=None, - help="execute on each machine", - ) - - self.args = parser.parse_args() - - if self.args.config_type is None: - self.args.config_type = ConfigType.convolution - - # overwritte common lib args with subcommand args value - if self.args.subcommand is not None: - self.overwrite_common_args() - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(-1) - - if self.args.list_solvers: - print_solvers() - raise CustomError("Printing solvers...") - - if self.args.fin_steps and self.args.subcommand != "load_job": - self.check_fin_args(parser) - self.set_prefix() - - if self.args.find_mode is None and not ( - self.args.check_status or self.args.restart_machine or self.args.execute_cmd - ): - parser.error("find_mode must be specified for a tuning run") - - if self.args.blacklist: - self.check_blacklist(parser) - - args_check(self.args, parser) - - fin_session_steps = [ - "miopen_find_compile", - "miopen_find_eval", - "miopen_perf_compile", - "miopen_perf_eval", - "get_applicability", - "find_compile", - "find_eval", - ] - has_fin = False - if self.args.fin_steps: - has_fin = all(x in fin_session_steps for x in self.args.fin_steps) - - if (self.args.update_applicability or has_fin) and not self.args.session_id: - parser.error("session_id must be specified with this operation") - - self.dbt = MIOpenDBTables( - session_id=self.args.session_id, config_type=self.args.config_type - ) - self.update_operation() - - def set_prefix(self): - """Set redis key prefix""" - if isinstance(self.args.fin_steps, Iterable): - steps_str = ("-").join(x for x in self.args.fin_steps) - self.prefix = ( - f"d_{self.db_name}_sess_{self.args.session_id}_" f"{steps_str}" - ) - else: - steps_str = self.args.fin_steps[0] - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - - self.logger.info("redis prefix: %s", self.prefix) - - def overwrite_common_args(self): - """Overwrite common MIOpen_lib args with subcommand args""" - if self.args.subcommand is not None: - subc_dict = vars(self.args.get(self.args.subcommand)) - for sub_key in subc_dict: - if sub_key in vars(self.args): - self.args[sub_key] = subc_dict.get(sub_key) - - def check_fin_args(self, parser): - """! Helper function for fin args + """Class to support MIOpen specific tuning functionality""" + + # pylint: disable=too-many-public-methods + + def __init__(self): + super().__init__(library=Library.MIOPEN) + self.args = None + self.set_state = None + + def parse_args(self): + # pylint: disable=too-many-statements + """Function to parse arguments""" + parser = setup_arg_parser( + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) + parser.add_argument( + "--find_mode", + dest="find_mode", + type=int, + default=1, + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) + parser.add_argument( + "--solver_id", + type=int, + dest="solver_id", + default=None, + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) + parser.add_argument( + "-B", + "--blacklist", + dest="blacklist", + type=str, + default=None, + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) + parser.add_argument( + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", + type=int, + default=None, + help="Limit the number of gpu workers created by Tuna, index from 0", + ) + + parser.add_argument( + "-R", + "--rich_data", + dest="rich_data", + action="store_true", + default=False, + help="record intermediate parameter results from perf tuning", + ) + + subcommands = parser.add_subcommands(required=False) + subcommands.add_subcommand("import_configs", + get_import_cfg_parser(), + required=False) + + subcommands.add_subcommand("load_job", + get_load_job_parser(), + required=False) + + subcommands.add_subcommand("export_db", + get_export_db_parser(), + required=False) + + subcommands.add_subcommand("update_golden", + get_update_golden_parser(), + required=False) + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", + type=str, + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) + + # JD: implement the following two using fin_steps + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) + + self.args = parser.parse_args() + + if self.args.config_type is None: + self.args.config_type = ConfigType.convolution + + # overwritte common lib args with subcommand args value + if self.args.subcommand is not None: + self.overwrite_common_args() + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(-1) + + if self.args.list_solvers: + print_solvers() + raise CustomError("Printing solvers...") + + if self.args.fin_steps and self.args.subcommand != "load_job": + self.check_fin_args(parser) + self.set_prefix() + + if self.args.find_mode is None and not (self.args.check_status or + self.args.restart_machine or + self.args.execute_cmd): + parser.error("find_mode must be specified for a tuning run") + + if self.args.blacklist: + self.check_blacklist(parser) + + args_check(self.args, parser) + + fin_session_steps = [ + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", + ] + has_fin = False + if self.args.fin_steps: + has_fin = all(x in fin_session_steps for x in self.args.fin_steps) + + if (self.args.update_applicability or has_fin) and not self.args.session_id: + parser.error("session_id must be specified with this operation") + + self.dbt = MIOpenDBTables(session_id=self.args.session_id, + config_type=self.args.config_type) + self.update_operation() + + def set_prefix(self): + """Set redis key prefix""" + if isinstance(self.args.fin_steps, Iterable): + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = (f"d_{self.db_name}_sess_{self.args.session_id}_" + f"{steps_str}") + else: + steps_str = self.args.fin_steps[0] + self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" + + self.logger.info("redis prefix: %s", self.prefix) + + def overwrite_common_args(self): + """Overwrite common MIOpen_lib args with subcommand args""" + if self.args.subcommand is not None: + subc_dict = vars(self.args.get(self.args.subcommand)) + for sub_key in subc_dict: + if sub_key in vars(self.args): + self.args[sub_key] = subc_dict.get(sub_key) + + def check_fin_args(self, parser): + """! Helper function for fin args @param parser The command line argument parser """ - valid_fin_steps = list(k for k in FinStep.__members__) - if "," in self.args.fin_steps: - parser.error("Multiple fin_steps currently not supported") - f_steps = self.args.fin_steps.split(",") - self.args.fin_steps = f_steps - for step in self.args.fin_steps: - if step not in valid_fin_steps: - parser.error(f"Supported fin steps are: {valid_fin_steps}") - assert len(self.args.fin_steps) == 1 - - def check_blacklist(self, parser): - """! Helper function + valid_fin_steps = list(k for k in FinStep.__members__) + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") + self.args.fin_steps = f_steps + for step in self.args.fin_steps: + if step not in valid_fin_steps: + parser.error(f"Supported fin steps are: {valid_fin_steps}") + assert len(self.args.fin_steps) == 1 + + def check_blacklist(self, parser): + """! Helper function @param parser The command line argument parser @return ret Boolean value """ - self.args.blacklist = self.args.blacklist.split(",") - for sol in self.args.blacklist: - if sol not in MIOPEN_ALG_LIST: - parser.error("Incorrect blacklist value") + self.args.blacklist = self.args.blacklist.split(",") + for sol in self.args.blacklist: + if sol not in MIOPEN_ALG_LIST: + parser.error("Incorrect blacklist value") - def do_fin_work(self, gpu, f_vals): - """! Helper function to execute job independendent fin work + def do_fin_work(self, gpu, f_vals): + """! Helper function to execute job independendent fin work @param gpu Unique ID of the GPU @param f_vals Dict containing runtime information """ - kwargs = self.get_kwargs(gpu, f_vals) - fin_worker = FinClass(**kwargs) + kwargs = self.get_kwargs(gpu, f_vals) + fin_worker = FinClass(**kwargs) - if self.args.update_solvers: - if not fin_worker.get_solvers(): - self.logger.error("No solvers returned from Fin class") + if self.args.update_solvers: + if not fin_worker.get_solvers(): + self.logger.error("No solvers returned from Fin class") - return True + return True - def launch_worker(self, gpu_idx, f_vals, worker_lst): - """! Function to launch worker + def launch_worker(self, gpu_idx, f_vals, worker_lst): + """! Function to launch worker @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information @param worker_lst List containing worker instances @return ret Boolean value """ - # pylint: disable=too-many-branches - worker = None - kwargs = self.get_kwargs(gpu_idx, f_vals) - if self.args.update_applicability: - kwargs["fin_steps"] = ["applicability"] - worker = FinClass(**kwargs) - worker.start() - worker_lst.append(worker) - return True - - worker = FinClass(**kwargs) - ret = False - if self.args.check_status: - if not super().check_status( - worker, - f_vals["b_first"], - gpu_idx, - f_vals["machine"], - self.args.docker_name, - ): - ret = True - elif self.args.init_session: - Session().add_new_session(self.args, worker) - elif self.args.execute_cmd: - # JD: Move the worker.exec_command to machine - self.logger.info(self.args.execute_cmd) - _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") - - return ret - - def compose_worker_list(self, machines): - # pylint: disable=too-many-branches - """! Helper function to compose worker_list + # pylint: disable=too-many-branches + worker = None + kwargs = self.get_kwargs(gpu_idx, f_vals) + if self.args.update_applicability: + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + worker.start() + worker_lst.append(worker) + return True + + worker = FinClass(**kwargs) + ret = False + if self.args.check_status: + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): + ret = True + elif self.args.init_session: + Session().add_new_session(self.args, worker) + elif self.args.execute_cmd: + # JD: Move the worker.exec_command to machine + self.logger.info(self.args.execute_cmd) + _, _, _ = worker.exec_command(self.args.execute_cmd + " 2>&1 ") + + return ret + + def compose_worker_list(self, machines): + # pylint: disable=too-many-branches + """! Helper function to compose worker_list @param machines List of machines to execute on """ - worker_lst = [] - fin_work_done = False - for machine in machines: - if self.args.restart_machine: - machine.restart_server(wait=False) - continue - - # fin_steps should only contain one step - worker_ids = None - if self.args.fin_steps and "eval" in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() - if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) - else: - worker_ids = super().get_num_procs(machine) - - if self.args.update_applicability: - f_vals = super().get_f_vals(machine, [1]) - kwargs = self.get_kwargs(0, f_vals) - kwargs["fin_steps"] = ["applicability"] - worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) - cfg_rows = query.all() - len_rows = len(cfg_rows) - proc_lim = (len_rows + 99) / 100 - if 32 < proc_lim: - proc_lim = 32 - while len(worker_ids) > proc_lim: - worker_ids.pop() - - if len(worker_ids) == 0: - return None - - f_vals = super().get_f_vals(machine, worker_ids) - - if (self.args.update_solvers) and not fin_work_done: - self.do_fin_work(0, f_vals) - fin_work_done = True - break - - for gpu_idx in worker_ids: - self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) - if not self.launch_worker(gpu_idx, f_vals, worker_lst): - break - - return worker_lst - - def add_tables(self): - """! Function to create new DB tables + worker_lst = [] + fin_work_done = False + for machine in machines: + if self.args.restart_machine: + machine.restart_server(wait=False) + continue + + # fin_steps should only contain one step + worker_ids = None + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: + worker_ids = machine.get_avail_gpus() + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = range(self.args.gpu_lim) + else: + worker_ids = super().get_num_procs(machine) + + if self.args.update_applicability: + f_vals = super().get_f_vals(machine, [1]) + kwargs = self.get_kwargs(0, f_vals) + kwargs["fin_steps"] = ["applicability"] + worker = FinClass(**kwargs) + query = worker.query_cfgs(self.args.label) + cfg_rows = query.all() + len_rows = len(cfg_rows) + proc_lim = (len_rows + 99) / 100 + if 32 < proc_lim: + proc_lim = 32 + while len(worker_ids) > proc_lim: + worker_ids.pop() + + if len(worker_ids) == 0: + return None + + f_vals = super().get_f_vals(machine, worker_ids) + + if (self.args.update_solvers) and not fin_work_done: + self.do_fin_work(0, f_vals) + fin_work_done = True + break + + for gpu_idx in worker_ids: + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) + if not self.launch_worker(gpu_idx, f_vals, worker_lst): + break + + return worker_lst + + def add_tables(self): + """! Function to create new DB tables @return Bool """ - ret_t = create_tables(get_miopen_tables()) - self.logger.info("DB creation successful: %s", ret_t) - recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) - return True - - def run(self): - # pylint: disable=duplicate-code - """! Main function to launch library""" - res = None - if self.args is None: - self.parse_args() - - if self.args.add_tables: - self.add_tables() - return None - - if ( - self.args.subcommand is not None - and self.args.subcommand == "import_configs" - ): - run_import_configs(self.args.import_configs, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "load_job": - run_load_job(self.args.load_job, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "export_db": - run_export_db(self.args.export_db, self.logger) - return None - - if self.args.subcommand is not None and self.args.subcommand == "update_golden": - run_update_golden(self.args.update_golden, self.logger) - return None - - machines = load_machines(self.args) - res = self.compose_worker_list(machines) - return res - - def get_envmt(self): - """! Function to construct environment var""" - envmt = ["MIOPEN_LOG_LEVEL=4"] - - envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") - envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") - - if self.args.find_mode: - envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") - - if self.args.blacklist: - bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(","): - envmt.append(bk_var) - - return envmt - - def get_kwargs(self, gpu_idx, f_vals, tuning=False): - """! Helper function to set up kwargs for worker instances + ret_t = create_tables(get_miopen_tables()) + self.logger.info("DB creation successful: %s", ret_t) + recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) + return True + + def run(self): + # pylint: disable=duplicate-code + """! Main function to launch library""" + res = None + if self.args is None: + self.parse_args() + + if self.args.add_tables: + self.add_tables() + return None + + if (self.args.subcommand is not None and + self.args.subcommand == "import_configs"): + run_import_configs(self.args.import_configs, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "load_job": + run_load_job(self.args.load_job, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "export_db": + run_export_db(self.args.export_db, self.logger) + return None + + if self.args.subcommand is not None and self.args.subcommand == "update_golden": + run_update_golden(self.args.update_golden, self.logger) + return None + + machines = load_machines(self.args) + res = self.compose_worker_list(machines) + return res + + def get_envmt(self): + """! Function to construct environment var""" + envmt = ["MIOPEN_LOG_LEVEL=4"] + + envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") + envmt.append("MIOPEN_DEBUG_IMPLICIT_GEMM_FIND_ALL_SOLUTIONS=1") + + if self.args.find_mode: + envmt.append(f"MIOPEN_FIND_MODE={self.args.find_mode}") + + if self.args.blacklist: + bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) + for bk_var in bk_str.split(","): + envmt.append(bk_var) + + return envmt + + def get_kwargs(self, gpu_idx, f_vals, tuning=False): + """! Helper function to set up kwargs for worker instances @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information @param tuning Boolean that indicates if kwargs are for a tuning step @return kwargs Dictionary """ - kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs["fin_steps"] = self.args.fin_steps - kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only - kwargs["config_type"] = self.args.config_type - kwargs["reset_interval"] = self.args.reset_interval + kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval - return kwargs + return kwargs - def get_job_list(self, session, find_state, claim_num): - """! Get list of jobs + def get_job_list(self, session, find_state, claim_num): + """! Get list of jobs @param session DB session @param find_state DB job state @param claim_num Number of DB jobs to pick up @return List of DB jobs """ - job_list = self.get_job_objs( - session, - find_state, - self.args.label, - self.dbt, - self.get_job_attr(), - claim_num, - self.args.fin_steps, - ) - - return job_list - - def get_job_objs( - self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None, - ) -> List[SimpleDict]: - """! Get list of job objects + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) + + return job_list + + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Get list of job objects @param session DB session @param find_state DB job state @param label DB job reason @@ -553,30 +553,29 @@ def get_job_objs( @param fin_steps List of MIFin steps @return List of DB jobs """ - entries: List[Tuple[SimpleDict, ...]] - conds: List[str] = [f"session={dbt.session.id}", "valid=1"] - - if label: - conds.append(f"reason='{label}'") - - conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") - - entries = self.compose_work_objs( - session, conds, dbt, job_attr, claim_num, fin_steps - ) - return entries - - def compose_work_objs( - self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None, - ) -> List[SimpleDict]: - """! Query a job list for update + entries: List[Tuple[SimpleDict, ...]] + conds: List[str] = [f"session={dbt.session.id}", "valid=1"] + + if label: + conds.append(f"reason='{label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, + fin_steps) + return entries + + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: + """! Query a job list for update @param session DB session @param conds List of conditions for DB job WHERE clause @param dbt Class representing all DB tables associated with this class @@ -584,370 +583,358 @@ def compose_work_objs( @param fin_steps List of MIFin steps @return List of MIFin work objects """ - job_entries = [] - if fin_steps: - conds.append(f"fin_step like '%{fin_steps[0]}%'") - else: - conds.append("fin_step='not_fin'") - - cond_str = " AND ".join(conds) - if cond_str: - cond_str = f"WHERE {cond_str}" - if claim_num: - cond_str += ( - f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" - ) - else: - cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" - - job_entries = gen_select_objs( - session, job_attr, dbt.job_table.__tablename__, cond_str - ) - - return job_entries - - def compose_work_objs_fin( - self, session, job_entries, dbt - ) -> List[Tuple[SimpleDict, SimpleDict]]: - """! Return jobs for fin work + job_entries = [] + if fin_steps: + conds.append(f"fin_step like '%{fin_steps[0]}%'") + else: + conds.append("fin_step='not_fin'") + + cond_str = " AND ".join(conds) + if cond_str: + cond_str = f"WHERE {cond_str}" + if claim_num: + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) + else: + cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" + + job_entries = gen_select_objs(session, job_attr, + dbt.job_table.__tablename__, cond_str) + + return job_entries + + def compose_work_objs_fin(self, session, job_entries, + dbt) -> List[Tuple[SimpleDict, SimpleDict]]: + """! Return jobs for fin work @param session DB session @param job_entries List of DB jobs @param dbt Class representing all DB tables associated with this class @return ret Job tuple """ - ret = [] - - cfg_rel = { - key: { - "key": list(val.local_columns)[0].name, - "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], - "fkey": str(list(val.remote_side)[0]).split(".")[1], - } - for key, val in inspect(dbt.config_table).relationships.items() - } - - if job_entries: - id_str = ",".join({str(job.config) for job in job_entries}) - cfg_cond_str = f"where valid=1 and id in ({id_str})" - cfg_attr = [column.name for column in inspect(dbt.config_table).c] - cfg_entries = gen_select_objs( - session, cfg_attr, dbt.config_table.__tablename__, cfg_cond_str - ) + ret = [] + + cfg_rel = { + key: { + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], + } for key, val in inspect(dbt.config_table).relationships.items() + } + + if job_entries: + id_str = ",".join({str(job.config) for job in job_entries}) + cfg_cond_str = f"where valid=1 and id in ({id_str})" + cfg_attr = [column.name for column in inspect(dbt.config_table).c] + cfg_entries = gen_select_objs(session, cfg_attr, + dbt.config_table.__tablename__, + cfg_cond_str) - cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) + cfg_entries = self.attach_tensors(session, cfg_rel, cfg_entries) - cfg_map = {cfg.id: cfg for cfg in cfg_entries} + cfg_map = {cfg.id: cfg for cfg in cfg_entries} - for job in job_entries: - ret.append((job, cfg_map[job.config])) + for job in job_entries: + ret.append((job, cfg_map[job.config])) - return ret + return ret - def attach_tensors(self, session, cfg_rel, cfg_entries): - """! Attach tensor relationship information to config entries + def attach_tensors(self, session, cfg_rel, cfg_entries): + """! Attach tensor relationship information to config entries @param session DB session @param cfg_rel DB Config col value @param cfg_entries List of DB Config entries @return cfg_entries List of DB Config entries with attached tensors (foreign keys) """ - for key, val in cfg_rel.items(): - rel_attr = [ - column.name - for column in inspect(get_class_by_tablename(val["ftble"])).c - ] - val["fattr"] = rel_attr - - for cfg in cfg_entries: - for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val["key"]) - rel_cond_str = f"where {val['fkey']}={rel_val}" - setattr( - cfg, - key, - gen_select_objs(session, val["fattr"], val["ftble"], rel_cond_str)[ - 0 - ], - ) - return cfg_entries - - # deprecated - def get_job_tables( - self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str] - ) -> List[SimpleDict]: - """Find job tables in query results""" - if has_attr_set(job_rows[0], job_attr): - job_tables: List[SimpleDict] = job_rows - else: - job_i: int = 0 - tble: SimpleDict - for i, tble in enumerate(job_rows[0]): - if has_attr_set(tble, job_attr): - job_i = i - break - job_tables = [row[job_i] for row in job_rows] - - return job_tables - - def update_operation(self): - """! Update the workers type that this library needs""" - if self.args.fin_steps: - if ( - "miopen_find_compile" in self.args.fin_steps - or "miopen_perf_compile" in self.args.fin_steps - ): - self.fetch_state.add("new") - self.set_state = "compile_start" - self.operation = Operation.COMPILE - elif ( - "miopen_find_eval" in self.args.fin_steps - or "miopen_perf_eval" in self.args.fin_steps - ): - self.fetch_state.add("new") - self.fetch_state.add("compiled") - self.set_state = "eval_start" - self.operation = Operation.EVAL - - if self.args.update_applicability: - self.fetch_state.add("new") - - def has_tunable_operation(self): - """! Check if its a tuning loop operation + for key, val in cfg_rel.items(): + rel_attr = [ + column.name + for column in inspect(get_class_by_tablename(val["ftble"])).c + ] + val["fattr"] = rel_attr + + for cfg in cfg_entries: + for key, val in cfg_rel.items(): + rel_val = getattr(cfg, val["key"]) + rel_cond_str = f"where {val['fkey']}={rel_val}" + setattr( + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], + rel_cond_str)[0], + ) + return cfg_entries + + # deprecated + def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], + job_attr: List[str]) -> List[SimpleDict]: + """Find job tables in query results""" + if has_attr_set(job_rows[0], job_attr): + job_tables: List[SimpleDict] = job_rows + else: + job_i: int = 0 + tble: SimpleDict + for i, tble in enumerate(job_rows[0]): + if has_attr_set(tble, job_attr): + job_i = i + break + job_tables = [row[job_i] for row in job_rows] + + return job_tables + + def update_operation(self): + """! Update the workers type that this library needs""" + if self.args.fin_steps: + if ("miopen_find_compile" in self.args.fin_steps or + "miopen_perf_compile" in self.args.fin_steps): + self.fetch_state.add("new") + self.set_state = "compile_start" + self.operation = Operation.COMPILE + elif ("miopen_find_eval" in self.args.fin_steps or + "miopen_perf_eval" in self.args.fin_steps): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" + self.operation = Operation.EVAL + + if self.args.update_applicability: + self.fetch_state.add("new") + + def has_tunable_operation(self): + """! Check if its a tuning loop operation @return Bool value that represents if operation is tuning """ - if self.args is None: - self.parse_args() - if self.args.subcommand and "load_job" in self.args.subcommand: - return False - if self.args.shutdown_workers: - return True - - return self.args.fin_steps and any( - s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS - ) - - @lru_cache(1) - def get_fdb_attr(self): - """! Get find_db table attrs + if self.args is None: + self.parse_args() + if self.args.subcommand and "load_job" in self.args.subcommand: + return False + if self.args.shutdown_workers: + return True + + return self.args.fin_steps and any( + s in self.args.fin_steps for s in MIOPEN_CELERY_STEPS) + + @lru_cache(1) + def get_fdb_attr(self): + """! Get find_db table attrs @return fdb_attr find_db table attributes without timestamps """ - fdb_attr = None - fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] - fdb_attr.remove("insert_ts") - fdb_attr.remove("update_ts") - return fdb_attr - - @lru_cache(1) - def get_tuning_data_attr(self): - """! Get tuning_data table attrs + fdb_attr = None + fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] + fdb_attr.remove("insert_ts") + fdb_attr.remove("update_ts") + return fdb_attr + + @lru_cache(1) + def get_tuning_data_attr(self): + """! Get tuning_data table attrs @return tuning_data_attr tuning_data table attributes without timestamps """ - tuning_data_attr = None - tuning_data_attr = [ - column.name for column in inspect(self.dbt.tuning_data_table).c - ] - tuning_data_attr.remove("insert_ts") - tuning_data_attr.remove("update_ts") - return tuning_data_attr - - def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): - """! Return list of serialize jobs + tuning_data_attr = None + tuning_data_attr = [ + column.name for column in inspect(self.dbt.tuning_data_table).c + ] + tuning_data_attr.remove("insert_ts") + tuning_data_attr.remove("update_ts") + return tuning_data_attr + + def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): + """! Return list of serialize jobs @param session DB session @param batch_jobs List of DB jobs @return DB jobs, serialized """ - entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) - return serialize_chunk(entries) - - def build_context( - self, serialized_jobs: Tuple[SimpleDict, SimpleDict] - ) -> List[dict]: - """Build context list for enqueue job""" - context_list = [] - kwargs = self.get_context_items() - fdb_attr = self.get_fdb_attr() - tuning_data_attr = self.get_tuning_data_attr() - for job, config in serialized_jobs: - context = { - "job": job, - "config": config, - "operation": self.operation, - "arch": self.dbt.session.arch, - "num_cu": self.dbt.session.num_cu, - "kwargs": kwargs, - "rich_data": self.args.rich_data, - "fdb_attr": fdb_attr, - "tuning_data_attr": tuning_data_attr, - } - context_list.append(context) - - return context_list - - def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): - """! Enqueue job (context) for queue:q_name + entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) + return serialize_chunk(entries) + + def build_context( + self, serialized_jobs: Tuple[SimpleDict, SimpleDict]) -> List[dict]: + """Build context list for enqueue job""" + context_list = [] + kwargs = self.get_context_items() + fdb_attr = self.get_fdb_attr() + tuning_data_attr = self.get_tuning_data_attr() + for job, config in serialized_jobs: + context = { + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, + } + context_list.append(context) + + return context_list + + def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): + """! Enqueue job (context) for queue:q_name @param context Context for Celery job @param q_name Custom Celery queue name @param task_id Custom Redis Key """ - # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - # function with correct q_name arg - # if import is moved to top it will result in circular imports - Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import ( - celery_enqueue, - ) # pylint: disable=import-outside-toplevel - - return celery_enqueue.apply_async( - (context,), - task_id=("-").join([self.prefix, uuid()]), - queue=q_name, - reply_to=q_name, - ) - - def process_compile_results(self, session, fin_json, context): - """! Process result from fin_build worker + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue,) # pylint: disable=import-outside-toplevel + + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) + + def process_compile_results(self, session, fin_json, context): + """! Process result from fin_build worker @param session DB session @param fin_json MIFin results for job @param context Context for Celery job @return Boolean value """ - job = SimpleDict(**context["job"]) - pending = [] - solver_id_map = get_solver_ids() - - failed_job = False - result_str = "" - status = None - try: - if fin_json: - if "success" in fin_json and fin_json["success"] is False: - status = [fin_json] - else: - if "miopen_find_compile_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - ) - - elif "miopen_perf_compile_result" in fin_json: - status = process_pdb_compile( - session, fin_json, job, self.dbt, solver_id_map - ) - - success, result_str = get_fin_result(status) - failed_job = not success - - except (OperationalError, IntegrityError) as err: - self.logger.warning("FinBuild: Unable to update Database %s", err) - session.rollback() - failed_job = True - except DataError as err: - self.logger.warning( - "FinBuild: Invalid data, likely large workspace. DB Error: %s", err - ) - session.rollback() - failed_job = True - - if failed_job: - set_job_state(session, job, self.dbt, "errored", False, result=result_str) + job = SimpleDict(**context["job"]) + pending = [] + solver_id_map = get_solver_ids() + + failed_job = False + result_str = "" + status = None + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] else: - set_job_state(session, job, self.dbt, "compiled", False, result=result_str) - - return True + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) - def process_eval_results(self, session, fin_json, context): - """! Process fin_json result + elif "miopen_perf_compile_result" in fin_json: + status = process_pdb_compile(session, fin_json, job, self.dbt, + solver_id_map) + + success, result_str = get_fin_result(status) + failed_job = not success + + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + failed_job = True + except DataError as err: + self.logger.warning( + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err) + session.rollback() + failed_job = True + + if failed_job: + set_job_state(session, job, self.dbt, "errored", False, result=result_str) + else: + set_job_state(session, + job, + self.dbt, + "compiled", + False, + result=result_str) + + return True + + def process_eval_results(self, session, fin_json, context): + """! Process fin_json result @param session DB session @param fin_json MIFin results for job @param context Context for Celery job @return Boolean value """ - job = SimpleDict(**context["job"]) - failed_job = True - result_str = "" - pending = [] - orig_state = "compiled" - - try: - if fin_json: - if "success" in fin_json and fin_json["success"] is False: - status = [fin_json] - else: - if "miopen_find_eval_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - result_str="miopen_find_eval_result", - check_str="evaluated", - ) - elif "miopen_perf_eval_result" in fin_json: - status = process_fdb_w_kernels( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["fdb_attr"], - pending, - result_str="miopen_perf_eval_result", - check_str="evaluated", - ) - if context["rich_data"]: - status = process_tuning_data( - session, - fin_json, - copy.deepcopy(context), - self.dbt, - context["tuning_data_attr"], - pending, - result_str="miopen_perf_eval_result", - check_str="evaluated", - ) - - success, result_str = get_fin_result(status) - failed_job = not success - - if failed_job: - if job.retries >= ( - MAX_ERRORED_JOB_RETRIES - 1 - ): # pylint: disable=no-member - self.logger.warning("max job retries exhausted, setting to errored") - set_job_state(session, job, self.dbt, "errored", result=result_str) - else: - self.logger.warning( - "resetting job state to %s, incrementing retries", orig_state - ) - set_job_state( - session, - job, - self.dbt, - orig_state, - increment_retries=True, - result=result_str, - ) - else: - self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, "evaluated", result=result_str) - clean_cache_table(self.dbt, job) - except (OperationalError, IntegrityError) as err: - self.logger.warning("FinBuild: Unable to update Database %s", err) - session.rollback() - set_job_state(session, job, self.dbt, "errored", result=result_str) - - return True - - def extract_job_id_from_context(self, context): - """Extract job ID from MIOpen celery task context""" - try: - # Extract job ID from the job context - return context.get("job", {}).get("id") - except (AttributeError, KeyError): - return None + job = SimpleDict(**context["job"]) + failed_job = True + result_str = "" + pending = [] + orig_state = "compiled" + + try: + if fin_json: + if "success" in fin_json and fin_json["success"] is False: + status = [fin_json] + else: + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + if context["rich_data"]: + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) + + success, result_str = get_fin_result(status) + failed_job = not success + + if failed_job: + if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") + set_job_state(session, job, self.dbt, "errored", result=result_str) + else: + self.logger.warning("resetting job state to %s, incrementing retries", + orig_state) + set_job_state( + session, + job, + self.dbt, + orig_state, + increment_retries=True, + result=result_str, + ) + else: + self.logger.info("\n\n Setting job state to evaluated") + set_job_state(session, job, self.dbt, "evaluated", result=result_str) + clean_cache_table(self.dbt, job) + except (OperationalError, IntegrityError) as err: + self.logger.warning("FinBuild: Unable to update Database %s", err) + session.rollback() + set_job_state(session, job, self.dbt, "errored", result=result_str) + + return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index da17340b..4c03ed0b 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -62,67 +62,69 @@ class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods - """Interface class extended by libraries. The purpose of this class is to define + """Interface class extended by libraries. The purpose of this class is to define common functionalities.""" - def __init__(self, library=Library.MIOPEN) -> None: + def __init__(self, library=Library.MIOPEN) -> None: - self.self: Library = self + self.self: Library = self - self.logger: logging.Logger = setup_logger( - logger_name=library.value, add_streamhandler=True - ) - self.args: argparse.Namespace + self.logger: logging.Logger = setup_logger(logger_name=library.value, + add_streamhandler=True) + self.args: argparse.Namespace - self.fetch_state: set = set() - self.max_job_retries = 10 - self.dbt = None - self.operation = None - self.db_name = os.environ["TUNA_DB_NAME"] - self.prefix = None + self.fetch_state: set = set() + self.max_job_retries = 10 + self.dbt = None + self.operation = None + self.db_name = os.environ["TUNA_DB_NAME"] + self.prefix = None - # Track jobs claimed by this specific instance when in distributor mode - self.claimed_job_ids = set() - self.completed_job_ids = set() - # if less than 25% of the jobs are remaining, we can grab more jobs - self.progress_factor = 0.25 + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 - def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: - """! Checking for docker + def check_docker(self, + worker: WorkerInterface, + dockername="miopentuna") -> bool: + """! Checking for docker @param worker The worker interface instance @param dockername The name of the docker """ - out2: ChannelFile - _, out2, _ = worker.exec_command("sudo docker info") - while not out2.channel.exit_status_ready(): - self.logger.warning(out2.readline()) - if out2.channel.exit_status > 0: - self.logger.warning("docker not installed or failed to run with sudo .... ") - return False - - out: StringIO = StringIO() - line: Optional[str] = None - _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning("%s docker image exists", dockername) - return True - if line is None: - self.logger.warning("%s docker image does not exist", dockername) - return False - - return False - - def check_status( - self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna", - ) -> bool: - """! Function to check gpu_status + out2: ChannelFile + _, out2, _ = worker.exec_command("sudo docker info") + while not out2.channel.exit_status_ready(): + self.logger.warning(out2.readline()) + if out2.channel.exit_status > 0: + self.logger.warning( + "docker not installed or failed to run with sudo .... ") + return False + + out: StringIO = StringIO() + line: Optional[str] = None + _, out, _ = worker.exec_command(f"sudo docker images | grep {dockername}") + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + return True + if line is None: + self.logger.warning("%s docker image does not exist", dockername) + return False + + return False + + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: + """! Function to check gpu_status @param worker The worker interface instance @param b_first Flag to keep track of visited GPU @param gpu_idx Unique ID of the GPU @@ -130,677 +132,673 @@ def check_status( @param dockername The name of the docker """ - if machine.chk_gpu_status(worker.gpu_id): - self.logger.info( - "Machine: (%s, %u) GPU_ID: %u OK", - machine.hostname, - machine.port, - gpu_idx, - ) - else: - self.logger.info( - "Machine: (%s, %u) GPU_ID: %u ERROR", - machine.hostname, - machine.port, - gpu_idx, - ) - - if not b_first: - return False - b_first = False - _, out, _ = worker.exec_command("docker info") - while not out.channel.exit_status_ready(): - pass - - if out.channel.exit_status > 0: - self.check_docker(worker, dockername) + if machine.chk_gpu_status(worker.gpu_id): + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) + else: + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) + + if not b_first: + return False + b_first = False + _, out, _ = worker.exec_command("docker info") + while not out.channel.exit_status_ready(): + pass + + if out.channel.exit_status > 0: + self.check_docker(worker, dockername) + else: + _, out, _ = worker.exec_command(f"docker images | grep {dockername}") + line: Optional[str] = None + for line in out.readlines(): + if line is not None: + if line.find(dockername) != -1: + self.logger.warning("%s docker image exists", dockername) + break else: - _, out, _ = worker.exec_command(f"docker images | grep {dockername}") - line: Optional[str] = None - for line in out.readlines(): - if line is not None: - if line.find(dockername) != -1: - self.logger.warning("%s docker image exists", dockername) - break - else: - self.logger.warning("%s docker image does not exist", dockername) - - return True - - def add_tables(self) -> bool: - """Add self specific tables""" - return self.add_tables() - - def get_num_procs(self, machine: Machine) -> List: - """Determine number of processes by compute capacity""" - worker_ids: List = [] - num_procs: int - env: Dict[str, Any] - env = get_env_vars() - if env["slurm_cpus"] > 0: - num_procs = int(env["slurm_cpus"]) - else: - num_procs = int(machine.get_num_cpus() * 0.6) - - worker_ids = list(range(num_procs)) - - if len(worker_ids) == 0: - self.logger.error("num_procs must be bigger than zero to launch worker") - self.logger.error("Cannot launch worker on machine: %s", machine.id) - worker_ids = [] - - return worker_ids - - def get_f_vals( - self, machine: Machine, worker_ids: range, tuning=False - ) -> Dict[str, Any]: - # pylint:disable=unused-argument - """Determine kwargs for worker_interface""" - f_vals: Dict[str, Any] - f_vals = self.compose_f_vals(machine) - f_vals["envmt"] = self.get_envmt() - - if not tuning: - f_vals["num_procs"] = Value("i", len(worker_ids)) - - return f_vals - - def get_envmt(self): - """Get runtime envmt""" - raise NotImplementedError("Not implemented") - - def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: - """! Compose dict for WorkerInterface constructor + self.logger.warning("%s docker image does not exist", dockername) + + return True + + def add_tables(self) -> bool: + """Add self specific tables""" + return self.add_tables() + + def get_num_procs(self, machine: Machine) -> List: + """Determine number of processes by compute capacity""" + worker_ids: List = [] + num_procs: int + env: Dict[str, Any] + env = get_env_vars() + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) + else: + num_procs = int(machine.get_num_cpus() * 0.6) + + worker_ids = list(range(num_procs)) + + if len(worker_ids) == 0: + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) + worker_ids = [] + + return worker_ids + + def get_f_vals(self, + machine: Machine, + worker_ids: range, + tuning=False) -> Dict[str, Any]: + # pylint:disable=unused-argument + """Determine kwargs for worker_interface""" + f_vals: Dict[str, Any] + f_vals = self.compose_f_vals(machine) + f_vals["envmt"] = self.get_envmt() + + if not tuning: + f_vals["num_procs"] = Value("i", len(worker_ids)) + + return f_vals + + def get_envmt(self): + """Get runtime envmt""" + raise NotImplementedError("Not implemented") + + def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: + """! Compose dict for WorkerInterface constructor @param args The command line arguments @param machine Machine instance """ - f_vals: Dict[str, Any] = {} - f_vals["b_first"] = True - - # adding non-serializable obj when not running through celery - if not tuning: - f_vals["machine"] = machine - f_vals["bar_lock"] = Lock() - # multiprocess queue for jobs, shared on machine - f_vals["job_queue"] = mpQueue() - f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value("i", 0) - - return f_vals - - def get_kwargs( - self, gpu_idx: int, f_vals: Dict[str, Any], tuning=False - ) -> Dict[str, Any]: - """! Helper function to set up kwargs for worker instances + f_vals: Dict[str, Any] = {} + f_vals["b_first"] = True + + # adding non-serializable obj when not running through celery + if not tuning: + f_vals["machine"] = machine + f_vals["bar_lock"] = Lock() + # multiprocess queue for jobs, shared on machine + f_vals["job_queue"] = mpQueue() + f_vals["job_queue_lock"] = Lock() + f_vals["end_jobs"] = Value("i", 0) + + return f_vals + + def get_kwargs(self, + gpu_idx: int, + f_vals: Dict[str, Any], + tuning=False) -> Dict[str, Any]: + """! Helper function to set up kwargs for worker instances @param gpu_idx Unique ID of the GPU @param f_vals Dict containing runtime information """ - envmt: Dict[str, Any] = f_vals["envmt"].copy() - kwargs: Dict[str, Any] = {} - - kwargs = { - "gpu_id": gpu_idx, - "envmt": envmt, - "label": self.args.label, - "docker_name": self.args.docker_name, - "session_id": self.args.session_id, - } - - # adding non-serializable obj when not running through celery - if not tuning: - kwargs["machine"] = f_vals["machine"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - kwargs["num_procs"] = f_vals["num_procs"] - kwargs["bar_lock"] = f_vals["bar_lock"] - kwargs["end_jobs"] = f_vals["end_jobs"] - kwargs["job_queue"] = f_vals["job_queue"] - kwargs["job_queue_lock"] = f_vals["job_queue_lock"] - - return kwargs - - def get_job_list(self, session, find_state, claim_num): - """Get list of jobs""" - raise NotImplementedError("Not implemented") - - def get_jobs( - self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False, - ): - """Interface function to get jobs based on session and find_state""" - # job_rows: List[SimpleDict] - ids: list - row: SimpleDict - - self.logger.info("Fetching DB rows...") - job_list = self.get_job_list(session, find_state, claim_num) - - if not self.check_jobs_found(job_list, find_state, session_id): - return [] - - if no_update: - return job_list - - ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info("Updating job state to %s", set_state) - - # OPTIMIZATION: Use bulk UPDATE instead of individual updates - if self.dbt is not None: - id_str = ','.join(map(str, ids)) - query = f""" + envmt: Dict[str, Any] = f_vals["envmt"].copy() + kwargs: Dict[str, Any] = {} + + kwargs = { + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, + } + + # adding non-serializable obj when not running through celery + if not tuning: + kwargs["machine"] = f_vals["machine"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + kwargs["num_procs"] = f_vals["num_procs"] + kwargs["bar_lock"] = f_vals["bar_lock"] + kwargs["end_jobs"] = f_vals["end_jobs"] + kwargs["job_queue"] = f_vals["job_queue"] + kwargs["job_queue_lock"] = f_vals["job_queue_lock"] + + return kwargs + + def get_job_list(self, session, find_state, claim_num): + """Get list of jobs""" + raise NotImplementedError("Not implemented") + + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): + """Interface function to get jobs based on session and find_state""" + # job_rows: List[SimpleDict] + ids: list + row: SimpleDict + + self.logger.info("Fetching DB rows...") + job_list = self.get_job_list(session, find_state, claim_num) + + if not self.check_jobs_found(job_list, find_state, session_id): + return [] + + if no_update: + return job_list + + ids = [row.id for row in job_list] + self.logger.info("%s jobs %s", find_state, ids) + self.logger.info("Updating job state to %s", set_state) + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" UPDATE {self.dbt.job_table.__tablename__} SET state = '{set_state}' WHERE id IN ({id_str}) """ - session.execute(query) - - # Update local objects to reflect new state - for job in job_list: - job.state = set_state - else: - raise CustomError("DBTable must be set") - - session.commit() - - return job_list - - def shutdown_workers(self): - """Shutdown all active celery workers regardless of queue""" - return stop_active_workers() - - def cancel_consumer(self, queue): - """Cancel consumers for queue""" + session.execute(query) + + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") + + session.commit() + + return job_list + + def shutdown_workers(self): + """Shutdown all active celery workers regardless of queue""" + return stop_active_workers() + + def cancel_consumer(self, queue): + """Cancel consumers for queue""" + try: + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=True, + ) + + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] + stdout, _ = subp.stdout, subp.stderr + while True: + line = stdout.readline() + if not line: + break + # stop workers that were feeding from this queue + if "->" in line and sess_str in line: + hostname = line.split("->")[1].split()[0].split(":")[0] + stop_named_worker(hostname) + + except Exception as exp: # pylint: disable=broad-exception-caught + self.logger.warning( + "Error occurred trying to cancel consumer for queue: %s ", queue) + self.logger.warning(exp) + return False + + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) + + return True + + def celery_enqueue_call(self, context, q_name, task_id=False): + """Wrapper function for celery enqueue func""" + raise NotImplementedError("Not implemented") + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs with machine-specific progress tracking and error handling""" + self.logger.info("Starting enqueue") + current_batch_size = 0 + + max_retries = 3 + retry_delay = 5 # seconds + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + + while True: + # Retry loop for database operations + for attempt in range(max_retries): try: - cmd = ( - f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - ) - subp = subprocess.Popen( # pylint: disable=consider-using-with - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - shell=True, - universal_newlines=True, - ) - - # filter the workers by session id - sess_str = "sess_" + queue.split("_")[-1] - stdout, _ = subp.stdout, subp.stderr - while True: - line = stdout.readline() - if not line: - break - # stop workers that were feeding from this queue - if "->" in line and sess_str in line: - hostname = line.split("->")[1].split()[0].split(":")[0] - stop_named_worker(hostname) - - except Exception as exp: # pylint: disable=broad-exception-caught - self.logger.warning( - "Error occurred trying to cancel consumer for queue: %s ", queue + with DbSession() as session: + # Check if we should enqueue more jobs based on OUR progress + if current_batch_size > 0: + if not self.should_enqueue_more_jobs(session, current_batch_size): + self.logger.info( + "Waiting for our current batch to progress before enqueuing more" + ) + return # Exit gracefully + + # Get jobs from database + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, ) - self.logger.warning(exp) - return False - - self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) - - return True - def celery_enqueue_call(self, context, q_name, task_id=False): - """Wrapper function for celery enqueue func""" - raise NotImplementedError("Not implemented") - - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs with machine-specific progress tracking and error handling""" - self.logger.info("Starting enqueue") - current_batch_size = 0 - - max_retries = 3 - retry_delay = 5 # seconds - consecutive_empty_fetches = 0 - max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) - - while True: - # Retry loop for database operations - for attempt in range(max_retries): - try: - with DbSession() as session: - # Check if we should enqueue more jobs based on OUR progress - if current_batch_size > 0: - if not self.should_enqueue_more_jobs(session, current_batch_size): - self.logger.info( - "Waiting for our current batch to progress before enqueuing more" - ) - return # Exit gracefully - - # Get jobs from database - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, # pylint: disable=no-member - self.args.session_id, # pylint: disable=no-member - job_batch_size, - ) - - if not job_list: - consecutive_empty_fetches += 1 - self.logger.info('No jobs found (attempt %d/%d)', - consecutive_empty_fetches, max_empty_fetches) - - if consecutive_empty_fetches >= max_empty_fetches: - self.logger.info('No new jobs after %d attempts. Exiting enqueue loop.', - max_empty_fetches) - return # Exit gracefully - - time.sleep(60) # Wait before next check - break # Break retry loop, continue main loop - - # Reset counter when jobs are found - consecutive_empty_fetches = 0 - - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) - - self.logger.info("Claimed jobs: %s", list(new_job_ids)) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - # Process all jobs in this batch - context_list = self.get_context_list(session, job_list) - for context in context_list: - try: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - except Exception as enqueue_err: # pylint: disable=broad-exception-caught - self.logger.error('Failed to enqueue job: %s', enqueue_err) - # Continue with other jobs rather than failing completely - continue - - current_batch_size = len(job_list) - self.logger.info( - "Job counter: %s, enqueued batch size: %s", - job_counter.value, - current_batch_size, - ) - - # Cleanup old tracking data periodically - self.cleanup_completed_jobs() - break # Success, break retry loop - - except Exception as db_err: # pylint: disable=broad-exception-caught - self.logger.warning('Database error on attempt %d/%d: %s', - attempt + 1, max_retries, db_err) - if attempt < max_retries - 1: - time.sleep(retry_delay * (attempt + 1)) # Exponential backoff - else: - self.logger.error('Max retries exceeded for database operation. Exiting.') - raise - - # If we got here with no jobs, the consecutive_empty_fetches logic handled it if not job_list: - continue + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) - def should_enqueue_more_jobs(self, session, current_batch_size): - """Check if we should enqueue more jobs based on THIS instance's progress""" - # Count only jobs claimed by this machine instance - our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) - - # Allow enqueuing when less than 25% of our claimed jobs are still in progress - progress_threshold = current_batch_size * self.progress_factor - - self.logger.info( - "Our jobs in progress: %d, completed: %d, threshold: %d", - our_in_progress_count, - len(self.completed_job_ids), - progress_threshold, - ) - - return our_in_progress_count < progress_threshold - - def cleanup_completed_jobs(self): - """Periodically clean up old job tracking data""" - # Keep sets from growing indefinitely - max_tracking_size = 10000 - if len(self.completed_job_ids) > max_tracking_size: - # Keep only the most recent completions - recent_completions = list(self.completed_job_ids)[-5000:] - self.completed_job_ids = set(recent_completions) - - # Remove old claimed jobs that are completed - self.claimed_job_ids -= set(recent_completions[:-1000]) - - async def cleanup_redis_results(self, prefix): - """Remove stale redis results by key""" - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - keys = [] - cursor = "0" - if prefix: - # a prefix is necessary when the need to different results in redis based on operation - # withough a prefix the redis key defaults to: "celery-task-meta-" - # with a prefix the key will look like: "celery-task-meta--" - # the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - # no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info("Found %s old results", len(results)) - for key in keys: - try: - await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode("utf-8")) - continue + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info( + 'No new jobs after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return # Exit gracefully - self.logger.info("Done removing old redis results for prefix: %s", prefix) + time.sleep(60) # Wait before next check + break # Break retry loop, continue main loop - return True - - async def consume(self, job_counter, prefix): - """Retrieve celery results from redis db""" - - backend_port, backend_host = get_backend_env() - redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") - - while job_counter.value > 0: - cursor = "0" - keys = [] - while cursor != 0: - if prefix: - # a prefix is necessary when the need to different results in redis based on operation - # withough a prefix the redis key defaults to: "celery-task-meta-" - # with a prefix the key will look like: "celery-task-meta--" - # the prefix can be applied when filtering the redis keys as bellow - cursor, results = await redis.scan(cursor, match=f"*{prefix}*") - else: - # no prefix, match any key - cursor, results = await redis.scan(cursor, match="*") - keys.extend(results) - self.logger.info("Found %s results", len(results)) - for key in keys: - try: - data = await redis.get(key) - if data: - _ = await self.parse_result(data.decode("utf-8")) - await redis.delete(key) - with job_counter_lock: - job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: - self.logger.error(red_err) - self.logger.info(key.decode("utf-8")) - - await asyncio.sleep(1) - self.logger.info("Job counter reached 0") - await redis.close() - - return True - - def prep_tuning(self): - """Prep env for tuning start""" - cmd = None - subp_list = [] - q_name = None - if self.operation == Operation.COMPILE: - q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long - else: - q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long - - self.logger.info("celery Q name: %s", q_name) - if not self.args.enqueue_only: - try: - self.logger.info("Launching celery workers for queue %s", q_name) - subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info("Done launching celery workers") - if not subp_list: - raise CustomError("Could not launch celery worker") - except kombu.exceptions.OperationalError as k_err: - self.logger.error("Redis error ocurred: %s", k_err) - return False - else: - purge_queue([q_name]) - - return q_name, subp_list - - # pylint: disable=too-many-locals - def tune(self, job_batch_size=1000): - """tuning loop to spin out celery tasks""" - - if self.args.shutdown_workers: - self.logger.info("Shutting down all celery workers") - stop_active_workers() - return True + # Reset counter when jobs are found + consecutive_empty_fetches = 0 - try: - q_name, subp_list = self.prep_tuning() - except CustomError as verr: - self.logger.error(verr) - return False - - try: - # if enqueue_only is False, we launch the celery workers - if not self.args.enqueue_only: - for subp in subp_list: - subp.wait() - return True - except KeyboardInterrupt: - for subp in subp_list: - subp.kill() - return False - - start = time.time() - - # set job count to 1 until first job fetch is finished - job_counter = Value("i", 1) - try: - enqueue_proc = Process( - target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] - ) - # Start enqueue proc - enqueue_proc.start() + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) - # cleanup old results - cleanup_proc = Process( - target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix) - ) - cleanup_proc.start() - cleanup_proc.join() - - # start async consume thread, blocking - consume_proc = Process( - target=self.async_wrap, args=(self.consume, job_counter, self.prefix) - ) - self.logger.info("Starting consume thread") - consume_proc.start() + self.logger.info("Claimed jobs: %s", list(new_job_ids)) - enqueue_proc.join() - # enqueue finished first fetch, remove hold on job_counter with job_counter_lock: - job_counter.value = job_counter.value - 1 - - # Progress-aware polling - shorter intervals, smarter enqueuing - poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) - - # check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process( - target=self.enqueue_jobs, args=[job_counter, job_batch_size, q_name] - ) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(poll_interval) # Shorter, configurable polling - - consume_proc.join() - - except ( - KeyboardInterrupt, - Exception, - ) as exp: # pylint: disable=broad-exception-caught - self.logger.error("Error ocurred %s", exp) - purge_queue([q_name]) - self.cancel_consumer(q_name) - self.reset_job_state_on_ctrl_c() - with job_counter_lock: - job_counter.value = 0 + job_counter.value = job_counter.value + len(job_list) + + # Process all jobs in this batch + context_list = self.get_context_list(session, job_list) + for context in context_list: + try: + # calling celery task, enqueuing to celery queue + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + # Continue with other jobs rather than failing completely + continue - self.cancel_consumer(q_name) - end = time.time() - self.logger.info( - "Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string - str(timedelta(seconds=end - start)) + current_batch_size = len(job_list) + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + current_batch_size, ) - ) - - return True - - async def async_callback(self, async_func, *args): - """Wrapper function to await on async function""" - await async_func(*args) - def async_wrap(self, async_func, *args): - """Run async function""" + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + break # Success, break retry loop + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error( + 'Max retries exceeded for database operation. Exiting.') + raise + + # If we got here with no jobs, the consecutive_empty_fetches logic handled it + if not job_list: + continue + + def should_enqueue_more_jobs(self, session, current_batch_size): + """Check if we should enqueue more jobs based on THIS instance's progress""" + # Count only jobs claimed by this machine instance + our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + + # Allow enqueuing when less than 25% of our claimed jobs are still in progress + progress_threshold = current_batch_size * self.progress_factor + + self.logger.info( + "Our jobs in progress: %d, completed: %d, threshold: %d", + our_in_progress_count, + len(self.completed_job_ids), + progress_threshold, + ) + + return our_in_progress_count < progress_threshold + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data""" + # Keep sets from growing indefinitely + max_tracking_size = 10000 + if len(self.completed_job_ids) > max_tracking_size: + # Keep only the most recent completions + recent_completions = list(self.completed_job_ids)[-5000:] + self.completed_job_ids = set(recent_completions) + + # Remove old claimed jobs that are completed + self.claimed_job_ids -= set(recent_completions[:-1000]) + + async def cleanup_redis_results(self, prefix): + """Remove stale redis results by key""" + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + keys = [] + cursor = "0" + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s old results", len(results)) + for key in keys: + try: + await redis.delete(key) + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + continue + + self.logger.info("Done removing old redis results for prefix: %s", prefix) + + return True + + async def consume(self, job_counter, prefix): + """Retrieve celery results from redis db""" + + backend_port, backend_host = get_backend_env() + redis = await aioredis.from_url(f"redis://{backend_host}:{backend_port}/15") + + while job_counter.value > 0: + cursor = "0" + keys = [] + while cursor != 0: + if prefix: + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow + cursor, results = await redis.scan(cursor, match=f"*{prefix}*") + else: + # no prefix, match any key + cursor, results = await redis.scan(cursor, match="*") + keys.extend(results) + self.logger.info("Found %s results", len(results)) + for key in keys: try: - asyncio.run(self.async_callback(async_func, *args)) - except KeyboardInterrupt: - self.logger.warning("Keyboard interrupt caught, terminating") - - def reset_job_state_on_ctrl_c(self): - """Reset job state for jobs in flight""" - temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name - attribs = ["state"] - temp_obj.state = 1 - - self.logger.info("Resetting job state in DB for in flight jobs") - - if self.operation == Operation.COMPILE: - state = 16 - elif self.operation == Operation.EVAL: - state = 12 - - query = gen_update_query( - temp_obj, - attribs, - self.dbt.job_table.__tablename__, - [("session", self.args.session_id), ("state", state)], - ) - with DbSession() as session: - - # pylint: disable=duplicate-code - def callback() -> bool: - session.execute(query) - session.commit() - return True - - # pylint: enable=duplicate-code - - assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info("Sucessfully reset job state") - return True - + data = await redis.get(key) + if data: + _ = await self.parse_result(data.decode("utf-8")) + await redis.delete(key) + with job_counter_lock: + job_counter.value = job_counter.value - 1 + except aioredis.exceptions.ResponseError as red_err: + self.logger.error(red_err) + self.logger.info(key.decode("utf-8")) + + await asyncio.sleep(1) + self.logger.info("Job counter reached 0") + await redis.close() + + return True + + def prep_tuning(self): + """Prep env for tuning start""" + cmd = None + subp_list = [] + q_name = None + if self.operation == Operation.COMPILE: + q_name = get_q_name(self, op_compile=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long + else: + q_name = get_q_name(self, op_eval=True) + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long + + self.logger.info("celery Q name: %s", q_name) + if not self.args.enqueue_only: + try: + self.logger.info("Launching celery workers for queue %s", q_name) + subp_list = launch_celery_worker(self.operation, cmd, self.args, True) + self.logger.info("Done launching celery workers") + if not subp_list: + raise CustomError("Could not launch celery worker") + except kombu.exceptions.OperationalError as k_err: + self.logger.error("Redis error ocurred: %s", k_err) return False - - def has_tunable_operation(self): - """Check if current operation is a tuning operation""" - raise NotImplementedError("Not implemented") - - def get_job_attr(self): - """Get job attr for row selection""" - job_attr: List[str] = None - try: - job_attr = [column.name for column in inspect(self.dbt.job_table).c] - job_attr.remove("insert_ts") - job_attr.remove("update_ts") - except NoInspectionAvailable as error: - self.logger.warning("Ignoring error for init_session: %s", error) - return job_attr - - def check_jobs_found( - self, job_rows: List[SimpleDict], find_state: List[Any], session_id: int - ) -> bool: - """check for end of jobs""" - if not job_rows: - # we are done - self.logger.warning("No %s jobs found, session %s", find_state, session_id) - return False + else: + purge_queue([q_name]) + + return q_name, subp_list + + # pylint: disable=too-many-locals + def tune(self, job_batch_size=1000): + """tuning loop to spin out celery tasks""" + + if self.args.shutdown_workers: + self.logger.info("Shutting down all celery workers") + stop_active_workers() + return True + + try: + q_name, subp_list = self.prep_tuning() + except CustomError as verr: + self.logger.error(verr) + return False + + try: + # if enqueue_only is False, we launch the celery workers + if not self.args.enqueue_only: + for subp in subp_list: + subp.wait() return True + except KeyboardInterrupt: + for subp in subp_list: + subp.kill() + return False + + start = time.time() + + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) + try: + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + # Start enqueue proc + enqueue_proc.start() + + # cleanup old results + cleanup_proc = Process(target=self.async_wrap, + args=(self.cleanup_redis_results, self.prefix)) + cleanup_proc.start() + cleanup_proc.join() + + # start async consume thread, blocking + consume_proc = Process(target=self.async_wrap, + args=(self.consume, job_counter, self.prefix)) + self.logger.info("Starting consume thread") + consume_proc.start() + + enqueue_proc.join() + # enqueue finished first fetch, remove hold on job_counter + with job_counter_lock: + job_counter.value = job_counter.value - 1 + + # Progress-aware polling - shorter intervals, smarter enqueuing + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) + + # check for new jobs + while consume_proc.is_alive(): + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + enqueue_proc.start() + enqueue_proc.join() + time.sleep(poll_interval) # Shorter, configurable polling + + consume_proc.join() + + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) + purge_queue([q_name]) + self.cancel_consumer(q_name) + self.reset_job_state_on_ctrl_c() + with job_counter_lock: + job_counter.value = 0 + + self.cancel_consumer(q_name) + end = time.time() + self.logger.info("Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string + str(timedelta(seconds=end - start)))) + + return True + + async def async_callback(self, async_func, *args): + """Wrapper function to await on async function""" + await async_func(*args) + + def async_wrap(self, async_func, *args): + """Run async function""" + try: + asyncio.run(self.async_callback(async_func, *args)) + except KeyboardInterrupt: + self.logger.warning("Keyboard interrupt caught, terminating") + + def reset_job_state_on_ctrl_c(self): + """Reset job state for jobs in flight""" + temp_obj = SimpleDict() + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] + temp_obj.state = 1 + + self.logger.info("Resetting job state in DB for in flight jobs") + + if self.operation == Operation.COMPILE: + state = 16 + elif self.operation == Operation.EVAL: + state = 12 + + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) + with DbSession() as session: + + # pylint: disable=duplicate-code + def callback() -> bool: + session.execute(query) + session.commit() + return True + + # pylint: enable=duplicate-code + + assert session_retry(session, callback, lambda x: x(), self.logger) + self.logger.info("Sucessfully reset job state") + return True + + return False + + def has_tunable_operation(self): + """Check if current operation is a tuning operation""" + raise NotImplementedError("Not implemented") + + def get_job_attr(self): + """Get job attr for row selection""" + job_attr: List[str] = None + try: + job_attr = [column.name for column in inspect(self.dbt.job_table).c] + job_attr.remove("insert_ts") + job_attr.remove("update_ts") + except NoInspectionAvailable as error: + self.logger.warning("Ignoring error for init_session: %s", error) + return job_attr + + def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], + session_id: int) -> bool: + """check for end of jobs""" + if not job_rows: + # we are done + self.logger.warning("No %s jobs found, session %s", find_state, + session_id) + return False + return True + + @lru_cache(1) + def get_context_items(self): + """Helper function to get items for celery job context""" + kwargs = None + f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) + kwargs = self.get_kwargs(0, f_vals, tuning=True) + return kwargs + + def serialize_jobs(self, session, batch_jobs): + """Return list of serialize jobs""" + raise NotImplementedError("Not implemented") + + def build_context(self, serialized_jobs): + """Build context list for enqueue job""" + raise NotImplementedError("Not implemented") + + def get_context_list(self, session, batch_jobs): + """Return list of jobs (context) for celery queue""" + + context_list: List[dict] = None + serialized_jobs = self.serialize_jobs(session, batch_jobs) + # build context for each celery task + context_list = self.build_context(serialized_jobs) + + return context_list + + async def parse_result(self, data): + """Function callback for celery async jobs to store results""" + data = json.loads(data) + + with DbSession() as session: + try: + fin_json = data["result"]["ret"] + context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + if job_id and job_id in self.claimed_job_ids: + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed", job_id) + + except KeyError as kerr: + self.logger.error(kerr) + return False - @lru_cache(1) - def get_context_items(self): - """Helper function to get items for celery job context""" - kwargs = None - f_vals = self.get_f_vals(Machine(local_machine=True), range(0), tuning=True) - kwargs = self.get_kwargs(0, f_vals, tuning=True) - return kwargs - - def serialize_jobs(self, session, batch_jobs): - """Return list of serialize jobs""" - raise NotImplementedError("Not implemented") - - def build_context(self, serialized_jobs): - """Build context list for enqueue job""" - raise NotImplementedError("Not implemented") - - def get_context_list(self, session, batch_jobs): - """Return list of jobs (context) for celery queue""" - - context_list: List[dict] = None - serialized_jobs = self.serialize_jobs(session, batch_jobs) - # build context for each celery task - context_list = self.build_context(serialized_jobs) - - return context_list - - async def parse_result(self, data): - """Function callback for celery async jobs to store results""" - data = json.loads(data) - - with DbSession() as session: - try: - fin_json = data["result"]["ret"] - context = data["result"]["context"] - - # Extract job ID from context to track completion - job_id = self.extract_job_id_from_context(context) - if job_id and job_id in self.claimed_job_ids: - self.completed_job_ids.add(job_id) - self.logger.info("Marked job %s as completed", job_id) - - except KeyError as kerr: - self.logger.error(kerr) - return False - - self.logger.info("Parsing: %s", fin_json) - if self.operation == Operation.COMPILE: - self.process_compile_results(session, fin_json, context) - elif self.operation == Operation.EVAL: - self.process_eval_results(session, fin_json, context) - else: - raise CustomError("Unsupported tuning operation") - - return True - - def extract_job_id_from_context(self, context): - """Extract job ID from celery task context""" - # This needs to be implemented in the MIOpen subclass - # based on how job IDs are stored in the context - raise NotImplementedError("Subclass must implement job ID extraction") - - def process_compile_results(self, session, fin_json, context): - """Process result from fin_build worker""" - raise NotImplementedError("Not implemented") - - def process_eval_results(self, session, fin_json, context): - """Process fin_json result""" - raise NotImplementedError("Not implemented") + self.logger.info("Parsing: %s", fin_json) + if self.operation == Operation.COMPILE: + self.process_compile_results(session, fin_json, context) + elif self.operation == Operation.EVAL: + self.process_eval_results(session, fin_json, context) + else: + raise CustomError("Unsupported tuning operation") + + return True + + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + + def process_compile_results(self, session, fin_json, context): + """Process result from fin_build worker""" + raise NotImplementedError("Not implemented") + + def process_eval_results(self, session, fin_json, context): + """Process fin_json result""" + raise NotImplementedError("Not implemented") From a19f1bb642e61446e798cb0191bc848c07860376 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 6 Nov 2025 04:39:15 -0600 Subject: [PATCH 07/33] changed default base image and properly passed it through to second docker build stage --- Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index cc493e68..0a9ce625 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG OSDB_BKC_VERSION= ARG HASVER=${ROCMVERSION:+$ROCMVERSION} ARG HASVER=${HASVER:-$OSDB_BKC_VERSION} -ARG BASEIMAGE=rocm/miopen:ci_3708da +ARG BASEIMAGE=rocm/miopen:ci_7c45f0 ARG UBUNTU=ubuntu:22.04 #use UBUNTU with rocm version set @@ -18,6 +18,8 @@ FROM $USEIMAGE as dtuna-ver-0 #args before from are wiped ARG ROCMVERSION= ARG OSDB_BKC_VERSION= +# pass through baseimage for later use +ARG BASEIMAGE RUN test -d /opt/rocm*; \ if [ $? -eq 0 ] ; then \ From c396afcba4e36b1e54483efc76cb0dc7bb0dd8e3 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 7 Nov 2025 03:48:42 -0600 Subject: [PATCH 08/33] changed to newer version of clang-format (12 no longer available) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 0a9ce625..459f1da4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,7 +73,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all apt-utils \ build-essential \ cmake \ - clang-format-12 \ + clang-format \ curl \ doxygen \ gdb \ From c284083b0fce84854b873355580af94c7d8fc5e4 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Mon, 10 Nov 2025 09:00:03 -0600 Subject: [PATCH 09/33] Big overhaul in order to build docker image and run on mi355. Used new dependencies to be able to run on python 3.12 --- Dockerfile | 146 +++++++++++++++++++++++----- requirements.txt | 22 ++--- tests/test_celery.py | 2 +- tuna/db/session_mixin.py | 21 ++-- tuna/miopen/db/mixin_tables.py | 8 +- tuna/miopen/scripts/dupe_resolve.py | 17 ++-- tuna/mituna_interface.py | 6 +- tuna/rocmlir/rocmlir_tables.py | 5 +- tuna/utils/db_utility.py | 26 +++-- 9 files changed, 181 insertions(+), 72 deletions(-) diff --git a/Dockerfile b/Dockerfile index 459f1da4..043fd76d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -80,10 +80,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all git \ lbzip2 \ lcov \ + libboost-filesystem-dev \ + libbz2-dev \ + libeigen3-dev \ libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ mysql-client \ + nlohmann-json3-dev \ openssh-server \ pkg-config \ python3 \ @@ -119,17 +123,64 @@ ENV UBSAN_OPTIONS=print_stacktrace=1 RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb +# Install frugally-deep and its dependencies (header-only libraries) +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + # Clone FunctionalPlus + git clone https://github.com/Dobiasd/FunctionalPlus.git /tmp/FunctionalPlus && \ + cd /tmp/FunctionalPlus && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/usr/local .. && \ + make install && \ + # Clone frugally-deep + git clone https://github.com/Dobiasd/frugally-deep.git /tmp/frugally-deep && \ + cd /tmp/frugally-deep && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/usr/local .. && \ + make install && \ + # Clean up + rm -rf /tmp/FunctionalPlus /tmp/frugally-deep; \ + fi + + +# ============================================ +# Check if BOTH MIOpen and Fin are already installed +# ============================================ +# We check both together because Fin depends on MIOpen headers +# If either is missing, we build both to ensure compatibility +RUN if [ -f /opt/rocm/lib/libMIOpen.so ] && [ -d /opt/rocm/include/miopen ] && \ + ([ -f /opt/rocm/bin/fin ] || [ -f /opt/rocm/miopen/bin/fin ]); then \ + echo "=== Both MIOpen and Fin already installed, skipping builds ==="; \ + echo "export SKIP_MIOPEN_BUILD=1" >> /env; \ + echo "export SKIP_FIN_BUILD=1" >> /env; \ + else \ + echo "=== Building MIOpen and Fin from source (Fin needs MIOpen headers) ==="; \ + fi +# ============================================ +# Clone MIOpen (if needed) +# ============================================ ARG ROCM_LIBS_DIR=/root/rocm-libraries ARG MIOPEN_DIR=$ROCM_LIBS_DIR/projects/miopen -#Clone MIOpen -RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR; \ + else \ + mkdir -p $ROCM_LIBS_DIR/projects && mkdir -p $MIOPEN_DIR; \ + fi + +# Run sparse-checkout from the git repo root +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + cd $ROCM_LIBS_DIR && git sparse-checkout set projects/miopen; \ + fi + WORKDIR $MIOPEN_DIR -RUN git sparse-checkout set projects/miopen + # not sure what this commit is, using latest develop for now # ARG MIOPEN_BRANCH=4940cf3ec ARG MIOPEN_BRANCH=develop -RUN git pull && git checkout $MIOPEN_BRANCH +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + git pull && git checkout $MIOPEN_BRANCH; \ + fi ARG PREFIX=/opt/rocm ARG MIOPEN_DEPS=$MIOPEN_DIR/deps @@ -137,7 +188,7 @@ ARG MIOPEN_DEPS=$MIOPEN_DIR/deps # Install dependencies # included in rocm/miopen:ci_xxxxxx ARG BUILD_MIOPEN_DEPS= ARG ARCH_TARGET= -RUN . /env; if [ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]; then\ +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ] && ([ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]); then\ pip install cget; \ if ! [ -z $ARCH_TARGET ]; then \ sed -i "s#\(composable_kernel.*\)#\1 -DGPU_TARGETS=\"$ARCH_TARGET\"#" requirements.txt; \ @@ -145,6 +196,8 @@ RUN . /env; if [ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]; then\ apt-get remove -y composablekernel-dev miopen-hip; \ CXX=/opt/rocm/llvm/bin/clang++ cget install -f ./dev-requirements.txt --prefix $MIOPEN_DEPS -DCMAKE_POLICY_VERSION_MINIMUM=3.5; \ git checkout requirements.txt; \ + echo "=== DEBUG: cget install completed, checking for composable_kernel ==="; \ + ls -la $MIOPEN_DEPS/lib/cmake/ || echo "No cmake configs found"; \ fi ARG TUNA_USER=miopenpdb @@ -154,36 +207,85 @@ WORKDIR $MIOPEN_DIR/build ARG MIOPEN_CACHE_DIR=/tmp/${TUNA_USER}/cache ARG MIOPEN_USER_DB_PATH=/tmp/$TUNA_USER/config/miopen # build kdb objects with offline clang compiler, disable comgr + hiprtc (which would make target id specific code objects) -ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=Off -DMIOPEN_USE_HIPRTC=Off -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS}" +ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=Off -DMIOPEN_USE_HIPRTC=Off -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS} -DBUILD_TESTING=Off -DMIOPEN_USE_MLIR=OFF" -RUN echo "MIOPEN: Selected $BACKEND backend." -RUN if [ $BACKEND = "OpenCL" ]; then \ - cmake -DMIOPEN_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ ${MIOPEN_CMAKE_ARGS} $MIOPEN_DIR ; \ - else \ - CXX=/opt/rocm/llvm/bin/clang++ cmake ${MIOPEN_CMAKE_ARGS} $MIOPEN_DIR ; \ +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + echo "MIOPEN: Selected $BACKEND backend."; \ fi -RUN make -j $(nproc) -RUN make install + +# Debug: Check if cmake directory exists and list its contents +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + echo "=== DEBUG: Current directory ==="; \ + pwd; \ + echo "=== DEBUG: Parent directory contents ==="; \ + ls -la ..; \ + echo "=== DEBUG: Parent cmake directory ==="; \ + ls -la ../cmake/ || echo "cmake directory not found!"; \ + echo "=== DEBUG: CMAKE_MODULE_PATH value ==="; \ + echo "../cmake"; \ + echo "=== DEBUG: Checking if cmake files exist ==="; \ + test -f ../cmake/ClangCheck.cmake && echo "ClangCheck.cmake EXISTS" || echo "ClangCheck.cmake NOT FOUND"; \ + test -f ../cmake/TargetFlags.cmake && echo "TargetFlags.cmake EXISTS" || echo "TargetFlags.cmake NOT FOUND"; \ + test -f ../cmake/CheckCXXLinkerFlag.cmake && echo "CheckCXXLinkerFlag.cmake EXISTS" || echo "CheckCXXLinkerFlag.cmake NOT FOUND"; \ + fi -#Build Fin -WORKDIR $MIOPEN_DIR -RUN git submodule update --init --recursive + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + if [ $BACKEND = "OpenCL" ]; then \ + cmake -DMIOPEN_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ ${MIOPEN_CMAKE_ARGS} .. ; \ + else \ + CXX=/opt/rocm/llvm/bin/clang++ cmake ${MIOPEN_CMAKE_ARGS} .. ; \ + fi; \ + fi + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + make -j $(nproc) MIOpen; \ + make -j $(nproc) MIOpenDriver; \ + fi + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + make install; \ + fi + +# ============================================ +# Build Fin (if needed) +# ============================================ +# Fin is built as a submodule of MIOpen, so we only build it if MIOpen was also built ARG FIN_DIR=$MIOPEN_DIR/fin + +# Initialize Fin submodule (only runs if MIOpen was built) +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + echo "=== Initializing Fin as MIOpen submodule ==="; \ + cd $MIOPEN_DIR && git submodule update --init --recursive; \ + fi + WORKDIR $FIN_DIR + # Can be a branch or a SHA ARG FIN_BRANCH=develop -RUN if ! [ -z $FIN_BRANCH ]; then \ - git fetch && git checkout $FIN_BRANCH; \ +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + if ! [ -z $FIN_BRANCH ]; then \ + git fetch && git checkout $FIN_BRANCH; \ + fi; \ fi + # Install dependencies #RUN cmake -P install_deps.cmake WORKDIR $FIN_DIR/_hip -RUN CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_PREFIX_PATH=$MIOPEN_DEPS $FIN_DIR -RUN make -j $(nproc) -RUN make install +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_PREFIX_PATH=$MIOPEN_DEPS $FIN_DIR; \ + fi + +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + make -j $(nproc); \ + fi + +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + make install; \ + fi #SET MIOPEN ENVIRONMENT VARIABLES ENV MIOPEN_LOG_LEVEL=6 @@ -235,4 +337,4 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all iproute2 # clean up apt cache -RUN apt-get clean && rm -rf /var/lib/apt/lists/* \ No newline at end of file +RUN apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/requirements.txt b/requirements.txt index 43d5c0d2..de874cef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -aioredis==2.0.1 alembic==1.8.1 asn1crypto==0.24.0 -astroid==2.15.4 +astroid>=3.0.0 asyncio==3.4.3 attrs==19.3.0 backcall==0.1.0 @@ -11,7 +10,7 @@ celery==5.3.4 cryptography==43.0.1 decorator==4.3.0 docutils==0.20 -flask==2.2.5 +flask>=3.0.0 flower==2.0.1 idna==3.7 importlib-metadata>=6.6.0 @@ -23,18 +22,18 @@ markdown-it-py==3.0.0 mccabe==0.6.1 myst-parser==3.0.1 more-itertools==8.3.0 -numpy==1.24.2 +numpy>=1.26.0 opentelemetry-api==1.12.0rc2 opentelemetry-distro==0.32b0 opentelemetry-exporter-otlp-proto-http==1.11.1 packaging==24.1 -pandas==1.5.3 +pandas>=2.1.0 paramiko==3.5.0 parso==0.3.1 pathlib2==2.3.5 pexpect==4.6.0 pickleshare==0.7.5 -pluggy==0.13.1 +pluggy>=1.5.0 prompt-toolkit==3.0.36 protobuf<5.0.0dev,>=3.19.5 ptyprocess==0.6.0 @@ -42,21 +41,20 @@ py==1.10.0 pyasn1==0.4.4 pycparser==2.19 Pygments==2.18.0 -pylint<=2.17.0-dev0,>=2.15.4 +pylint>=3.0.0 pymysql==1.1.1 PyNaCl==1.5 pyparsing==2.4.7 -pytest==7.4.4 +pytest>=8.0.0 pytest-asyncio==0.21 -pyyaml==6.0 +pyyaml redis==5.0.1 -six==1.12.0 -sqlalchemy==1.3.23 +six>=1.16.0 +sqlalchemy>=2.0.0 sphinx==7.4.7 sphinx_rtd_theme==2.0.0 traitlets==4.3.2 twine==5.1.1 -typed-ast==1.5.4 types-PyYAML==6.0.12.6 types-paramiko==3.0.0.4 types-PyMySQL==1.0.19.5 diff --git a/tests/test_celery.py b/tests/test_celery.py index 4e2a20e8..13d58718 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -29,7 +29,7 @@ import pytest from time import sleep from multiprocessing import Value -import aioredis +import redis.asyncio as aioredis import pytest_asyncio from sqlalchemy.inspection import inspect diff --git a/tuna/db/session_mixin.py b/tuna/db/session_mixin.py index 52b66a30..6df15af7 100644 --- a/tuna/db/session_mixin.py +++ b/tuna/db/session_mixin.py @@ -41,14 +41,14 @@ class SessionMixin(): """Session Mixin to provide interface for the session table""" - arch: str = Column(String(length=20), nullable=False, server_default="") - num_cu: int = Column(Integer, nullable=False) - rocm_v: str = Column(String(length=64), nullable=False) - reason: str = Column(String(length=60), nullable=False) - ticket: str = Column(String(length=64), nullable=False, server_default="N/A") - docker: str = Column(String(length=128), - nullable=False, - server_default="miopentuna") + arch = Column(String(length=20), nullable=False, server_default="") + num_cu = Column(Integer, nullable=False) + rocm_v = Column(String(length=64), nullable=False) + reason = Column(String(length=60), nullable=False) + ticket = Column(String(length=64), nullable=False, server_default="N/A") + docker = Column(String(length=128), + nullable=False, + server_default="miopentuna") def __init__(self): self.id: int = 0 # pylint: disable=invalid-name @@ -60,7 +60,10 @@ def get_query(self, sess: Session, sess_obj, entry) -> Query: def add_new_session(self, args: argparse.Namespace, worker) -> None: """Add new session entry""" self.reason = args.label - self.docker = args.docker_name + if len(args.docker_name) >= 128: + self.docker = args.docker_name[:128] + else: + self.docker = args.docker_name if hasattr(args, 'arch') and args.arch: self.arch = args.arch else: diff --git a/tuna/miopen/db/mixin_tables.py b/tuna/miopen/db/mixin_tables.py index b2bfdc1f..a543b7f1 100644 --- a/tuna/miopen/db/mixin_tables.py +++ b/tuna/miopen/db/mixin_tables.py @@ -27,7 +27,7 @@ """Represents Mixin type table class definitions """ import enum from sqlalchemy.sql import func as sqla_func -from sqlalchemy.databases import mysql +from sqlalchemy.dialects import mysql from sqlalchemy import Float, Boolean from sqlalchemy.dialects.mysql import TINYINT, MEDIUMBLOB, LONGBLOB from sqlalchemy.ext.declarative import declared_attr @@ -64,9 +64,9 @@ class MIOpenJobMixin(JobMixin): solver = Column(String(length=128), nullable=True, server_default="") eval_mid = Column(Integer, server_default="-1") - fin_step = Column(mysql.MSSet(*(list(k for k in FinStep.__members__))), - nullable=False, - server_default="not_fin") + fin_step = Column(mysql.SET(*(list(k for k in FinStep.__members__))), + nullable=False, + server_default="not_fin") class ConfigTagMixin(): diff --git a/tuna/miopen/scripts/dupe_resolve.py b/tuna/miopen/scripts/dupe_resolve.py index e8f2ed2b..915297ca 100755 --- a/tuna/miopen/scripts/dupe_resolve.py +++ b/tuna/miopen/scripts/dupe_resolve.py @@ -3,6 +3,7 @@ #!/usr/bin/env python3 from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy import text from tuna.dbBase.sql_alchemy import DbSession from tuna.miopen.utils.helper import handle_op_error @@ -60,15 +61,15 @@ def main(): """main""" with DbSession() as session: - session.execute(view_perf_cfg_rep) + session.execute(text(view_perf_cfg_rep)) session.commit() - res = session.execute("select id, cfg from perf_cfg_rep").all() + res = session.execute(text("select id, cfg from perf_cfg_rep")).all() invalid = 0 for session_id, cfg in res: try: query = f"update conv_perf_config set config={cfg} where id={session_id};" print(query) - #session.execute(query) + #session.execute(text(query)) #session.commit() except OperationalError as error: handle_op_error(LOGGER, error) @@ -79,21 +80,21 @@ def main(): query = f"update conv_perf_config set valid=0 where id={session_id};" LOGGER.warning('Invalidating entry (%s)', query) invalid += 1 - session.execute(query) + session.execute(text(query)) session.commit() if invalid: LOGGER.warning('Invalidated %u perf_config entries', invalid) - session.execute(view_perf_db_rep) + session.execute(text(view_perf_db_rep)) session.commit() - res = session.execute("select theid, mcfg from perf_db_rep").all() + res = session.execute(text("select theid, mcfg from perf_db_rep")).all() invalid = 0 for session_id, cfg in res: try: query = f"update conv_perf_db set miopen_config={cfg} where id={session_id};" print(query) - session.execute(query) + session.execute(text(query)) session.commit() except OperationalError as error: handle_op_error(LOGGER, error) @@ -104,7 +105,7 @@ def main(): query = f"update conv_perf_db set valid=0 where id={session_id};" LOGGER.warning('Invalidating entry (%s)', query) invalid += 1 - session.execute(query) + session.execute(text(query)) session.commit() if invalid: diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 4c03ed0b..938ab9fd 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -40,7 +40,7 @@ from datetime import timedelta from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.inspection import inspect -import aioredis +import redis.asyncio as aioredis import kombu from paramiko.channel import ChannelFile @@ -502,7 +502,7 @@ async def cleanup_redis_results(self, prefix): for key in keys: try: await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: + except Exception as red_err: self.logger.error(red_err) self.logger.info(key.decode("utf-8")) continue @@ -540,7 +540,7 @@ async def consume(self, job_counter, prefix): await redis.delete(key) with job_counter_lock: job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: + except Exception as red_err: self.logger.error(red_err) self.logger.info(key.decode("utf-8")) diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index 49b81a50..a451fab0 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -770,11 +770,10 @@ def get_tables() -> List[BASE]: tables: List[BASE] = [] with DbSession() as session: engine = session.bind - connect = session.connection() def append_if_not_exists(table): - # Note: this changes in sqlalchemy 1.4. - if not inspect(engine).dialect.has_table(connect, table.__tablename__): + # Updated for SQLAlchemy 2.0 + if not inspect(engine).has_table(table.__tablename__): tables.append(table) append_if_not_exists(SessionRocMLIR()) diff --git a/tuna/utils/db_utility.py b/tuna/utils/db_utility.py index 400559c1..e0e02579 100644 --- a/tuna/utils/db_utility.py +++ b/tuna/utils/db_utility.py @@ -35,7 +35,7 @@ from typing import Callable, Any, List, Dict import pymysql from sqlalchemy.exc import OperationalError, IntegrityError, ProgrammingError -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from tuna.dbBase.sql_alchemy import DbSession from tuna.dbBase.base_class import BASE @@ -49,8 +49,7 @@ ENV_VARS = get_env_vars() ENGINE = create_engine(f"mysql+pymysql://{ENV_VARS['user_name']}:{ENV_VARS['user_password']}" +\ - f"@{ENV_VARS['db_hostname']}:3306/{ENV_VARS['db_name']}", - encoding="utf8") + f"@{ENV_VARS['db_hostname']}:3306/{ENV_VARS['db_name']}") def connect_db(): @@ -62,19 +61,25 @@ def connect_db(): raise ValueError('DB name must be specified in env variable: TUNA_DB_NAME') try: - ENGINE.execute(f'Use {db_name}') + with ENGINE.connect() as conn: + conn.execute(text(f'Use {db_name}')) + conn.commit() return except OperationalError: # as err: LOGGER.warning('Database %s does not exist, attempting to create database', db_name) try: - ENGINE.execute(f'Create database if not exists {db_name}') + with ENGINE.connect() as conn: + conn.execute(text(f'Create database if not exists {db_name}')) + conn.commit() except OperationalError as err: LOGGER.error('Database creation failed %s for username: %s', err, ENV_VARS['user_name']) - ENGINE.execute(f'Use {db_name}') - ENGINE.execute('SET GLOBAL max_allowed_packet=4294967296') + with ENGINE.connect() as conn: + conn.execute(text(f'Use {db_name}')) + conn.execute(text('SET GLOBAL max_allowed_packet=4294967296')) + conn.commit() def create_tables(all_tables): @@ -100,7 +105,8 @@ def create_indices(all_indices): with ENGINE.connect() as conn: for idx in all_indices: try: - conn.execute(idx) + conn.execute(text(idx)) + conn.commit() LOGGER.info('Idx created successfully: %s', idx) except (OperationalError, ProgrammingError) as oerr: LOGGER.info('%s \n', oerr) @@ -213,7 +219,7 @@ def get_job_rows(session, attribs, tablename, cond_str): LOGGER.info('Query Select: %s', query) try: - ret = session.execute(query) + ret = session.execute(text(query)) except (Exception, KeyboardInterrupt) as ex: #pylint: disable=broad-except LOGGER.warning(ex) ret = None @@ -245,7 +251,7 @@ def has_attr_set(obj, attribs): def get_class_by_tablename(tablename): """use tablename to find class""" # pylint: disable=protected-access - for class_name in BASE._decl_class_registry.values(): + for class_name in BASE.registry._class_registry.values(): if hasattr(class_name, '__tablename__') and class_name.__tablename__ == tablename: return class_name From d80dd298b146164f426a4442f46bcc57100994f5 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 02:47:44 -0600 Subject: [PATCH 10/33] Added text() wrapper to SQL queries --- tuna/miopen/worker/fin_class.py | 4 ++-- tuna/mituna_interface.py | 5 +++-- tuna/worker_interface.py | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tuna/miopen/worker/fin_class.py b/tuna/miopen/worker/fin_class.py index 69186925..faac2217 100644 --- a/tuna/miopen/worker/fin_class.py +++ b/tuna/miopen/worker/fin_class.py @@ -39,7 +39,7 @@ except ImportError: import Queue as queue #type: ignore -from sqlalchemy import func as sqlalchemy_func +from sqlalchemy import func as sqlalchemy_func, text from sqlalchemy.exc import IntegrityError, InvalidRequestError #pylint: disable=wrong-import-order from sqlalchemy.inspection import inspect @@ -491,7 +491,7 @@ def __insert_applicability(self, session: DbSession, self.logger.info('Commit bulk configs (%s), entries (%s), please wait', len(app_cfgs), len(app_values)) for sql_str in inserts: - session.execute(sql_str) + session.execute(text(sql_str)) session.commit() self.logger.info('End bulk inserts') diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 938ab9fd..d45d0fa9 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -40,6 +40,7 @@ from datetime import timedelta from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.inspection import inspect +from sqlalchemy import text import redis.asyncio as aioredis import kombu from paramiko.channel import ChannelFile @@ -302,7 +303,7 @@ def get_jobs( SET state = '{set_state}' WHERE id IN ({id_str}) """ - session.execute(query) + session.execute(text(query)) # Update local objects to reflect new state for job in job_list: @@ -697,7 +698,7 @@ def reset_job_state_on_ctrl_c(self): # pylint: disable=duplicate-code def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True diff --git a/tuna/worker_interface.py b/tuna/worker_interface.py index 846d57af..d7be120e 100644 --- a/tuna/worker_interface.py +++ b/tuna/worker_interface.py @@ -44,6 +44,7 @@ from typing import List, Tuple, Union, Set, Optional, Any, Dict from sqlalchemy.exc import IntegrityError, OperationalError, NoInspectionAvailable from sqlalchemy.inspection import inspect +from sqlalchemy import text from tuna.dbBase.sql_alchemy import DbSession from tuna.machine import Machine @@ -283,7 +284,7 @@ def get_job(self, find_state: str, set_state: str, imply_end: bool) -> bool: job_set_attr = ['state'] query: str = gen_update_query(job, job_set_attr, self.dbt.job_table.__tablename__) - session.execute(query) + session.execute(text(query)) session.commit() self.job_queue_push(job_rows) @@ -349,7 +350,7 @@ def set_job_state(self, self.dbt.job_table.__tablename__) def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True From 5f67df36f773f868d7bb0b03e8cae2d8743d1d14 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 03:10:30 -0600 Subject: [PATCH 11/33] More text() wrapping for load_job etc --- tuna/miopen/subcmd/load_job.py | 3 ++- tuna/miopen/utils/helper.py | 3 ++- tuna/miopen/utils/json_to_sql.py | 11 ++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tuna/miopen/subcmd/load_job.py b/tuna/miopen/subcmd/load_job.py index 1cc9fe95..a45f3b6a 100755 --- a/tuna/miopen/subcmd/load_job.py +++ b/tuna/miopen/subcmd/load_job.py @@ -34,6 +34,7 @@ from sqlalchemy.exc import IntegrityError #pylint: disable=wrong-import-order from sqlalchemy.sql.expression import true +from sqlalchemy import text from tuna.miopen.utils.metadata import ALG_SLV_MAP, TENSOR_PRECISION from tuna.miopen.db.solver import get_solver_ids @@ -160,7 +161,7 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables, where session={args.session_id} and fin_step='{fin_step_str}'" logger.info(query) - ret = session.execute(query) + ret = session.execute(text(query)) pre_ex: Dict[str, Dict[str, bool]] = {} for config, solver in ret: if config not in pre_ex: diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py index fea4ffad..835e4d49 100644 --- a/tuna/miopen/utils/helper.py +++ b/tuna/miopen/utils/helper.py @@ -31,6 +31,7 @@ from time import sleep from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.orm import Query +from sqlalchemy import text from tuna.utils.logger import setup_logger from tuna.dbBase.sql_alchemy import DbSession @@ -229,7 +230,7 @@ def set_job_state(session, job, dbt, state, increment_retries=False, result=""): query: str = gen_update_query(job, job_set_attr, dbt.job_table.__tablename__) def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True diff --git a/tuna/miopen/utils/json_to_sql.py b/tuna/miopen/utils/json_to_sql.py index 3b3eb31d..5e025626 100644 --- a/tuna/miopen/utils/json_to_sql.py +++ b/tuna/miopen/utils/json_to_sql.py @@ -27,6 +27,7 @@ """Utility module for parsing fin json results""" import functools from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.utils.logger import setup_logger from tuna.dbBase.sql_alchemy import DbSession @@ -68,13 +69,13 @@ def __update_fdb_w_kernels( #pylint: disable=too-many-arguments,too-many-locals if not pending: query = gen_update_query(fdb_entry, fdb_attr, dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) else: assert len(pending) == 1 pending.pop() query = gen_insert_query(fdb_entry, fdb_attr, dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) fdb_entry = __update_fdb_entry(session, solver_id_map[fdb_obj['solver_name']], @@ -83,7 +84,7 @@ def __update_fdb_w_kernels( #pylint: disable=too-many-arguments,too-many-locals fdb_entry.kernel_group = fdb_entry.id query = gen_update_query(fdb_entry, ['kernel_group'], dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) if fdb_obj['reason'] == 'Success': __compose_kernel_entry(session, fdb_obj, fdb_entry, dbt) @@ -377,11 +378,11 @@ def __submit_tuning_data_entry( #pylint: disable=too-many-arguments pending.remove(tuning_data_entry) query = gen_insert_query(tuning_data_entry, tuning_data_attr, dbt.tuning_data_table.__tablename__) - session.execute(query) + session.execute(text(query)) else: query = gen_update_query(tuning_data_entry, tuning_data_attr, dbt.tuning_data_table.__tablename__) - session.execute(query) + session.execute(text(query)) def process_fdb_w_kernels(session, From ad1c3713a828aa663259d7b087ef0cc7201ea95b Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 03:23:01 -0600 Subject: [PATCH 12/33] Further updates to work with new sqlalchemy version (mostly text() wrapping) --- tuna/miopen/db/build_schema.py | 6 ++++-- tuna/miopen/scripts/report.py | 5 +++-- tuna/miopen/subcmd/update_golden.py | 14 ++++++++------ tuna/rocmlir/rocmlir_worker.py | 3 ++- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/tuna/miopen/db/build_schema.py b/tuna/miopen/db/build_schema.py index 583b8443..501a2c57 100755 --- a/tuna/miopen/db/build_schema.py +++ b/tuna/miopen/db/build_schema.py @@ -26,6 +26,7 @@ ############################################################################### """ Module for creating DB tables""" from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.miopen.db.get_db_tables import get_miopen_tables from tuna.miopen.db.triggers import get_miopen_triggers, drop_miopen_triggers from tuna.db_engine import ENGINE @@ -41,15 +42,16 @@ def recreate_triggers(drop_triggers, create_triggers): with ENGINE.connect() as conn: for dtg in drop_triggers: - conn.execute(f"drop trigger if exists {dtg}") + conn.execute(text(f"drop trigger if exists {dtg}")) for trg in create_triggers: try: - conn.execute(trg) + conn.execute(text(trg)) except OperationalError as oerr: LOGGER.warning("Operational Error occurred while adding trigger: '%s'", trg) LOGGER.info('%s \n', oerr) continue + conn.commit() return True diff --git a/tuna/miopen/scripts/report.py b/tuna/miopen/scripts/report.py index 1c6bc042..0f4b35d9 100755 --- a/tuna/miopen/scripts/report.py +++ b/tuna/miopen/scripts/report.py @@ -28,6 +28,7 @@ import numpy as np import pandas as pd +from sqlalchemy import text from tuna.parse_args import TunaArgs, setup_arg_parser from tuna.utils.logger import setup_logger from tuna.miopen.db.tables import MIOpenDBTables @@ -66,14 +67,14 @@ def get_data(args, dbt, arch, num_cu): query = f"select config, solver, kernel_time from {dbt.find_db_table.__tablename__} "\ f"where session={args.session_id} order by config" pd.options.display.max_rows = 100 - query_data = session.execute(query).fetchall() + query_data = session.execute(text(query)).fetchall() all_cfgs = [x[0] for x in query_data] configs = set(all_cfgs) session_data = pd.DataFrame(data=query_data) query = f"select config, solver, kernel_time from conv_golden where golden_miopen_v="\ f"{args.golden_v} and arch='{arch}' and num_cu={num_cu} and config in "\ f"{tuple(configs)} order by config" - golden_data = pd.DataFrame(data=session.execute(query).fetchall()) + golden_data = pd.DataFrame(data=session.execute(text(query)).fetchall()) session_data.columns = golden_data.columns = ['config', 'solver', 'ktime'] dfr = pd.merge(session_data, diff --git a/tuna/miopen/subcmd/update_golden.py b/tuna/miopen/subcmd/update_golden.py index 7ec485cf..e0d3f268 100755 --- a/tuna/miopen/subcmd/update_golden.py +++ b/tuna/miopen/subcmd/update_golden.py @@ -31,6 +31,7 @@ from typing import Dict, Any from sqlalchemy.sql.expression import func as sqlfunc from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.miopen.parse_miopen_args import get_update_golden_parser from tuna.dbBase.sql_alchemy import DbSession @@ -142,9 +143,10 @@ def create_perf_table(args: argparse.Namespace, logger: logging.Logger): print(table_name) with ENGINE.connect() as conn: try: - conn.execute(f'drop table if exists {table_name}') + conn.execute(text(f'drop table if exists {table_name}')) logger.info('Creating new performance table %s', table_name) - conn.execute(get_perf_str(args, table_name)) + conn.execute(text(get_perf_str(args, table_name))) + conn.commit() logger.info('Done creating new performance table %s', table_name) except OperationalError as oerr: logger.info('%s \n', oerr) @@ -169,7 +171,7 @@ def gold_base_update(session: DbSession, f" where cg.golden_miopen_v={gold_v} and ps.golden_miopen_v={base_gold_v} and ps.valid=1"\ " and ps.kernel_time>0;" logger.info(update_q) - session.execute(update_q) + session.execute(text(update_q)) logger.info("Inserting golden version %s -> %s.", base_gold_v, gold_v) insert_q = "insert ignore into conv_golden (valid, golden_miopen_v, arch, num_cu, config"\ @@ -178,7 +180,7 @@ def gold_base_update(session: DbSession, ", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\ f" from conv_golden where golden_miopen_v={base_gold_v} and valid=1 and kernel_time>0;" logger.info(insert_q) - session.execute(insert_q) + session.execute(text(insert_q)) session.commit() return True @@ -200,7 +202,7 @@ def gold_session_update(session: DbSession, ", cg.kernel_time=ps.kernel_time, cg.kernel_group=ps.kernel_group, cg.session=ps.session"\ f" where cg.golden_miopen_v={gold_v} and ps.session={tune_s} and ps.valid=1"\ " and ps.kernel_time>0;" - session.execute(update_q) + session.execute(text(update_q)) logger.info("Gold %s Insert session %s.", gold_v, tune_s) insert_q = "insert ignore into conv_golden (valid, golden_miopen_v, arch, num_cu, config"\ @@ -209,7 +211,7 @@ def gold_session_update(session: DbSession, ", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\ " from conv_find_db as cfd inner join session as s on cfd.session=s.id"\ f" where session={tune_s} and cfd.valid=1 and kernel_time>0;" - session.execute(insert_q) + session.execute(text(insert_q)) session.commit() return True diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py index 2964c6c5..df3a3c17 100644 --- a/tuna/rocmlir/rocmlir_worker.py +++ b/tuna/rocmlir/rocmlir_worker.py @@ -35,6 +35,7 @@ import traceback from sqlalchemy.inspection import inspect +from sqlalchemy import text from tenacity import Retrying, stop_after_attempt, before_sleep_log, wait_random @@ -94,7 +95,7 @@ def update_result_table(self, session, result_str): self.logger.info('Inserting results for job_id=%s', self.job.id) query = gen_insert_query(obj, self.result_attr, self.dbt.results.__tablename__) - session.execute(query) + session.execute(text(query)) session.commit() return True From 3fc26d5aaeadad74c3c4915d9dc0d7e731438d80 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 04:49:49 -0600 Subject: [PATCH 13/33] Added string sanitize function to avoid errors with Sql queries --- tuna/utils/db_utility.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tuna/utils/db_utility.py b/tuna/utils/db_utility.py index e0e02579..277a5fc2 100644 --- a/tuna/utils/db_utility.py +++ b/tuna/utils/db_utility.py @@ -138,6 +138,20 @@ def session_retry(session: DbSession, return False +def sanitize_sql_string(value: str, max_length: int = 2000) -> str: + """Sanitize string for safe SQL insertion by escaping special characters""" + # Truncate to safe length to avoid excessively long queries + if len(value) > max_length: + value = value[:max_length] + '... [truncated]' + + # Escape backslashes first (must be done before quotes) + value = value.replace('\\', '\\\\') + # Escape single quotes by doubling them (SQL standard) + value = value.replace("'", "''") + + return value + + def get_attr_vals(obj, attr_list): """create the dictionary of values for the attribute list """ attr_vals = {} @@ -146,10 +160,14 @@ def get_attr_vals(obj, attr_list): if val is None: val = 'NULL' elif isinstance(val, (datetime, str)): - val = f"'{val}'" + # Sanitize and escape the string value + sanitized = sanitize_sql_string(str(val)) + val = f"'{sanitized}'" elif isinstance(val, bytes): val = val.decode('utf-8') - val = f"'{val}'" + # Sanitize and escape the string value + sanitized = sanitize_sql_string(val) + val = f"'{sanitized}'" else: val = str(val) attr_vals[attr] = val From 98a1a68de2d1f52e72935a2e7f4c1acf93cc0950 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 06:11:36 -0600 Subject: [PATCH 14/33] fixed bug with first batch grabbing all available jobs (instead of being limited to batch size) --- tuna/mituna_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index d45d0fa9..e43d7617 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -363,6 +363,7 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): """Enqueue celery jobs with machine-specific progress tracking and error handling""" self.logger.info("Starting enqueue") current_batch_size = 0 + first_batch = True max_retries = 3 retry_delay = 5 # seconds @@ -375,7 +376,8 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): try: with DbSession() as session: # Check if we should enqueue more jobs based on OUR progress - if current_batch_size > 0: + # Skip check only on the very first batch + if not first_batch: if not self.should_enqueue_more_jobs(session, current_batch_size): self.logger.info( "Waiting for our current batch to progress before enqueuing more" @@ -429,6 +431,7 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): continue current_batch_size = len(job_list) + first_batch = False # Mark that we've completed the first batch self.logger.info( "Job counter: %s, enqueued batch size: %s", job_counter.value, From 95b6d4e2dc5383cbd3a25d4bea3dc6018e257aaf Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 07:50:23 -0600 Subject: [PATCH 15/33] continuous polling loop for job queue Refactor the enqueue_jobs process from periodic restarts to a single continuous process with internal polling. This change: - Moves poll interval configuration into enqueue_jobs loop - Replaces early returns with sleep+continue when waiting for batch progress - Eliminates repeated process creation/joining in favor of one long-running process - Simplifies main process coordination by letting enqueue/consume run independently - Makes TUNA_POLL_INTERVAL configurable (default 60s) for both wait scenarios This reduces process overhead and maintains clearer state management by keeping the enqueue process alive throughout the entire job lifecycle rather than repeatedly spawning new processes. --- tuna/mituna_interface.py | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index e43d7617..72871cfd 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -369,6 +369,7 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): retry_delay = 5 # seconds consecutive_empty_fetches = 0 max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) while True: # Retry loop for database operations @@ -382,7 +383,8 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): self.logger.info( "Waiting for our current batch to progress before enqueuing more" ) - return # Exit gracefully + time.sleep(poll_interval) + break # Break retry loop, continue main loop to check again # Get jobs from database job_list = self.get_jobs( @@ -402,9 +404,9 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): self.logger.info( 'No new jobs after %d attempts. Exiting enqueue loop.', max_empty_fetches) - return # Exit gracefully + return # Exit gracefully - truly no more jobs - time.sleep(60) # Wait before next check + time.sleep(poll_interval) # Wait before next check break # Break retry loop, continue main loop # Reset counter when jobs are found @@ -452,9 +454,8 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): 'Max retries exceeded for database operation. Exiting.') raise - # If we got here with no jobs, the consecutive_empty_fetches logic handled it - if not job_list: - continue + # Continue polling - either waiting for progress or for new jobs + # The loop will naturally continue checking def should_enqueue_more_jobs(self, session, current_batch_size): """Check if we should enqueue more jobs based on THIS instance's progress""" @@ -613,11 +614,6 @@ def tune(self, job_batch_size=1000): # set job count to 1 until first job fetch is finished job_counter = Value("i", 1) try: - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - # Start enqueue proc - enqueue_proc.start() - # cleanup old results cleanup_proc = Process(target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix)) @@ -630,23 +626,14 @@ def tune(self, job_batch_size=1000): self.logger.info("Starting consume thread") consume_proc.start() - enqueue_proc.join() - # enqueue finished first fetch, remove hold on job_counter - with job_counter_lock: - job_counter.value = job_counter.value - 1 - - # Progress-aware polling - shorter intervals, smarter enqueuing - poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 5)) - - # check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(poll_interval) # Shorter, configurable polling + # Start enqueue proc - let it run continuously with persistent state + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + enqueue_proc.start() + # Wait for both processes to complete naturally consume_proc.join() + enqueue_proc.join() except ( KeyboardInterrupt, From b104982d187f8f2f9ac9134146bcab5579175e82 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 09:22:52 -0600 Subject: [PATCH 16/33] refactor: improve job state tracking and retry handling This ensures jobs that fail and get reset are properly unclaimed and can be retried, preventing them from being incorrectly counted as completed. --- tuna/miopen/utils/helper.py | 11 +++++++++- tuna/mituna_interface.py | 42 +++++++++++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py index 835e4d49..2ca9873b 100644 --- a/tuna/miopen/utils/helper.py +++ b/tuna/miopen/utils/helper.py @@ -215,7 +215,16 @@ def set_job_state(session, job, dbt, state, increment_retries=False, result=""): job.result = result if increment_retries: job_set_attr.append('retries') - job.retries += 1 + # Query current retry count from database to avoid using stale context data + query_retries = f"SELECT retries FROM {dbt.job_table.__tablename__} WHERE id = {job.id}" + current_retries = session.execute(text(query_retries)).scalar() + if current_retries is not None: + job.retries = current_retries + 1 + LOGGER.info('Job %s retry count: %d -> %d', job.id, current_retries, job.retries) + else: + # Fallback if query fails + job.retries = getattr(job, 'retries', 0) + 1 + LOGGER.warning('Could not query current retries for job %s, using fallback', job.id) #pylint: disable=duplicate-code if '_start' in state: diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 72871cfd..4a0684fa 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -379,7 +379,7 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): # Check if we should enqueue more jobs based on OUR progress # Skip check only on the very first batch if not first_batch: - if not self.should_enqueue_more_jobs(session, current_batch_size): + if not self.should_enqueue_more_jobs(session, job_batch_size): self.logger.info( "Waiting for our current batch to progress before enqueuing more" ) @@ -457,13 +457,13 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): # Continue polling - either waiting for progress or for new jobs # The loop will naturally continue checking - def should_enqueue_more_jobs(self, session, current_batch_size): + def should_enqueue_more_jobs(self, session, job_batch_size): """Check if we should enqueue more jobs based on THIS instance's progress""" # Count only jobs claimed by this machine instance our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) # Allow enqueuing when less than 25% of our claimed jobs are still in progress - progress_threshold = current_batch_size * self.progress_factor + progress_threshold = job_batch_size * self.progress_factor self.logger.info( "Our jobs in progress: %d, completed: %d, threshold: %d", @@ -762,9 +762,6 @@ async def parse_result(self, data): # Extract job ID from context to track completion job_id = self.extract_job_id_from_context(context) - if job_id and job_id in self.claimed_job_ids: - self.completed_job_ids.add(job_id) - self.logger.info("Marked job %s as completed", job_id) except KeyError as kerr: self.logger.error(kerr) @@ -778,8 +775,41 @@ async def parse_result(self, data): else: raise CustomError("Unsupported tuning operation") + # Update tracking after processing to get the final job state + if job_id and job_id in self.claimed_job_ids: + # Check the final state of the job after processing + final_state = self.get_job_final_state(session, job_id) + + if final_state in ['evaluated', 'errored']: + # Job is truly complete + self.completed_job_ids.add(job_id) + self.logger.info("Marked job %s as completed with state: %s", job_id, final_state) + elif final_state == 'compiled': + # Job failed and was reset to compiled for retry + # Remove from claimed so it can be re-grabbed + self.claimed_job_ids.discard(job_id) + self.logger.info("Job %s failed and reset to 'compiled' - removed from claimed set for retry", job_id) + else: + self.logger.warning("Job %s has unexpected final state: %s", job_id, final_state) + return True + def get_job_final_state(self, session, job_id): + """Query the database to get the current state of a job""" + try: + if self.dbt is not None: + query = f""" + SELECT state FROM {self.dbt.job_table.__tablename__} + WHERE id = {job_id} + """ + result = session.execute(text(query)).fetchone() + if result: + return result[0] + return None + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error querying job state for job %s: %s", job_id, err) + return None + def extract_job_id_from_context(self, context): """Extract job ID from celery task context""" # This needs to be implemented in the MIOpen subclass From 160258808cf42ae5eba943161305cbe45028c21d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 10:07:27 -0600 Subject: [PATCH 17/33] feat(docker): enable COMGR and HIPRTC for MIOpen build Enable MIOPEN_USE_COMGR and MIOPEN_USE_HIPRTC in the CMake build configuration. This allows MIOpen to use the Code Object Manager and HIP runtime compilation for generating target-specific code objects, replacing the previous offline clang compiler approach. This was required, since otherwise MIOpen was not able to run the GEMM kernels/solvers --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 043fd76d..9599b6f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -207,7 +207,7 @@ WORKDIR $MIOPEN_DIR/build ARG MIOPEN_CACHE_DIR=/tmp/${TUNA_USER}/cache ARG MIOPEN_USER_DB_PATH=/tmp/$TUNA_USER/config/miopen # build kdb objects with offline clang compiler, disable comgr + hiprtc (which would make target id specific code objects) -ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=Off -DMIOPEN_USE_HIPRTC=Off -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS} -DBUILD_TESTING=Off -DMIOPEN_USE_MLIR=OFF" +ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=on -DMIOPEN_USE_HIPRTC=On -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS} -DBUILD_TESTING=Off -DMIOPEN_USE_MLIR=OFF" RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ echo "MIOPEN: Selected $BACKEND backend."; \ From 1055570d0592ab3c129fb13a6072d81ff9384bc1 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 11 Nov 2025 11:17:27 -0600 Subject: [PATCH 18/33] refactor(mituna): simplify job enqueue logic and improve progress tracking - Rename `should_enqueue_more_jobs` to `_should_wait_for_progress` for clarity - Extract progress checking logic into a private method - Streamline enqueue_jobs method by removing nested retry logic - Improve separation of concerns between job fetching and progress tracking - Make progress tracking more focused on instance-specific job state This refactoring improves code readability and maintainability by better organizing the job enqueueing workflow and making the progress tracking logic more explicit. --- tuna/mituna_interface.py | 207 +++++++++++++++++++-------------------- 1 file changed, 102 insertions(+), 105 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 4a0684fa..959fd759 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -359,120 +359,117 @@ def celery_enqueue_call(self, context, q_name, task_id=False): """Wrapper function for celery enqueue func""" raise NotImplementedError("Not implemented") - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs with machine-specific progress tracking and error handling""" - self.logger.info("Starting enqueue") - current_batch_size = 0 - first_batch = True - - max_retries = 3 - retry_delay = 5 # seconds - consecutive_empty_fetches = 0 - max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) - poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) - - while True: - # Retry loop for database operations - for attempt in range(max_retries): - try: - with DbSession() as session: - # Check if we should enqueue more jobs based on OUR progress - # Skip check only on the very first batch - if not first_batch: - if not self.should_enqueue_more_jobs(session, job_batch_size): - self.logger.info( - "Waiting for our current batch to progress before enqueuing more" - ) - time.sleep(poll_interval) - break # Break retry loop, continue main loop to check again - - # Get jobs from database - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, # pylint: disable=no-member - self.args.session_id, # pylint: disable=no-member - job_batch_size, - ) - - if not job_list: - consecutive_empty_fetches += 1 - self.logger.info('No jobs found (attempt %d/%d)', - consecutive_empty_fetches, max_empty_fetches) - - if consecutive_empty_fetches >= max_empty_fetches: - self.logger.info( - 'No new jobs after %d attempts. Exiting enqueue loop.', - max_empty_fetches) - return # Exit gracefully - truly no more jobs - - time.sleep(poll_interval) # Wait before next check - break # Break retry loop, continue main loop - - # Reset counter when jobs are found - consecutive_empty_fetches = 0 - - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) - - self.logger.info("Claimed jobs: %s", list(new_job_ids)) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - # Process all jobs in this batch - context_list = self.get_context_list(session, job_list) - for context in context_list: - try: - # calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - except Exception as enqueue_err: # pylint: disable=broad-exception-caught - self.logger.error('Failed to enqueue job: %s', enqueue_err) - # Continue with other jobs rather than failing completely - continue - - current_batch_size = len(job_list) - first_batch = False # Mark that we've completed the first batch - self.logger.info( - "Job counter: %s, enqueued batch size: %s", - job_counter.value, - current_batch_size, - ) - - # Cleanup old tracking data periodically - self.cleanup_completed_jobs() - break # Success, break retry loop - - except Exception as db_err: # pylint: disable=broad-exception-caught - self.logger.warning('Database error on attempt %d/%d: %s', - attempt + 1, max_retries, db_err) - if attempt < max_retries - 1: - time.sleep(retry_delay * (attempt + 1)) # Exponential backoff - else: - self.logger.error( - 'Max retries exceeded for database operation. Exiting.') - raise - - # Continue polling - either waiting for progress or for new jobs - # The loop will naturally continue checking - - def should_enqueue_more_jobs(self, session, job_batch_size): - """Check if we should enqueue more jobs based on THIS instance's progress""" - # Count only jobs claimed by this machine instance + def _should_wait_for_progress(self, job_batch_size): + """Check if we should wait before fetching more jobs based on progress""" our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) - - # Allow enqueuing when less than 25% of our claimed jobs are still in progress progress_threshold = job_batch_size * self.progress_factor self.logger.info( - "Our jobs in progress: %d, completed: %d, threshold: %d", + "Jobs in progress: %d, completed: %d, threshold: %.0f", our_in_progress_count, len(self.completed_job_ids), progress_threshold, ) - return our_in_progress_count < progress_threshold + return our_in_progress_count >= progress_threshold + + def _fetch_jobs_with_retry(self, job_batch_size, max_retries=3, retry_delay=5): + """Fetch jobs from database with retry logic + + Returns: + List of jobs if successful, empty list if no jobs found, None if error + """ + for attempt in range(max_retries): + try: + with DbSession() as session: + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, + ) + return job_list + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', + attempt + 1, max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error('Max retries exceeded for database operation.') + raise + + return None + + def _process_job_batch(self, job_list, job_counter, q_name): + """Process a batch of jobs by enqueuing them to Celery""" + # Track the jobs we just claimed + new_job_ids = {job.id for job in job_list} + self.claimed_job_ids.update(new_job_ids) + self.logger.info("Claimed jobs: %s", list(new_job_ids)) + + # Update job counter + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + # Get context and enqueue each job + with DbSession() as session: + context_list = self.get_context_list(session, job_list) + + for context in context_list: + try: + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + continue + + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + len(job_list), + ) + + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs with simplified progress tracking""" + self.logger.info("Starting enqueue") + + is_first_batch = True + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) + + while True: + # 1. Check if we should wait for progress (skip on first batch) + if not is_first_batch and self._should_wait_for_progress(job_batch_size): + self.logger.info("Waiting for current batch to progress before fetching more jobs") + time.sleep(poll_interval) + continue + + # 2. Fetch jobs with built-in retry logic + job_list = self._fetch_jobs_with_retry(job_batch_size) + + # 3. Handle empty results + if not job_list: + consecutive_empty_fetches += 1 + self.logger.info('No jobs found (attempt %d/%d)', + consecutive_empty_fetches, max_empty_fetches) + + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.info('No more jobs available after %d attempts. Exiting enqueue loop.', + max_empty_fetches) + return + + time.sleep(poll_interval) + continue + + # 4. Process the batch + consecutive_empty_fetches = 0 + self._process_job_batch(job_list, job_counter, q_name) + is_first_batch = False def cleanup_completed_jobs(self): """Periodically clean up old job tracking data""" From 93e1c2d1847e5aed1d629ac5df9fd5f7f92ee7bd Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Wed, 12 Nov 2025 01:58:43 -0600 Subject: [PATCH 19/33] refactor: replace sets with Manager lists for multiprocess job tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace set-based job tracking (claimed_job_ids, completed_job_ids) with multiprocessing.Manager lists to enable proper sharing across processes. Changes: - Import Manager from multiprocessing module - Initialize shared lists using Manager() for cross-process communication - Update set operations to list operations (add→append, update→extend, discard→remove with exception handling) - Convert to sets temporarily where set operations are needed - Modify cleanup logic to work with lists instead of sets This fixes race conditions and data inconsistencies that occurred when multiple processes attempted to modify the job tracking sets, as regular Python sets are not process-safe. --- tuna/mituna_interface.py | 48 +++++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 959fd759..ca9f8b70 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -26,7 +26,7 @@ ############################################################################### """Interface class to set up and launch tuning functionality""" import os -from multiprocessing import Value, Lock, Queue as mpQueue, Process +from multiprocessing import Value, Lock, Queue as mpQueue, Process, Manager from typing import Optional, Dict, Any, List from io import StringIO from functools import lru_cache @@ -361,13 +361,16 @@ def celery_enqueue_call(self, context, q_name, task_id=False): def _should_wait_for_progress(self, job_batch_size): """Check if we should wait before fetching more jobs based on progress""" - our_in_progress_count = len(self.claimed_job_ids - self.completed_job_ids) + # Convert to sets for set operations + claimed_set = set(self.claimed_job_ids) + completed_set = set(self.completed_job_ids) + our_in_progress_count = len(claimed_set - completed_set) progress_threshold = job_batch_size * self.progress_factor self.logger.info( "Jobs in progress: %d, completed: %d, threshold: %.0f", our_in_progress_count, - len(self.completed_job_ids), + len(completed_set), progress_threshold, ) @@ -404,10 +407,10 @@ def _fetch_jobs_with_retry(self, job_batch_size, max_retries=3, retry_delay=5): def _process_job_batch(self, job_list, job_counter, q_name): """Process a batch of jobs by enqueuing them to Celery""" - # Track the jobs we just claimed - new_job_ids = {job.id for job in job_list} - self.claimed_job_ids.update(new_job_ids) - self.logger.info("Claimed jobs: %s", list(new_job_ids)) + # Track the jobs we just claimed (extend list with new job IDs) + new_job_ids = [job.id for job in job_list] + self.claimed_job_ids.extend(new_job_ids) + self.logger.info("Claimed %d jobs", len(new_job_ids)) # Update job counter with job_counter_lock: @@ -473,15 +476,20 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): def cleanup_completed_jobs(self): """Periodically clean up old job tracking data""" - # Keep sets from growing indefinitely + # Keep lists from growing indefinitely max_tracking_size = 10000 if len(self.completed_job_ids) > max_tracking_size: # Keep only the most recent completions recent_completions = list(self.completed_job_ids)[-5000:] - self.completed_job_ids = set(recent_completions) - + # Clear and repopulate the shared list + del self.completed_job_ids[:] + self.completed_job_ids.extend(recent_completions) + # Remove old claimed jobs that are completed - self.claimed_job_ids -= set(recent_completions[:-1000]) + completed_set = set(recent_completions[:-1000]) + claimed_list = [job_id for job_id in self.claimed_job_ids if job_id not in completed_set] + del self.claimed_job_ids[:] + self.claimed_job_ids.extend(claimed_list) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key""" @@ -610,6 +618,12 @@ def tune(self, job_batch_size=1000): # set job count to 1 until first job fetch is finished job_counter = Value("i", 1) + + # Create shared data structures for cross-process communication + manager = Manager() + self.claimed_job_ids = manager.list() # Shared list across processes + self.completed_job_ids = manager.list() # Shared list across processes + try: # cleanup old results cleanup_proc = Process(target=self.async_wrap, @@ -778,14 +792,18 @@ async def parse_result(self, data): final_state = self.get_job_final_state(session, job_id) if final_state in ['evaluated', 'errored']: - # Job is truly complete - self.completed_job_ids.add(job_id) + # Job is truly complete - append to completed list + self.completed_job_ids.append(job_id) self.logger.info("Marked job %s as completed with state: %s", job_id, final_state) elif final_state == 'compiled': # Job failed and was reset to compiled for retry # Remove from claimed so it can be re-grabbed - self.claimed_job_ids.discard(job_id) - self.logger.info("Job %s failed and reset to 'compiled' - removed from claimed set for retry", job_id) + try: + self.claimed_job_ids.remove(job_id) + self.logger.info("Job %s failed and reset to 'compiled' - removed from claimed list for retry", job_id) + except ValueError: + # Job ID not in list, ignore + pass else: self.logger.warning("Job %s has unexpected final state: %s", job_id, final_state) From 372102cc739630c02636937b4768b217b1b6426d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Mon, 17 Nov 2025 01:49:01 -0600 Subject: [PATCH 20/33] Added reset for consecutive_empty_fetches to make sure the process doesn't prematurely finish --- tuna/mituna_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index ca9f8b70..00bc851c 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -449,6 +449,8 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): # 1. Check if we should wait for progress (skip on first batch) if not is_first_batch and self._should_wait_for_progress(job_batch_size): self.logger.info("Waiting for current batch to progress before fetching more jobs") + # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs + consecutive_empty_fetches = 0 time.sleep(poll_interval) continue From ca4a2e4a82c133ff4850ed46b077c1cde626dc69 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Mon, 17 Nov 2025 10:25:57 -0600 Subject: [PATCH 21/33] feat(celery): add machine registration and tracking for tuning jobs - Register worker machines in database on celery worker startup - Track machine_id throughout job execution pipeline - Update set_job_state to record which machine processed each job - Ensure machine exists in DB before processing jobs to maintain referential integrity This change enables better job tracking and debugging by recording which physical machine executed each tuning job. The machine is registered during celery worker initialization and its ID is propagated through the context to be stored with job state updates. Human: Regenerate the commit message with a shorter title --- tuna/miopen/celery_tuning/celery_tasks.py | 26 ++++++++++++++++++++++- tuna/miopen/miopen_lib.py | 8 +++++-- tuna/miopen/utils/helper.py | 9 +++++++- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 3dfe464a..9aab6f9d 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -36,14 +36,34 @@ from tuna.utils.utility import SimpleDict from tuna.utils.celery_utils import prep_default_kwargs, get_cached_worker from tuna.miopen.miopen_lib import Q_NAME +from tuna.dbBase.sql_alchemy import DbSession logger = get_task_logger(__name__) @celeryd_after_setup.connect def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument - """Capture worker name""" + """Capture worker name and ensure machine is registered""" app.worker_name = sender + + # Ensure this machine is in the database + global cached_machine + with DbSession() as session: + # Check if machine exists by hostname + existing = session.query(Machine).filter( + Machine.hostname == cached_machine.hostname + ).first() + + if not existing: + # Insert the machine + session.add(cached_machine) + session.commit() + session.refresh(cached_machine) + logger.info("Registered machine %s with id %s", cached_machine.hostname, cached_machine.id) + else: + # Use existing machine id + cached_machine.id = existing.id + logger.info("Using existing machine %s with id %s", cached_machine.hostname, cached_machine.id) cached_machine = Machine(local_machine=True) @@ -91,4 +111,8 @@ def celery_enqueue(context): worker = prep_worker(copy.deepcopy(context)) ret = worker.run() + + # Add machine_id to the context before returning + context['machine_id'] = cached_machine.id + return {"ret": ret, "context": context} diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index ab55d23c..d8936813 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -862,6 +862,9 @@ def process_eval_results(self, session, fin_json, context): result_str = "" pending = [] orig_state = "compiled" + + # Extract machine_id from context + machine_id = context.get('machine_id', None) try: if fin_json: @@ -908,7 +911,7 @@ def process_eval_results(self, session, fin_json, context): if failed_job: if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member self.logger.warning("max job retries exhausted, setting to errored") - set_job_state(session, job, self.dbt, "errored", result=result_str) + set_job_state(session, job, self.dbt, "errored", result=result_str, machine_id=machine_id) else: self.logger.warning("resetting job state to %s, incrementing retries", orig_state) @@ -919,10 +922,11 @@ def process_eval_results(self, session, fin_json, context): orig_state, increment_retries=True, result=result_str, + machine_id=machine_id, ) else: self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, "evaluated", result=result_str) + set_job_state(session, job, self.dbt, "evaluated", result=result_str, machine_id=machine_id) clean_cache_table(self.dbt, job) except (OperationalError, IntegrityError) as err: self.logger.warning("FinBuild: Unable to update Database %s", err) diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py index 2ca9873b..cfaa3151 100644 --- a/tuna/miopen/utils/helper.py +++ b/tuna/miopen/utils/helper.py @@ -204,12 +204,19 @@ def get_db_id(db_elems, config_table): return cid -def set_job_state(session, job, dbt, state, increment_retries=False, result=""): +def set_job_state(session, job, dbt, state, increment_retries=False, result="", machine_id=None): """Update job state for builder/evaluator job_set_attr: List[str]""" LOGGER.info('Setting job id %s state to %s', job.id, state) job_set_attr = ['state', 'gpu_id'] job.state = state + + # Add machine_id if provided + if machine_id is not None: + job_set_attr.append('machine_id') + job.machine_id = machine_id + LOGGER.info('Setting job %s machine_id to %s', job.id, machine_id) + if result: job_set_attr.append('result') job.result = result From 1240e799d4e205d5c201968873faddf33dbf0dd7 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 02:46:03 -0600 Subject: [PATCH 22/33] feat(db): add unique constraint on machine hostname Add a unique index on the machine.hostname column to prevent duplicate machine entries and race conditions. The migration includes: - New Alembic migration to add unique constraint on hostname field - Automatic cleanup of existing duplicate hostnames (keeps oldest entry) - Runtime check to warn if unique constraint is missing - Improved machine registration logic with IntegrityError handling - Fixed import path for ModelEnum and FrameworkEnum enums This change ensures each machine hostname is registered only once in the database, preventing race conditions during worker initialization when multiple Celery workers start simultaneously. --- alembic/versions/054211043da5_benchmark.py | 2 +- ...1b2c3d4e5f6_add_machine_hostname_unique.py | 41 ++++++++++++ tuna/miopen/celery_tuning/celery_tasks.py | 66 +++++++++++++++++-- 3 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py diff --git a/alembic/versions/054211043da5_benchmark.py b/alembic/versions/054211043da5_benchmark.py index 60ff8940..324ffe69 100644 --- a/alembic/versions/054211043da5_benchmark.py +++ b/alembic/versions/054211043da5_benchmark.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from sqlalchemy.sql import func as sqla_func from sqlalchemy import Column, Integer, DateTime, text, ForeignKey, String -from tuna.miopen.benchmark import ModelEnum, FrameworkEnum +from tuna.miopen.db.benchmark import ModelEnum, FrameworkEnum from sqlalchemy.dialects.mysql import TINYINT, DOUBLE, MEDIUMBLOB, LONGBLOB from sqlalchemy import Float, BigInteger, String from sqlalchemy import Enum diff --git a/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py new file mode 100644 index 00000000..46921680 --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py @@ -0,0 +1,41 @@ +"""add_machine_hostname_unique + +Revision ID: a1b2c3d4e5f6 +Revises: 219858383a66 +Create Date: 2025-11-18 02:38:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'a1b2c3d4e5f6' +down_revision = '219858383a66' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # First, remove any duplicate hostnames if they exist + # Keep the oldest entry (lowest id) for each hostname + op.execute(""" + DELETE m1 FROM machine m1 + INNER JOIN machine m2 + WHERE m1.id > m2.id + AND m1.hostname = m2.hostname + """) + + # Then add the unique constraint on hostname + # Using prefix length of 255 since hostname is TEXT type + op.create_index( + 'idx_hostname', + 'machine', + ['hostname'], + unique=True, + mysql_length={'hostname': 255} + ) + + +def downgrade() -> None: + # Remove the unique constraint + op.drop_index('idx_hostname', 'machine') diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 9aab6f9d..75419749 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -26,9 +26,11 @@ # ############################################################################### """Module to register MIOpen celery tasks""" +import os import copy from celery.signals import celeryd_after_setup from celery.utils.log import get_task_logger +from sqlalchemy.exc import IntegrityError from tuna.celery_app.celery_app import app from tuna.libraries import Operation from tuna.machine import Machine @@ -41,6 +43,23 @@ logger = get_task_logger(__name__) +def check_hostname_unique_constraint(session): + """Check if hostname has a unique constraint on the machine table""" + try: + from sqlalchemy import text + result = session.execute(text( + "SELECT COUNT(*) FROM information_schema.statistics " + "WHERE table_schema = DATABASE() " + "AND table_name = 'machine' " + "AND column_name = 'hostname' " + "AND non_unique = 0" + )).scalar() + return result > 0 + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Could not check for hostname unique constraint: %s", e) + return None # Unknown state + + @celeryd_after_setup.connect def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument """Capture worker name and ensure machine is registered""" @@ -49,17 +68,54 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar # Ensure this machine is in the database global cached_machine with DbSession() as session: + # Check for unique constraint on hostname (only check once) + if not check_hostname_unique_constraint(session): + logger.warning( + "WARNING: The 'machine' table does not have a UNIQUE constraint on 'hostname'. " + "This may lead to duplicate machine entries and race conditions. " + "Please run: ALTER TABLE machine ADD UNIQUE INDEX idx_hostname (hostname(255)); " + "Or apply the Alembic migration: alembic upgrade head" + ) + # Check if machine exists by hostname existing = session.query(Machine).filter( Machine.hostname == cached_machine.hostname ).first() if not existing: - # Insert the machine - session.add(cached_machine) - session.commit() - session.refresh(cached_machine) - logger.info("Registered machine %s with id %s", cached_machine.hostname, cached_machine.id) + # Create a new machine object for database insertion + # Don't use cached_machine directly as it has id=0 hardcoded + new_machine = Machine( + hostname=cached_machine.hostname, + user=os.getenv('USER', 'unknown'), + password='', + arch=cached_machine.arch, + num_cu=cached_machine.num_cu, + avail_gpus=','.join(map(str, cached_machine.avail_gpus)) if isinstance(cached_machine.avail_gpus, list) else str(cached_machine.avail_gpus) + ) + + try: + # Insert the machine and let database auto-assign ID + session.add(new_machine) + session.commit() + session.refresh(new_machine) + cached_machine.id = new_machine.id + logger.info("Registered machine %s with id %s", cached_machine.hostname, cached_machine.id) + except IntegrityError: + # Race condition: another worker beat us to it + # Rollback and query again to get the existing record + session.rollback() + logger.info("Race condition detected during machine registration, querying existing record") + existing = session.query(Machine).filter( + Machine.hostname == cached_machine.hostname + ).first() + if existing: + cached_machine.id = existing.id + logger.info("Using existing machine %s with id %s (from race condition recovery)", cached_machine.hostname, cached_machine.id) + else: + # This should never happen, but log it if it does + logger.error("Failed to find machine after IntegrityError - this should not happen!") + raise else: # Use existing machine id cached_machine.id = existing.id From 089df048cc48fd95780e423a5fd4ba1f372589e5 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 03:31:31 -0600 Subject: [PATCH 23/33] feat(celery): improve machine registration robustness and error handling - Move sqlalchemy.text import to module level for consistency - Add socket import for hostname initialization - Ensure cached_machine hostname is initialized before database operations - Add validation and default values for machine attributes (arch, num_cu) - Improve avail_gpus formatting with early conversion to string - Add comprehensive error handling with detailed logging for machine registration failures - Preserve IntegrityError exception context when re-raising - Add broad exception handler to catch and log unexpected registration errors These changes prevent potential NoneType errors and provide better diagnostics when machine registration fails during Celery worker initialization. --- tuna/miopen/celery_tuning/celery_tasks.py | 33 ++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 75419749..13a37d81 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -27,10 +27,12 @@ ############################################################################### """Module to register MIOpen celery tasks""" import os +import socket import copy from celery.signals import celeryd_after_setup from celery.utils.log import get_task_logger from sqlalchemy.exc import IntegrityError +from sqlalchemy import text from tuna.celery_app.celery_app import app from tuna.libraries import Operation from tuna.machine import Machine @@ -46,7 +48,6 @@ def check_hostname_unique_constraint(session): """Check if hostname has a unique constraint on the machine table""" try: - from sqlalchemy import text result = session.execute(text( "SELECT COUNT(*) FROM information_schema.statistics " "WHERE table_schema = DATABASE() " @@ -67,6 +68,19 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar # Ensure this machine is in the database global cached_machine + + # Ensure cached_machine is fully initialized + if not cached_machine.hostname: + cached_machine.hostname = socket.gethostname() + logger.info("Initialized hostname: %s", cached_machine.hostname) + + # Ensure avail_gpus is properly formatted as string + avail_gpus_str = cached_machine.avail_gpus + if isinstance(avail_gpus_str, list): + avail_gpus_str = ','.join(map(str, avail_gpus_str)) + elif avail_gpus_str is None: + avail_gpus_str = '' + with DbSession() as session: # Check for unique constraint on hostname (only check once) if not check_hostname_unique_constraint(session): @@ -89,9 +103,9 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar hostname=cached_machine.hostname, user=os.getenv('USER', 'unknown'), password='', - arch=cached_machine.arch, - num_cu=cached_machine.num_cu, - avail_gpus=','.join(map(str, cached_machine.avail_gpus)) if isinstance(cached_machine.avail_gpus, list) else str(cached_machine.avail_gpus) + arch=cached_machine.arch if cached_machine.arch else 'unknown', + num_cu=cached_machine.num_cu if cached_machine.num_cu else 64, + avail_gpus=avail_gpus_str ) try: @@ -101,7 +115,7 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar session.refresh(new_machine) cached_machine.id = new_machine.id logger.info("Registered machine %s with id %s", cached_machine.hostname, cached_machine.id) - except IntegrityError: + except IntegrityError as ie: # Race condition: another worker beat us to it # Rollback and query again to get the existing record session.rollback() @@ -115,7 +129,14 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar else: # This should never happen, but log it if it does logger.error("Failed to find machine after IntegrityError - this should not happen!") - raise + raise ie + except Exception as e: # pylint: disable=broad-exception-caught + # Log any other errors during machine registration + session.rollback() + logger.error("Error registering machine: %s", e) + logger.error("Machine details - hostname: %s, arch: %s, num_cu: %s, avail_gpus: %s", + cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, avail_gpus_str) + raise else: # Use existing machine id cached_machine.id = existing.id From 29a1f054126745915a211ffd46e9a45c0ffeb03d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 04:18:05 -0600 Subject: [PATCH 24/33] feat(machine): add SQLAlchemy validator for avail_gpus field Add @validates decorator to Machine class to automatically convert avail_gpus list to comma-separated string for database storage. This centralizes the conversion logic in the model layer instead of handling it manually in celery_tasks.py. - Add validates import from sqlalchemy.orm - Implement validate_avail_gpus() method to handle list-to-string conversion - Remove manual string conversion logic from capture_worker_name() - Pass avail_gpus as list to Machine constructor, letting validator handle conversion - Update comments to document the automatic conversion behavior This improves code maintainability by following the DRY principle and ensures consistent data formatting across all Machine instances. --- tuna/machine.py | 8 ++++++++ tuna/miopen/celery_tuning/celery_tasks.py | 13 ++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tuna/machine.py b/tuna/machine.py index dfaaf75f..9b4b328b 100644 --- a/tuna/machine.py +++ b/tuna/machine.py @@ -38,6 +38,7 @@ from typing import Set, List, Optional, TextIO, Tuple, Dict, Union, Any, Callable from sqlalchemy import Text, Column, orm +from sqlalchemy.orm import validates from sqlalchemy.dialects.mysql import TINYINT, INTEGER from paramiko import SSHClient @@ -145,6 +146,13 @@ def __init__(self, **kwargs: dict) -> None: self.logger.info("avail gpus: %s", self.avail_gpus) + @validates('avail_gpus') + def validate_avail_gpus(self, key, value): + """Convert avail_gpus to comma-separated string for database storage""" + if isinstance(value, list): + return ','.join(map(str, value)) + return value if value else '' + def set_logger(self, logger: logging.Logger) -> bool: """set logging for machine, use this to associate the machine with a subprocess""" pid: int = os.getpid() diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 13a37d81..60c90538 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -74,13 +74,6 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar cached_machine.hostname = socket.gethostname() logger.info("Initialized hostname: %s", cached_machine.hostname) - # Ensure avail_gpus is properly formatted as string - avail_gpus_str = cached_machine.avail_gpus - if isinstance(avail_gpus_str, list): - avail_gpus_str = ','.join(map(str, avail_gpus_str)) - elif avail_gpus_str is None: - avail_gpus_str = '' - with DbSession() as session: # Check for unique constraint on hostname (only check once) if not check_hostname_unique_constraint(session): @@ -99,13 +92,15 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar if not existing: # Create a new machine object for database insertion # Don't use cached_machine directly as it has id=0 hardcoded + # Note: avail_gpus can be passed as a list - the @validates decorator + # in Machine class will automatically convert it to a string for database storage new_machine = Machine( hostname=cached_machine.hostname, user=os.getenv('USER', 'unknown'), password='', arch=cached_machine.arch if cached_machine.arch else 'unknown', num_cu=cached_machine.num_cu if cached_machine.num_cu else 64, - avail_gpus=avail_gpus_str + avail_gpus=cached_machine.avail_gpus if cached_machine.avail_gpus else [] ) try: @@ -135,7 +130,7 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar session.rollback() logger.error("Error registering machine: %s", e) logger.error("Machine details - hostname: %s, arch: %s, num_cu: %s, avail_gpus: %s", - cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, avail_gpus_str) + cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, cached_machine.avail_gpus) raise else: # Use existing machine id From f75e451b5c98d1b2b515003b9d292aa2b5f5d5f5 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 04:28:33 -0600 Subject: [PATCH 25/33] yapf formatting --- ...1b2c3d4e5f6_add_machine_hostname_unique.py | 13 ++-- tuna/miopen/celery_tuning/celery_tasks.py | 66 ++++++++++--------- tuna/miopen/utils/helper.py | 18 +++-- 3 files changed, 54 insertions(+), 43 deletions(-) diff --git a/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py index 46921680..629a0438 100644 --- a/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py +++ b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py @@ -24,16 +24,13 @@ def upgrade() -> None: WHERE m1.id > m2.id AND m1.hostname = m2.hostname """) - + # Then add the unique constraint on hostname # Using prefix length of 255 since hostname is TEXT type - op.create_index( - 'idx_hostname', - 'machine', - ['hostname'], - unique=True, - mysql_length={'hostname': 255} - ) + op.create_index('idx_hostname', + 'machine', ['hostname'], + unique=True, + mysql_length={'hostname': 255}) def downgrade() -> None: diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 60c90538..61ade415 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -48,13 +48,12 @@ def check_hostname_unique_constraint(session): """Check if hostname has a unique constraint on the machine table""" try: - result = session.execute(text( - "SELECT COUNT(*) FROM information_schema.statistics " - "WHERE table_schema = DATABASE() " - "AND table_name = 'machine' " - "AND column_name = 'hostname' " - "AND non_unique = 0" - )).scalar() + result = session.execute( + text("SELECT COUNT(*) FROM information_schema.statistics " + "WHERE table_schema = DATABASE() " + "AND table_name = 'machine' " + "AND column_name = 'hostname' " + "AND non_unique = 0")).scalar() return result > 0 except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Could not check for hostname unique constraint: %s", e) @@ -65,15 +64,15 @@ def check_hostname_unique_constraint(session): def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument """Capture worker name and ensure machine is registered""" app.worker_name = sender - + # Ensure this machine is in the database global cached_machine - + # Ensure cached_machine is fully initialized if not cached_machine.hostname: cached_machine.hostname = socket.gethostname() logger.info("Initialized hostname: %s", cached_machine.hostname) - + with DbSession() as session: # Check for unique constraint on hostname (only check once) if not check_hostname_unique_constraint(session): @@ -81,14 +80,12 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar "WARNING: The 'machine' table does not have a UNIQUE constraint on 'hostname'. " "This may lead to duplicate machine entries and race conditions. " "Please run: ALTER TABLE machine ADD UNIQUE INDEX idx_hostname (hostname(255)); " - "Or apply the Alembic migration: alembic upgrade head" - ) - + "Or apply the Alembic migration: alembic upgrade head") + # Check if machine exists by hostname existing = session.query(Machine).filter( - Machine.hostname == cached_machine.hostname - ).first() - + Machine.hostname == cached_machine.hostname).first() + if not existing: # Create a new machine object for database insertion # Don't use cached_machine directly as it has id=0 hardcoded @@ -100,42 +97,51 @@ def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-ar password='', arch=cached_machine.arch if cached_machine.arch else 'unknown', num_cu=cached_machine.num_cu if cached_machine.num_cu else 64, - avail_gpus=cached_machine.avail_gpus if cached_machine.avail_gpus else [] - ) - + avail_gpus=cached_machine.avail_gpus + if cached_machine.avail_gpus else []) + try: # Insert the machine and let database auto-assign ID session.add(new_machine) session.commit() session.refresh(new_machine) cached_machine.id = new_machine.id - logger.info("Registered machine %s with id %s", cached_machine.hostname, cached_machine.id) + logger.info("Registered machine %s with id %s", cached_machine.hostname, + cached_machine.id) except IntegrityError as ie: # Race condition: another worker beat us to it # Rollback and query again to get the existing record session.rollback() - logger.info("Race condition detected during machine registration, querying existing record") + logger.info( + "Race condition detected during machine registration, querying existing record" + ) existing = session.query(Machine).filter( - Machine.hostname == cached_machine.hostname - ).first() + Machine.hostname == cached_machine.hostname).first() if existing: cached_machine.id = existing.id - logger.info("Using existing machine %s with id %s (from race condition recovery)", cached_machine.hostname, cached_machine.id) + logger.info( + "Using existing machine %s with id %s (from race condition recovery)", + cached_machine.hostname, cached_machine.id) else: # This should never happen, but log it if it does - logger.error("Failed to find machine after IntegrityError - this should not happen!") + logger.error( + "Failed to find machine after IntegrityError - this should not happen!" + ) raise ie except Exception as e: # pylint: disable=broad-exception-caught # Log any other errors during machine registration session.rollback() logger.error("Error registering machine: %s", e) - logger.error("Machine details - hostname: %s, arch: %s, num_cu: %s, avail_gpus: %s", - cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, cached_machine.avail_gpus) + logger.error( + "Machine details - hostname: %s, arch: %s, num_cu: %s, avail_gpus: %s", + cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, + cached_machine.avail_gpus) raise else: # Use existing machine id cached_machine.id = existing.id - logger.info("Using existing machine %s with id %s", cached_machine.hostname, cached_machine.id) + logger.info("Using existing machine %s with id %s", + cached_machine.hostname, cached_machine.id) cached_machine = Machine(local_machine=True) @@ -183,8 +189,8 @@ def celery_enqueue(context): worker = prep_worker(copy.deepcopy(context)) ret = worker.run() - + # Add machine_id to the context before returning context['machine_id'] = cached_machine.id - + return {"ret": ret, "context": context} diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py index cfaa3151..f71cee52 100644 --- a/tuna/miopen/utils/helper.py +++ b/tuna/miopen/utils/helper.py @@ -204,19 +204,25 @@ def get_db_id(db_elems, config_table): return cid -def set_job_state(session, job, dbt, state, increment_retries=False, result="", machine_id=None): +def set_job_state(session, + job, + dbt, + state, + increment_retries=False, + result="", + machine_id=None): """Update job state for builder/evaluator job_set_attr: List[str]""" LOGGER.info('Setting job id %s state to %s', job.id, state) job_set_attr = ['state', 'gpu_id'] job.state = state - + # Add machine_id if provided if machine_id is not None: job_set_attr.append('machine_id') job.machine_id = machine_id LOGGER.info('Setting job %s machine_id to %s', job.id, machine_id) - + if result: job_set_attr.append('result') job.result = result @@ -227,11 +233,13 @@ def set_job_state(session, job, dbt, state, increment_retries=False, result="", current_retries = session.execute(text(query_retries)).scalar() if current_retries is not None: job.retries = current_retries + 1 - LOGGER.info('Job %s retry count: %d -> %d', job.id, current_retries, job.retries) + LOGGER.info('Job %s retry count: %d -> %d', job.id, current_retries, + job.retries) else: # Fallback if query fails job.retries = getattr(job, 'retries', 0) + 1 - LOGGER.warning('Could not query current retries for job %s, using fallback', job.id) + LOGGER.warning( + 'Could not query current retries for job %s, using fallback', job.id) #pylint: disable=duplicate-code if '_start' in state: From 08ddbe0df6ebf69133a4fb3fba3ffcd7ca045b22 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 05:03:46 -0600 Subject: [PATCH 26/33] refactor(machine): convert avail_gpus to hybrid_property with getter/setter Replace SQLAlchemy's `validates` decorator with `hybrid_property` for the `avail_gpus` field to provide better separation between database storage (comma-separated string) and application usage (list of integers). Changes: - Rename column to `_avail_gpus` with proper column mapping - Add `@hybrid_property` getter that returns List[int] for application use - Add setter that converts list to comma-separated string for DB storage - Handle both list and string inputs with proper type conversion - Minor code formatting improvements (whitespace, line breaks) This improves type safety and makes the interface more intuitive while maintaining backward compatibility with existing database schema. --- tuna/machine.py | 26 +++++++++++++----- tuna/miopen/db/mixin_tables.py | 4 +-- tuna/miopen/miopen_lib.py | 16 +++++++++-- tuna/mituna_interface.py | 50 +++++++++++++++++++++------------- tuna/utils/db_utility.py | 4 +-- 5 files changed, 67 insertions(+), 33 deletions(-) diff --git a/tuna/machine.py b/tuna/machine.py index 9b4b328b..a7105070 100644 --- a/tuna/machine.py +++ b/tuna/machine.py @@ -38,7 +38,7 @@ from typing import Set, List, Optional, TextIO, Tuple, Dict, Union, Any, Callable from sqlalchemy import Text, Column, orm -from sqlalchemy.orm import validates +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.dialects.mysql import TINYINT, INTEGER from paramiko import SSHClient @@ -69,7 +69,7 @@ class Machine(BASE): #pylint: disable=too-many-instance-attributes local_port: int = Column(INTEGER, server_default="22") user: str = Column(Text, nullable=False) password: str = Column(Text, nullable=False) - avail_gpus: List[int] = Column(Text, nullable=False) + _avail_gpus: str = Column('avail_gpus', Text, nullable=False) arch: str = Column(Text, nullable=False) arch_full: str = '' num_cu: int = Column(INTEGER, nullable=False, server_default="64") @@ -146,12 +146,24 @@ def __init__(self, **kwargs: dict) -> None: self.logger.info("avail gpus: %s", self.avail_gpus) - @validates('avail_gpus') - def validate_avail_gpus(self, key, value): - """Convert avail_gpus to comma-separated string for database storage""" + @hybrid_property + def avail_gpus(self) -> List[int]: + """Return avail_gpus as a list of integers for application use""" + if isinstance(self._avail_gpus, str) and self._avail_gpus: + return [int(x) for x in self._avail_gpus.split(',')] + elif isinstance(self._avail_gpus, list): + return self._avail_gpus + return [] + + @avail_gpus.setter + def avail_gpus(self, value: Union[List[int], str]) -> None: + """Store avail_gpus as comma-separated string for database storage""" if isinstance(value, list): - return ','.join(map(str, value)) - return value if value else '' + self._avail_gpus = ','.join(map(str, value)) + elif value: + self._avail_gpus = str(value) + else: + self._avail_gpus = '' def set_logger(self, logger: logging.Logger) -> bool: """set logging for machine, use this to associate the machine with a subprocess""" diff --git a/tuna/miopen/db/mixin_tables.py b/tuna/miopen/db/mixin_tables.py index a543b7f1..da933314 100644 --- a/tuna/miopen/db/mixin_tables.py +++ b/tuna/miopen/db/mixin_tables.py @@ -65,8 +65,8 @@ class MIOpenJobMixin(JobMixin): solver = Column(String(length=128), nullable=True, server_default="") eval_mid = Column(Integer, server_default="-1") fin_step = Column(mysql.SET(*(list(k for k in FinStep.__members__))), - nullable=False, - server_default="not_fin") + nullable=False, + server_default="not_fin") class ConfigTagMixin(): diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index d8936813..5c51b935 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -862,7 +862,7 @@ def process_eval_results(self, session, fin_json, context): result_str = "" pending = [] orig_state = "compiled" - + # Extract machine_id from context machine_id = context.get('machine_id', None) @@ -911,7 +911,12 @@ def process_eval_results(self, session, fin_json, context): if failed_job: if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member self.logger.warning("max job retries exhausted, setting to errored") - set_job_state(session, job, self.dbt, "errored", result=result_str, machine_id=machine_id) + set_job_state(session, + job, + self.dbt, + "errored", + result=result_str, + machine_id=machine_id) else: self.logger.warning("resetting job state to %s, incrementing retries", orig_state) @@ -926,7 +931,12 @@ def process_eval_results(self, session, fin_json, context): ) else: self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, "evaluated", result=result_str, machine_id=machine_id) + set_job_state(session, + job, + self.dbt, + "evaluated", + result=result_str, + machine_id=machine_id) clean_cache_table(self.dbt, job) except (OperationalError, IntegrityError) as err: self.logger.warning("FinBuild: Unable to update Database %s", err) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 00bc851c..9bf99f2f 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -376,7 +376,10 @@ def _should_wait_for_progress(self, job_batch_size): return our_in_progress_count >= progress_threshold - def _fetch_jobs_with_retry(self, job_batch_size, max_retries=3, retry_delay=5): + def _fetch_jobs_with_retry(self, + job_batch_size, + max_retries=3, + retry_delay=5): """Fetch jobs from database with retry logic Returns: @@ -395,8 +398,8 @@ def _fetch_jobs_with_retry(self, job_batch_size, max_retries=3, retry_delay=5): return job_list except Exception as db_err: # pylint: disable=broad-exception-caught - self.logger.warning('Database error on attempt %d/%d: %s', - attempt + 1, max_retries, db_err) + self.logger.warning('Database error on attempt %d/%d: %s', attempt + 1, + max_retries, db_err) if attempt < max_retries - 1: time.sleep(retry_delay * (attempt + 1)) # Exponential backoff else: @@ -419,7 +422,7 @@ def _process_job_batch(self, job_list, job_counter, q_name): # Get context and enqueue each job with DbSession() as session: context_list = self.get_context_list(session, job_list) - + for context in context_list: try: self.celery_enqueue_call(context, q_name=q_name) @@ -439,7 +442,7 @@ def _process_job_batch(self, job_list, job_counter, q_name): def enqueue_jobs(self, job_counter, job_batch_size, q_name): """Enqueue celery jobs with simplified progress tracking""" self.logger.info("Starting enqueue") - + is_first_batch = True consecutive_empty_fetches = 0 max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) @@ -448,7 +451,8 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): while True: # 1. Check if we should wait for progress (skip on first batch) if not is_first_batch and self._should_wait_for_progress(job_batch_size): - self.logger.info("Waiting for current batch to progress before fetching more jobs") + self.logger.info( + "Waiting for current batch to progress before fetching more jobs") # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs consecutive_empty_fetches = 0 time.sleep(poll_interval) @@ -456,18 +460,19 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): # 2. Fetch jobs with built-in retry logic job_list = self._fetch_jobs_with_retry(job_batch_size) - + # 3. Handle empty results if not job_list: consecutive_empty_fetches += 1 self.logger.info('No jobs found (attempt %d/%d)', consecutive_empty_fetches, max_empty_fetches) - + if consecutive_empty_fetches >= max_empty_fetches: - self.logger.info('No more jobs available after %d attempts. Exiting enqueue loop.', - max_empty_fetches) + self.logger.info( + 'No more jobs available after %d attempts. Exiting enqueue loop.', + max_empty_fetches) return - + time.sleep(poll_interval) continue @@ -486,10 +491,13 @@ def cleanup_completed_jobs(self): # Clear and repopulate the shared list del self.completed_job_ids[:] self.completed_job_ids.extend(recent_completions) - + # Remove old claimed jobs that are completed completed_set = set(recent_completions[:-1000]) - claimed_list = [job_id for job_id in self.claimed_job_ids if job_id not in completed_set] + claimed_list = [ + job_id for job_id in self.claimed_job_ids + if job_id not in completed_set + ] del self.claimed_job_ids[:] self.claimed_job_ids.extend(claimed_list) @@ -620,12 +628,12 @@ def tune(self, job_batch_size=1000): # set job count to 1 until first job fetch is finished job_counter = Value("i", 1) - + # Create shared data structures for cross-process communication manager = Manager() self.claimed_job_ids = manager.list() # Shared list across processes self.completed_job_ids = manager.list() # Shared list across processes - + try: # cleanup old results cleanup_proc = Process(target=self.async_wrap, @@ -792,22 +800,26 @@ async def parse_result(self, data): if job_id and job_id in self.claimed_job_ids: # Check the final state of the job after processing final_state = self.get_job_final_state(session, job_id) - + if final_state in ['evaluated', 'errored']: # Job is truly complete - append to completed list self.completed_job_ids.append(job_id) - self.logger.info("Marked job %s as completed with state: %s", job_id, final_state) + self.logger.info("Marked job %s as completed with state: %s", job_id, + final_state) elif final_state == 'compiled': # Job failed and was reset to compiled for retry # Remove from claimed so it can be re-grabbed try: self.claimed_job_ids.remove(job_id) - self.logger.info("Job %s failed and reset to 'compiled' - removed from claimed list for retry", job_id) + self.logger.info( + "Job %s failed and reset to 'compiled' - removed from claimed list for retry", + job_id) except ValueError: # Job ID not in list, ignore pass else: - self.logger.warning("Job %s has unexpected final state: %s", job_id, final_state) + self.logger.warning("Job %s has unexpected final state: %s", job_id, + final_state) return True diff --git a/tuna/utils/db_utility.py b/tuna/utils/db_utility.py index 277a5fc2..5c2dc169 100644 --- a/tuna/utils/db_utility.py +++ b/tuna/utils/db_utility.py @@ -143,12 +143,12 @@ def sanitize_sql_string(value: str, max_length: int = 2000) -> str: # Truncate to safe length to avoid excessively long queries if len(value) > max_length: value = value[:max_length] + '... [truncated]' - + # Escape backslashes first (must be done before quotes) value = value.replace('\\', '\\\\') # Escape single quotes by doubling them (SQL standard) value = value.replace("'", "''") - + return value From 9477576fd78144f8f5985144cd105347100103b4 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 18 Nov 2025 05:10:18 -0600 Subject: [PATCH 27/33] fix(machine): handle avail_gpus type conversion for hybrid property Add type checking before converting avail_gpus to prevent double conversion when value is already a list from hybrid property getter. This fixes a bug where string split was attempted on list objects, causing AttributeError. - Check if avail_gpus is already a list before attempting string conversion - Only split and convert when value is a non-empty string from database - Prevents runtime errors when hybrid property returns list type --- tuna/machine.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tuna/machine.py b/tuna/machine.py index a7105070..c8ad6a2a 100644 --- a/tuna/machine.py +++ b/tuna/machine.py @@ -137,10 +137,14 @@ def __init__(self, **kwargs: dict) -> None: self.ipmi_user, self.ipmi_password) if not self.avail_gpus is None: - self.avail_gpus = [ - int(val) for val in self.avail_gpus.split(',') #type: ignore - ] #type: ignore - self.num_gpus = len(self.avail_gpus) + # Check if it's already a list (from hybrid property getter) + if isinstance(self.avail_gpus, list): + # Already converted by hybrid property, just use it + self.num_gpus = len(self.avail_gpus) + elif isinstance(self.avail_gpus, str) and self.avail_gpus: + # String from database, convert to list + self.avail_gpus = [int(val) for val in self.avail_gpus.split(',')] + self.num_gpus = len(self.avail_gpus) self.cpus = [] self.gpus = [] From fb09e8d014bb0ae6a5001e5a3e22e7cbaa6669c5 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Fri, 21 Nov 2025 09:32:44 -0600 Subject: [PATCH 28/33] feat(miopen): add detection and handling of database-locked jobs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `detect_and_handle_locked_jobs` method to identify jobs that are stuck due to stale database locks and preventing pipeline progress. The method: - Queries jobs without locking to detect if they're being skipped by FOR UPDATE SKIP LOCKED clauses - Automatically marks jobs with high retry counts (≥2) as errored to unblock the pipeline - Logs warnings about detected locked jobs for debugging Also enhance logging in `mituna_interface.py` to provide better visibility into job fetching and state transitions, including job counts and IDs. This resolves issues where stale database locks cause jobs to be perpetually skipped, blocking pipeline execution. --- tuna/miopen/miopen_lib.py | 68 ++++++++++++++++++++++++++++++++++++ tuna/mituna_interface.py | 72 +++++++++++++++++++++++++++++++++------ 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index 5c51b935..665d5510 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -35,6 +35,7 @@ from kombu.utils.uuid import uuid from sqlalchemy.inspection import inspect from sqlalchemy.exc import OperationalError, DataError, IntegrityError +from sqlalchemy import text from tuna.mituna_interface import MITunaInterface from tuna.miopen.utils.helper import print_solvers from tuna.parse_args import TunaArgs, setup_arg_parser, args_check @@ -604,6 +605,73 @@ def compose_work_objs( return job_entries + def detect_and_handle_locked_jobs(self, session: DbSession, + find_state: List[str]) -> bool: + """Detect jobs that are locked and preventing progress + + This method queries for jobs without locking to detect if jobs exist + but are being skipped due to database locks. If found, it marks jobs + with high retry counts as errored to unblock the pipeline. + + @param session DB session + @param find_state List of job states to check + @return True if locked jobs were found and handled, False otherwise + """ + # Query WITHOUT lock to see if jobs are being skipped + conds = [f"session={self.dbt.session.id}", "valid=1"] + + if self.args.label: + conds.append(f"reason='{self.args.label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + if self.args.fin_steps: + conds.append(f"fin_step like '%{self.args.fin_steps[0]}%'") + + cond_str = " AND ".join(conds) + query = f""" + SELECT id, config, retries, state, solver + FROM {self.dbt.job_table.__tablename__} + WHERE {cond_str} + ORDER BY retries, config ASC + LIMIT 10 + """ + + unlocked_jobs = session.execute(text(query)).fetchall() + + if unlocked_jobs: + self.logger.warning( + "Found %d jobs in target state but they were skipped by FOR UPDATE SKIP LOCKED", + len(unlocked_jobs)) + self.logger.warning("Likely cause: stale database locks. Job IDs: %s", + [job[0] for job in unlocked_jobs]) + + # Mark jobs with high retries as errored to unblock + jobs_marked = 0 + for job_row in unlocked_jobs: + job_id, config_id, retries, state, solver = job_row + if retries >= (MAX_ERRORED_JOB_RETRIES - 1): # retries >= 2 + self.logger.warning( + "Marking locked job %d (config=%d, solver=%s, retries=%d) as errored to unblock pipeline", + job_id, config_id, solver, retries) + update_query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = 'errored', + result = 'Marked as errored due to stale lock or excessive retries', + update_ts = NOW() + WHERE id = {job_id} + """ + session.execute(text(update_query)) + jobs_marked += 1 + + if jobs_marked > 0: + session.commit() + self.logger.info("Marked %d locked jobs as errored", jobs_marked) + return True + + return False + def compose_work_objs_fin(self, session, job_entries, dbt) -> List[Tuple[SimpleDict, SimpleDict]]: """! Return jobs for fin work diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 9bf99f2f..764d61e8 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -282,18 +282,24 @@ def get_jobs( ids: list row: SimpleDict - self.logger.info("Fetching DB rows...") + self.logger.info("Fetching DB rows for states=%s, session=%d, claim_num=%s", + find_state, session_id, claim_num) job_list = self.get_job_list(session, find_state, claim_num) + self.logger.info("get_job_list returned %d jobs", len(job_list) if job_list else 0) if not self.check_jobs_found(job_list, find_state, session_id): + self.logger.info("check_jobs_found returned False - no jobs available") return [] if no_update: + self.logger.info("no_update=True, returning %d jobs without state update", + len(job_list)) return job_list ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info("Updating job state to %s", set_state) + self.logger.info("Found %d jobs with IDs: %s (showing first 10)", + len(ids), ids[:10] if len(ids) > 10 else ids) + self.logger.info("Updating job state from %s to %s", find_state, set_state) # OPTIMIZATION: Use bulk UPDATE instead of individual updates if self.dbt is not None: @@ -304,6 +310,7 @@ def get_jobs( WHERE id IN ({id_str}) """ session.execute(text(query)) + self.logger.info("Executed bulk UPDATE for %d jobs", len(ids)) # Update local objects to reflect new state for job in job_list: @@ -312,6 +319,8 @@ def get_jobs( raise CustomError("DBTable must be set") session.commit() + self.logger.info("Transaction committed - %d jobs now in state '%s'", + len(ids), set_state) return job_list @@ -441,45 +450,88 @@ def _process_job_batch(self, job_list, job_counter, q_name): def enqueue_jobs(self, job_counter, job_batch_size, q_name): """Enqueue celery jobs with simplified progress tracking""" - self.logger.info("Starting enqueue") + self.logger.info("Starting enqueue loop - batch_size=%d, queue=%s", + job_batch_size, q_name) + self.logger.info("Fetch states: %s, Set state: %s", self.fetch_state, + self.set_state) is_first_batch = True consecutive_empty_fetches = 0 max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) + loop_iteration = 0 while True: + loop_iteration += 1 + self.logger.info("=== Enqueue loop iteration %d ===", loop_iteration) + # 1. Check if we should wait for progress (skip on first batch) if not is_first_batch and self._should_wait_for_progress(job_batch_size): self.logger.info( - "Waiting for current batch to progress before fetching more jobs") + "Waiting for current batch to progress before fetching more jobs (iteration %d)", + loop_iteration) # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs consecutive_empty_fetches = 0 + self.logger.info("Sleeping for %d seconds...", poll_interval) time.sleep(poll_interval) continue # 2. Fetch jobs with built-in retry logic + self.logger.info("Attempting to fetch %d jobs (iteration %d)", + job_batch_size, loop_iteration) job_list = self._fetch_jobs_with_retry(job_batch_size) + self.logger.info("Fetch returned %d jobs", len(job_list) if job_list else 0) # 3. Handle empty results if not job_list: consecutive_empty_fetches += 1 - self.logger.info('No jobs found (attempt %d/%d)', - consecutive_empty_fetches, max_empty_fetches) + self.logger.warning( + 'No jobs found (attempt %d/%d) - iteration %d', + consecutive_empty_fetches, max_empty_fetches, loop_iteration) + + # Check if jobs are being skipped due to database locks + if consecutive_empty_fetches == 2: # After 2nd empty fetch + self.logger.warning( + "Checking for locked jobs that may be blocking progress (iteration %d)", + loop_iteration) + with DbSession() as lock_check_session: + if hasattr(self, 'detect_and_handle_locked_jobs'): + try: + handled = self.detect_and_handle_locked_jobs( + lock_check_session, list(self.fetch_state)) + if handled: + self.logger.info( + "Handled locked jobs, resetting empty fetch counter") + consecutive_empty_fetches = 0 # Reset counter to retry + continue + else: + self.logger.info("No locked jobs found to handle") + except Exception as lock_err: # pylint: disable=broad-exception-caught + self.logger.error("Error checking for locked jobs: %s", lock_err) + else: + self.logger.warning( + "detect_and_handle_locked_jobs method not available") if consecutive_empty_fetches >= max_empty_fetches: - self.logger.info( - 'No more jobs available after %d attempts. Exiting enqueue loop.', - max_empty_fetches) + self.logger.warning( + 'EXITING: No more jobs available after %d attempts (iteration %d). Exiting enqueue loop.', + max_empty_fetches, loop_iteration) + self.logger.info("Final state - claimed: %d, completed: %d", + len(self.claimed_job_ids), len(self.completed_job_ids)) return + self.logger.info("Sleeping for %d seconds before retry (iteration %d)...", + poll_interval, loop_iteration) time.sleep(poll_interval) continue # 4. Process the batch + self.logger.info("Processing batch of %d jobs (iteration %d)", len(job_list), + loop_iteration) consecutive_empty_fetches = 0 self._process_job_batch(job_list, job_counter, q_name) is_first_batch = False + self.logger.info("Batch processed successfully (iteration %d)", loop_iteration) def cleanup_completed_jobs(self): """Periodically clean up old job tracking data""" From 1ff40472bfde351f0f0cdc521c3aa099881dbb8d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Sat, 22 Nov 2025 02:37:44 -0600 Subject: [PATCH 29/33] refactor: reorganize imports and reduce verbose logging in MITunaInterface - Reorganize imports alphabetically and group by standard lib, third-party, and local - Consolidate related imports using parentheses for better readability - Remove excessive debug logging in get_jobs() method - Add new reconcile_tracking_state() method for job state reconciliation - Replace multiple verbose log statements with single summary log message This improves code maintainability by following PEP 8 import conventions and reduces log noise while maintaining essential debugging information. The new reconciliation method helps prevent the distributor from getting stuck on stale job states. --- tuna/mituna_interface.py | 206 +++++++++++++++++++++++++++++++-------- 1 file changed, 164 insertions(+), 42 deletions(-) diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 764d61e8..2fc82b36 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -25,39 +25,42 @@ # ############################################################################### """Interface class to set up and launch tuning functionality""" -import os -from multiprocessing import Value, Lock, Queue as mpQueue, Process, Manager -from typing import Optional, Dict, Any, List -from io import StringIO -from functools import lru_cache +import argparse +import asyncio import json import logging -import argparse +import os import subprocess -import time +import sys import threading -import asyncio +import time from datetime import timedelta -from sqlalchemy.exc import NoInspectionAvailable -from sqlalchemy.inspection import inspect -from sqlalchemy import text -import redis.asyncio as aioredis +from functools import lru_cache +from io import StringIO +from multiprocessing import Lock, Manager, Process +from multiprocessing import Queue as mpQueue +from multiprocessing import Value +from typing import Any, Dict, List, Optional + import kombu +import redis.asyncio as aioredis from paramiko.channel import ChannelFile +from sqlalchemy import text +from sqlalchemy.exc import NoInspectionAvailable +from sqlalchemy.inspection import inspect -from tuna.worker_interface import WorkerInterface -from tuna.machine import Machine -from tuna.libraries import Library -from tuna.utils.logger import setup_logger -from tuna.utils.utility import get_env_vars, SimpleDict -from tuna.dbBase.sql_alchemy import DbSession -from tuna.celery_app.celery_app import stop_active_workers, stop_named_worker -from tuna.celery_app.celery_app import get_backend_env, purge_queue -from tuna.celery_app.utility import get_q_name +from tuna.celery_app.celery_app import (get_backend_env, purge_queue, + stop_active_workers, stop_named_worker) from tuna.celery_app.celery_workers import launch_celery_worker -from tuna.libraries import Operation +from tuna.celery_app.utility import get_q_name from tuna.custom_errors import CustomError +from tuna.dbBase.sql_alchemy import DbSession +from tuna.libraries import Library, Operation +from tuna.machine import Machine from tuna.utils.db_utility import gen_update_query, session_retry +from tuna.utils.logger import setup_logger +from tuna.utils.utility import SimpleDict, get_env_vars +from tuna.worker_interface import WorkerInterface job_counter_lock = threading.Lock() @@ -282,24 +285,17 @@ def get_jobs( ids: list row: SimpleDict - self.logger.info("Fetching DB rows for states=%s, session=%d, claim_num=%s", - find_state, session_id, claim_num) job_list = self.get_job_list(session, find_state, claim_num) - self.logger.info("get_job_list returned %d jobs", len(job_list) if job_list else 0) if not self.check_jobs_found(job_list, find_state, session_id): - self.logger.info("check_jobs_found returned False - no jobs available") return [] if no_update: - self.logger.info("no_update=True, returning %d jobs without state update", - len(job_list)) return job_list ids = [row.id for row in job_list] - self.logger.info("Found %d jobs with IDs: %s (showing first 10)", - len(ids), ids[:10] if len(ids) > 10 else ids) - self.logger.info("Updating job state from %s to %s", find_state, set_state) + # Log summary of jobs being updated + self.logger.info("Updating %d jobs from %s to %s", len(ids), find_state, set_state) # OPTIMIZATION: Use bulk UPDATE instead of individual updates if self.dbt is not None: @@ -310,7 +306,6 @@ def get_jobs( WHERE id IN ({id_str}) """ session.execute(text(query)) - self.logger.info("Executed bulk UPDATE for %d jobs", len(ids)) # Update local objects to reflect new state for job in job_list: @@ -319,8 +314,6 @@ def get_jobs( raise CustomError("DBTable must be set") session.commit() - self.logger.info("Transaction committed - %d jobs now in state '%s'", - len(ids), set_state) return job_list @@ -385,6 +378,79 @@ def _should_wait_for_progress(self, job_batch_size): return our_in_progress_count >= progress_threshold + def reconcile_tracking_state(self): + """Reconcile tracking lists with actual database state + + This method queries the database to check the actual state of claimed jobs + and updates the tracking lists accordingly. This prevents the distributor + from getting stuck waiting for jobs that have already completed. + + Returns: + Number of jobs reconciled + """ + if not self.claimed_job_ids: + self.logger.info("No claimed jobs to reconcile") + return 0 + + claimed_set = set(self.claimed_job_ids) + completed_set = set(self.completed_job_ids) + in_progress_set = claimed_set - completed_set + + if not in_progress_set: + self.logger.info("No in-progress jobs to reconcile") + return 0 + + self.logger.info("Reconciling %d in-progress jobs with database state", + len(in_progress_set)) + + # Query database for actual state of these jobs + with DbSession() as session: + try: + # Batch the query to avoid SQL statement too long + in_progress_list = list(in_progress_set) + batch_size = 1000 + reconciled_count = 0 + + for i in range(0, len(in_progress_list), batch_size): + batch = in_progress_list[i:i + batch_size] + id_str = ','.join(map(str, batch)) + + query = f""" + SELECT id, state FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + """ + results = session.execute(text(query)).fetchall() + + completed_jobs = 0 + removed_jobs = 0 + + for job_id, state in results: + if state in ['evaluated', 'errored']: + # Job is complete but not tracked - add to completed + if job_id not in self.completed_job_ids: + self.completed_job_ids.append(job_id) + completed_jobs += 1 + reconciled_count += 1 + elif state in ['compiled', 'new']: + # Job was reset but still in claimed - remove from claimed + if job_id in self.claimed_job_ids: + self.claimed_job_ids.remove(job_id) + removed_jobs += 1 + reconciled_count += 1 + # Jobs in 'eval_start' state are legitimately in progress - no action needed + + # Log batch summary instead of individual jobs + if completed_jobs > 0 or removed_jobs > 0: + self.logger.info("Batch %d: marked %d completed, removed %d from claimed", + i // batch_size + 1, completed_jobs, removed_jobs) + + self.logger.info("Reconciliation complete: %d total jobs updated", reconciled_count) + return reconciled_count + + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error during reconciliation: %s", err) + return 0 + def _fetch_jobs_with_retry(self, job_batch_size, max_retries=3, @@ -450,6 +516,22 @@ def _process_job_batch(self, job_list, job_counter, q_name): def enqueue_jobs(self, job_counter, job_batch_size, q_name): """Enqueue celery jobs with simplified progress tracking""" + # Configure logger for subprocess to write to stdout + # This ensures logs are captured by bash redirection (> logfile.log 2>&1) + + # Remove any existing handlers to avoid duplicates + self.logger.handlers.clear() + + # Add StreamHandler that writes to stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s' + ) + stdout_handler.setFormatter(formatter) + self.logger.addHandler(stdout_handler) + self.logger.setLevel(logging.INFO) + self.logger.info("Starting enqueue loop - batch_size=%d, queue=%s", job_batch_size, q_name) self.logger.info("Fetch states: %s, Set state: %s", self.fetch_state, @@ -460,27 +542,67 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) loop_iteration = 0 + + # Track consecutive waits to detect stale state + consecutive_waits = 0 + last_in_progress_count = -1 + reconcile_threshold = int(os.environ.get('TUNA_RECONCILE_THRESHOLD', 5)) while True: loop_iteration += 1 - self.logger.info("=== Enqueue loop iteration %d ===", loop_iteration) + # Only log iteration every 10 iterations or when something interesting happens + if loop_iteration % 10 == 1: + self.logger.info("=== Enqueue loop iteration %d ===", loop_iteration) # 1. Check if we should wait for progress (skip on first batch) if not is_first_batch and self._should_wait_for_progress(job_batch_size): - self.logger.info( - "Waiting for current batch to progress before fetching more jobs (iteration %d)", - loop_iteration) + claimed_set = set(self.claimed_job_ids) + completed_set = set(self.completed_job_ids) + current_in_progress = len(claimed_set - completed_set) + + # Check if we're stuck waiting with the same in-progress count + if current_in_progress == last_in_progress_count: + consecutive_waits += 1 + # Only log warning every 5 waits to reduce verbosity + if consecutive_waits % 5 == 0 or consecutive_waits >= reconcile_threshold - 2: + self.logger.warning( + "Consecutive waits: %d/%d with same in-progress count: %d", + consecutive_waits, reconcile_threshold, current_in_progress) + else: + # Log when wait state changes + if consecutive_waits > 0: + self.logger.info("Wait state changed - in-progress count: %d -> %d", + last_in_progress_count, current_in_progress) + consecutive_waits = 0 + last_in_progress_count = current_in_progress + + # Trigger reconciliation if stuck waiting too long + if consecutive_waits >= reconcile_threshold: + self.logger.warning( + "RECONCILIATION TRIGGERED: Stuck waiting for %d iterations with %d jobs in progress", + consecutive_waits, current_in_progress) + reconciled = self.reconcile_tracking_state() + self.logger.info("Reconciled %d jobs - resetting wait counter", reconciled) + consecutive_waits = 0 + last_in_progress_count = -1 + # Don't sleep, immediately retry fetching jobs + continue + + # Only log wait message on first wait or every 10th wait + if consecutive_waits == 1 or consecutive_waits % 10 == 0: + self.logger.info( + "Waiting for batch progress (iteration %d, wait #%d)", + loop_iteration, consecutive_waits) # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs consecutive_empty_fetches = 0 - self.logger.info("Sleeping for %d seconds...", poll_interval) time.sleep(poll_interval) continue # 2. Fetch jobs with built-in retry logic - self.logger.info("Attempting to fetch %d jobs (iteration %d)", - job_batch_size, loop_iteration) job_list = self._fetch_jobs_with_retry(job_batch_size) - self.logger.info("Fetch returned %d jobs", len(job_list) if job_list else 0) + # Only log fetch details when jobs are found or on errors + if job_list: + self.logger.info("Fetched %d jobs (iteration %d)", len(job_list), loop_iteration) # 3. Handle empty results if not job_list: From b2713f07aa038b265fdbff61c79c756e50b60a82 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Tue, 2 Dec 2025 04:52:29 -0600 Subject: [PATCH 30/33] feat(miopen): add filtering options for applicability updates Add `--new_only` and `--config_limit` command-line arguments to optimize applicability update operations. The `--new_only` flag skips configs with existing applicability data, while `--config_limit` restricts the number of configs processed. Changes: - Add new CLI arguments for filtering applicability updates - Implement worker scaling based on GPU count (4x multiplier) for applicability operations - Pass filtering parameters through worker initialization and query methods - Add detailed logging for worker creation, config processing, and worker count calculations - Reduce log verbosity by commenting out duplicate job warnings in load_job.py This improves performance for large-scale applicability updates by allowing incremental processing and testing with limited datasets. --- tuna/miopen/miopen_lib.py | 42 +++++++++++++++++++++++++++++---- tuna/miopen/subcmd/load_job.py | 4 ++-- tuna/miopen/worker/fin_class.py | 35 +++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 10 deletions(-) diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index 665d5510..dcc5d1bc 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -225,6 +225,20 @@ def parse_args(self): action="store_true", help="Update the applicability table in the database", ) + parser.add_argument( + "--new_only", + dest="new_only", + action="store_true", + default=False, + help="Only update applicability for configs without existing data in this session (use with --update_applicability)", + ) + parser.add_argument( + "--config_limit", + dest="config_limit", + type=int, + default=None, + help="Limit the number of configs to process (useful for testing with --update_applicability)", + ) group.add_argument( "-s", "--status", @@ -364,7 +378,12 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst): kwargs = self.get_kwargs(gpu_idx, f_vals) if self.args.update_applicability: kwargs["fin_steps"] = ["applicability"] + kwargs["new_only"] = self.args.new_only + kwargs["config_limit"] = self.args.config_limit worker = FinClass(**kwargs) + self.logger.info("Created FinClass worker with gpu_id=%s, ROCR_VISIBLE_DEVICES in envmt: %s", + worker.gpu_id, + any('ROCR_VISIBLE_DEVICES' in env for env in worker.envmt)) worker.start() worker_lst.append(worker) return True @@ -404,25 +423,40 @@ def compose_worker_list(self, machines): # fin_steps should only contain one step worker_ids = None if self.args.fin_steps and "eval" in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() + worker_ids = machine.get_avail_gpus() # Use actual GPUs if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) + worker_ids = list(range(self.args.gpu_lim)) + elif self.args.update_applicability: + worker_ids = list(range(len(machine.get_avail_gpus()) * 4 )) # Use GPU count + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = list(range(self.args.gpu_lim)) else: - worker_ids = super().get_num_procs(machine) + worker_ids = super().get_num_procs(machine) # Use CPU count for other operations + if self.args.update_applicability: f_vals = super().get_f_vals(machine, [1]) kwargs = self.get_kwargs(0, f_vals) kwargs["fin_steps"] = ["applicability"] + kwargs["new_only"] = self.args.new_only + kwargs["config_limit"] = self.args.config_limit worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) + skip_existing = self.args.new_only + config_limit = self.args.config_limit + query = worker.query_cfgs(self.args.label, skip_existing=skip_existing, config_limit=config_limit) cfg_rows = query.all() len_rows = len(cfg_rows) + self.logger.warning("Found %d configs to process (label=%s, new_only=%s, config_limit=%s)", + len_rows, self.args.label, self.args.new_only, self.args.config_limit) proc_lim = (len_rows + 99) / 100 if 32 < proc_lim: proc_lim = 32 + self.logger.info("Calculated proc_lim=%d based on %d configs", proc_lim, len_rows) + initial_workers = len(worker_ids) while len(worker_ids) > proc_lim: worker_ids.pop() + self.logger.warning("Worker count: initial=%d, after limit=%d (proc_lim=%d)", + initial_workers, len(worker_ids), proc_lim) if len(worker_ids) == 0: return None diff --git a/tuna/miopen/subcmd/load_job.py b/tuna/miopen/subcmd/load_job.py index a45f3b6a..a9d2a177 100755 --- a/tuna/miopen/subcmd/load_job.py +++ b/tuna/miopen/subcmd/load_job.py @@ -183,8 +183,8 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables, if job.config in pre_ex: if job.solver in pre_ex[job.config]: - logger.warning("Job exists (skip): %s : %s", job.config, - job.solver) + # logger.warning("Job exists (skip): %s : %s", job.config, + # job.solver) continue session.add(job) diff --git a/tuna/miopen/worker/fin_class.py b/tuna/miopen/worker/fin_class.py index faac2217..19d94a2f 100644 --- a/tuna/miopen/worker/fin_class.py +++ b/tuna/miopen/worker/fin_class.py @@ -67,7 +67,7 @@ def __init__(self, **kwargs): """Constructor""" allowed_keys = set([ 'fin_steps', 'local_file', 'fin_infile', 'fin_outfile', 'config_type', - 'dynamic_solvers_only' + 'dynamic_solvers_only', 'new_only', 'config_limit' ]) self.__dict__.update((key, None) for key in allowed_keys) @@ -101,6 +101,15 @@ def __init__(self, **kwargs): ) self.envmt.append( f"MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopenpdb/thread-{self.gpu_id}/cache") + + if hasattr(self, 'gpu_id') and self.gpu_id is not None: + num_gpus = len(self.machine.get_avail_gpus()) if hasattr(self, 'machine') else 1 + actual_gpu = self.gpu_id % num_gpus # Wrap around available GPUs + self.envmt.append(f"ROCR_VISIBLE_DEVICES={actual_gpu}") + self.logger.info("Set ROCR_VISIBLE_DEVICES=%d for worker (worker_id=%d, num_gpus=%d)", + actual_gpu, self.gpu_id, num_gpus) + else: + self.logger.warning("gpu_id not set - ROCR_VISIBLE_DEVICES not configured. All workers may use same GPU!") self.cfg_attr = [column.name for column in inspect(self.dbt.config_table).c] @@ -319,8 +328,8 @@ def applicability(self): return True - def query_cfgs(self, label=None): - """query all configs from table, optionally limit by label""" + def query_cfgs(self, label=None, skip_existing=False, config_limit=None): + """query all configs from table, optionally limit by label, skip existing, and limit count""" with DbSession() as session: query = session.query(self.dbt.config_table)\ .filter(self.dbt.config_table.valid == 1) @@ -329,17 +338,35 @@ def query_cfgs(self, label=None): query = query.filter(self.dbt.config_table.id == self.dbt.config_tags_table.config)\ .filter(self.dbt.config_tags_table.tag == label) + # Skip configs that already have applicability data in this session + if skip_existing: + query = query.outerjoin( + self.dbt.solver_app, + (self.dbt.solver_app.config == self.dbt.config_table.id) & + (self.dbt.solver_app.session == self.session_id) + ).filter(self.dbt.solver_app.id == None) + #order by id for splitting configs into blocks query = query.order_by(self.dbt.config_table.id) + + # Apply config limit if specified + if config_limit is not None and config_limit > 0: + query = query.limit(config_limit) + self.logger.info("Limiting query to %d configs", config_limit) + return query def __set_all_configs(self, idx: int = 0, num_blk: int = 1) -> bool: """Gathering all configs from Tuna DB to set up fin input file""" if idx == 0: - query = self.query_cfgs(self.label) + skip_existing = getattr(self, 'new_only', False) + config_limit = getattr(self, 'config_limit', None) + query = self.query_cfgs(self.label, skip_existing=skip_existing, config_limit=config_limit) rows = query.all() len_rows = len(rows) + self.logger.warning("Query returned %d configs (label=%s, skip_existing=%s, config_limit=%s)", + len_rows, self.label, skip_existing, config_limit) master_cfg_list = [] for row in rows: r_dict = compose_config_obj(row, self.config_type) From 10773d160da45722d2133b600a186df78924030d Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 4 Dec 2025 09:55:59 -0600 Subject: [PATCH 31/33] feat(miopen): fix SQLAlchemy subquery usage with explicit select() Update subquery filtering to use explicit select() method calls for SQLAlchemy 2.0 compatibility. This change affects config tag filtering and solver application queries, ensuring proper subquery execution in newer SQLAlchemy versions where implicit select is deprecated. --- tuna/miopen/subcmd/load_job.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuna/miopen/subcmd/load_job.py b/tuna/miopen/subcmd/load_job.py index a9d2a177..61a9cde6 100755 --- a/tuna/miopen/subcmd/load_job.py +++ b/tuna/miopen/subcmd/load_job.py @@ -106,7 +106,7 @@ def config_query(args: argparse.Namespace, session, dbt: MIOpenDBTables): if args.tag: tag_query = session.query(dbt.config_tags_table.config)\ .filter(dbt.config_tags_table.tag == args.tag).subquery() - cfg_query = cfg_query.filter(dbt.config_table.id.in_(tag_query)) + cfg_query = cfg_query.filter(dbt.config_table.id.in_(tag_query.select())) if args.cmd: cfg_query = cfg_query.filter( @@ -136,7 +136,7 @@ def compose_query(args: argparse.Namespace, session, dbt: MIOpenDBTables, if args.only_dynamic: query = query.filter(Solver.is_dynamic == true()) - query = query.filter(dbt.solver_app.config.in_(cfg_query.subquery())) + query = query.filter(dbt.solver_app.config.in_(cfg_query.subquery().select())) return query From 3ee445984adb47750f0a6104339269f8f1f55f66 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Mon, 22 Dec 2025 14:05:34 +0000 Subject: [PATCH 32/33] NOTE: This commit may introduce a bug with NCHW layouts, they do not seem to be processed by "update_applicability" after import. consider reverting this. perf(miopen): optimize config import with batch processing and tensor caching This commit significantly improves the performance of config import operations: **Changes:** - Add batch processing for database operations in import_configs - New `batch_insert_tensors()` function for bulk tensor insertion - New `get_or_create_tensor_ids()` for efficient tensor ID retrieval - New `batch_process_drivers()` to process drivers in configurable batches - Optimize tensor caching in MIOpenDriver base class - Check cache before creating database session to avoid unnecessary connections - Reduce redundant cache lookups after loading tensor map - Early return on cache hits to minimize database operations - Add CLI arguments for batch configuration - `--batch_size`: configurable batch size (default: 1000) - `--disable_batch_import`: fallback flag for debugging **Why:** The original implementation processed configs one-by-one, creating a database session for each tensor lookup/insert. This caused significant overhead when importing large numbers of configs. The new batch processing approach reduces database round-trips and leverages bulk operations for better performance. --- tuna/miopen/driver/base.py | 38 +- tuna/miopen/parse_miopen_args.py | 13 + tuna/miopen/subcmd/import_configs.py | 651 ++++++++++++++++++++++++++- tuna/miopen/worker/fin_class.py | 7 +- 4 files changed, 680 insertions(+), 29 deletions(-) diff --git a/tuna/miopen/driver/base.py b/tuna/miopen/driver/base.py index 2bed7e6e..8c0d34e5 100755 --- a/tuna/miopen/driver/base.py +++ b/tuna/miopen/driver/base.py @@ -160,26 +160,38 @@ def get_tensor_id(session: Session, tensor_dict: dict) -> int: def __insert_tensor(self, tensor_dict: dict) -> int: """Insert new row into tensor table and return primary key""" ret_id: int = -1 + + # Check cache first without creating a session + tid = TensorTable(**tensor_dict) + tid.valid = 1 + key = build_dict_val_key(tid) + + # If cache is populated and key exists, return immediately + if MIOpenDriver.tensor_id_map and key in MIOpenDriver.tensor_id_map: + ret_id = MIOpenDriver.tensor_id_map[key] + LOGGER.info("Get Tensor: %s", ret_id) + return ret_id + + # Cache miss or not populated - need database session session: Session with DbSession() as session: try: - tid = TensorTable(**tensor_dict) - tid.valid = 1 - key = build_dict_val_key(tid) #cache the tensor table to avoid queries if not MIOpenDriver.tensor_id_map: MIOpenDriver.tensor_id_map = get_session_val_map( session, TensorTable, MIOpenDriver.tensor_attr) - id_map = MIOpenDriver.tensor_id_map - if key in id_map: - ret_id = id_map[key] - LOGGER.info("Get Tensor: %s", ret_id) - else: - session.add(tid) - session.commit() - ret_id = tid.id - id_map[key] = ret_id - LOGGER.info("Insert Tensor: %s", ret_id) + # Check cache again after loading + if key in MIOpenDriver.tensor_id_map: + ret_id = MIOpenDriver.tensor_id_map[key] + LOGGER.info("Get Tensor: %s", ret_id) + return ret_id + + # Not in cache, insert new tensor + session.add(tid) + session.commit() + ret_id = tid.id + MIOpenDriver.tensor_id_map[key] = ret_id + LOGGER.info("Insert Tensor: %s", ret_id) except IntegrityError as err: LOGGER.warning(err) session.rollback() diff --git a/tuna/miopen/parse_miopen_args.py b/tuna/miopen/parse_miopen_args.py index 5d608d5a..a90ba376 100644 --- a/tuna/miopen/parse_miopen_args.py +++ b/tuna/miopen/parse_miopen_args.py @@ -140,6 +140,19 @@ def get_import_cfg_parser( 'Tag to mark the origin of this config but skips the insert new config \ step in case the config does not exist in the table. Wildcard columns \ allowed for tagging') + parser.add_argument( + '--batch_size', + type=int, + dest='batch_size', + default=1000, + help='Batch size for bulk database operations (default: 1000). \ + Higher values are faster but use more memory.') + parser.add_argument( + '--disable_batch_import', + action='store_true', + dest='disable_batch_import', + help='Disable batch import optimization and use original one-by-one import. \ + Use this for debugging or if batch import has issues.') return parser diff --git a/tuna/miopen/subcmd/import_configs.py b/tuna/miopen/subcmd/import_configs.py index 5e23d2ac..c6aa1aba 100755 --- a/tuna/miopen/subcmd/import_configs.py +++ b/tuna/miopen/subcmd/import_configs.py @@ -42,6 +42,8 @@ from tuna.miopen.driver.batchnorm import DriverBatchNorm from tuna.miopen.db.tables import MIOpenDBTables from tuna.miopen.db.benchmark import Framework, Model +from tuna.miopen.db.tensortable import TensorTable +from tuna.utils.db_utility import build_dict_val_key def create_query(tag: str, mark_recurrent: bool, config_id: int) -> dict: @@ -155,6 +157,606 @@ def parse_line(args: argparse.Namespace, line: str, counts: dict, return True +def batch_insert_tensors(session, tensor_dicts: List[dict], logger: logging.Logger) -> dict: + """Insert tensors in batch with duplicate handling. Returns dict mapping tensor_key -> tensor_id""" + if not tensor_dicts: + return {} + + tensor_map = {} + + # Get existing tensors to avoid duplicates + tensor_keys = [build_dict_val_key(TensorTable(**td)) for td in tensor_dicts] + unique_dicts = {build_dict_val_key(TensorTable(**td)): td for td in tensor_dicts} + + # Query existing tensors + existing_tensors = session.query(TensorTable).all() + for tensor in existing_tensors: + key = build_dict_val_key(tensor) + if key in unique_dicts: + tensor_map[key] = tensor.id + + # Filter out already existing tensors + new_tensor_dicts = [td for key, td in unique_dicts.items() if key not in tensor_map] + + if new_tensor_dicts: + try: + # Bulk insert new tensors + session.bulk_insert_mappings(TensorTable, new_tensor_dicts, return_defaults=True) + session.flush() + + # Query back to get IDs + for td in new_tensor_dicts: + key = build_dict_val_key(TensorTable(**td)) + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + tensor_map[key] = result[0] + except IntegrityError as err: + logger.warning(f"Bulk tensor insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for td in new_tensor_dicts: + try: + tensor = TensorTable(**td) + tensor.valid = 1 + session.add(tensor) + session.flush() + key = build_dict_val_key(tensor) + tensor_map[key] = tensor.id + except IntegrityError: + session.rollback() + # Already exists, query it + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_map[key] = result[0] + + return tensor_map + + +def batch_insert_configs(session, drivers: List[DriverBase], dbt: MIOpenDBTables, + logger: logging.Logger) -> Tuple[int, List]: + """Insert configs in batch with duplicate handling. Returns (count_inserted, list_of_config_objects)""" + if not drivers: + return 0, [] + + # Get config objects + config_objs = [driver.get_db_obj(keep_id=True) for driver in drivers] + + # Filter out configs that already have IDs (already in DB) + new_configs = [c for c in config_objs if c.id is None] + + if not new_configs: + return 0, config_objs + + # Get MD5s of new configs to check for existing + new_md5s = [c.md5 for c in new_configs] + existing_md5s = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(new_md5s) + ).all() + existing_md5_set = {row[0] for row in existing_md5s} + + # Filter to only truly new configs + truly_new_configs = [c for c in new_configs if c.md5 not in existing_md5_set] + + inserted_count = 0 + if truly_new_configs: + try: + # Try bulk insert + session.bulk_save_objects(truly_new_configs, return_defaults=True) + session.flush() + inserted_count = len(truly_new_configs) + except IntegrityError as err: + logger.warning(f"Bulk config insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for config in truly_new_configs: + try: + session.add(config) + session.flush() + inserted_count += 1 + except IntegrityError: + session.rollback() + + # Refresh all configs to get IDs + for config in config_objs: + if config.id is None: + # Query to get ID + result = session.query(dbt.config_table.id).filter_by(md5=config.md5).first() + if result: + config.id = result[0] + + return inserted_count, config_objs + + +def batch_insert_tags(session, config_ids: List[int], dbt: MIOpenDBTables, + args: argparse.Namespace, logger: logging.Logger) -> int: + """Insert tags in batch with duplicate handling. Returns count of tags inserted""" + if not config_ids or not (args.tag or args.mark_recurrent): + return 0 + + # Build tag dictionaries + tag_dicts = [] + for config_id in config_ids: + tag_dict = create_query(args.tag, args.mark_recurrent, config_id) + tag_dicts.append(tag_dict) + + if not tag_dicts: + return 0 + + # Get existing tags to avoid duplicates + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_config_ids = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_config_ids] + + inserted_count = 0 + if tag_dicts: + try: + # Try bulk insert + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + inserted_count = len(tag_dicts) + except IntegrityError as err: + logger.warning(f"Bulk tag insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for tag_dict in tag_dicts: + try: + tag_obj = dbt.config_tags_table(**tag_dict) + session.merge(tag_obj) + session.flush() + inserted_count += 1 + except IntegrityError: + session.rollback() + + return inserted_count + + +def get_or_create_tensor_ids(session, tensor_dicts: List[dict], logger: logging.Logger) -> dict: + """Get or create tensor IDs in bulk. Returns dict mapping tensor_key -> tensor_id""" + if not tensor_dicts: + return {} + + import time + t0 = time.time() + + # Build unique tensor dict map + unique_tensors = {} + for td in tensor_dicts: + td['valid'] = 1 + key = build_dict_val_key(TensorTable(**td)) + unique_tensors[key] = td + + logger.info("Found %d unique tensors to process", len(unique_tensors)) + + # Query existing tensors in bulk + existing_tensors = session.query(TensorTable).all() + tensor_id_map = {} + for tensor in existing_tensors: + key = build_dict_val_key(tensor) + if key in unique_tensors: + tensor_id_map[key] = tensor.id + + logger.info("Found %d existing tensors in DB (%.2fs)", len(tensor_id_map), time.time() - t0) + + # Insert new tensors + new_tensors = [td for key, td in unique_tensors.items() if key not in tensor_id_map] + if new_tensors: + t0 = time.time() + try: + session.bulk_insert_mappings(TensorTable, new_tensors) + session.flush() + logger.info("Bulk inserted %d new tensors (%.2fs)", len(new_tensors), time.time() - t0) + + # Query back to get IDs + for td in new_tensors: + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_id_map[key] = result[0] + except IntegrityError as err: + logger.warning(f"Bulk tensor insert failed: {err}") + session.rollback() + # Fallback to individual + for td in new_tensors: + try: + tensor = TensorTable(**td) + session.add(tensor) + session.flush() + key = build_dict_val_key(tensor) + tensor_id_map[key] = tensor.id + except IntegrityError: + session.rollback() + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_id_map[key] = result[0] + + return tensor_id_map + + +def import_cfgs_batch_ultra(args: argparse.Namespace, dbt: MIOpenDBTables, + logger: logging.Logger, batch_size: int = 1000) -> dict: + """Ultra-optimized batch import bypassing get_db_obj()""" + import time + import hashlib + from tuna.miopen.utils.metadata import TENSOR_PRECISION + + connect_db() + + counts = {} + counts['cnt_configs'] = 0 + counts['cnt_tagged_configs'] = set() + + # Step 1: Read and parse + start_time = time.time() + logger.info("Reading and parsing config file...") + drivers_to_process = [] + unique_lines = set() + + with open(os.path.expanduser(args.file_name), "r") as infile: + for line in infile: + line = line.strip() + if line: + unique_lines.add(line) + + for line in unique_lines: + try: + if args.config_type == ConfigType.batch_norm: + driver = DriverBatchNorm(line, args.command) + else: + driver = DriverConvolution(line, args.command) + + if not args.batch_list: + drivers_to_process.append(driver) + else: + for bsz in args.batch_list: + driver_copy = DriverBatchNorm(line, args.command) if args.config_type == ConfigType.batch_norm else DriverConvolution(line, args.command) + driver_copy.batchsize = bsz + drivers_to_process.append(driver_copy) + except ValueError as err: + logger.warning(f"Error parsing line: {err}") + + parse_time = time.time() - start_time + logger.info("Parsed %u driver objects (took %.2fs)", len(drivers_to_process), parse_time) + + # Step 2: Collect all unique tensors + start_time = time.time() + logger.info("Collecting tensor dictionaries...") + all_tensor_dicts = [] + for driver in drivers_to_process: + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + all_tensor_dicts.extend([input_t, weight_t]) + + logger.info("Collected %d tensor dicts (took %.2fs)", len(all_tensor_dicts), time.time() - start_time) + + # Step 3: Batch process tensors and configs + total_drivers = len(drivers_to_process) + logger.info(f"Starting ultra-optimized batch import (batch size: {batch_size})...") + overall_start = time.time() + + for batch_start in range(0, total_drivers, batch_size): + batch_end = min(batch_start + batch_size, total_drivers) + batch = drivers_to_process[batch_start:batch_end] + + with DbSession() as session: + # Collect tensors for this batch + batch_tensor_dicts = [] + for driver in batch: + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + batch_tensor_dicts.extend([input_t, weight_t]) + + # Get/create tensor IDs + tensor_id_map = get_or_create_tensor_ids(session, batch_tensor_dicts, logger) + + # Build config dictionaries manually (bypass get_db_obj) + config_dicts = [] + for driver in batch: + try: + # Get tensor IDs + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + input_t['valid'] = 1 + weight_t['valid'] = 1 + + input_key = build_dict_val_key(TensorTable(**input_t)) + weight_key = build_dict_val_key(TensorTable(**weight_t)) + + if input_key not in tensor_id_map or weight_key not in tensor_id_map: + logger.warning("Missing tensor IDs for config, skipping") + continue + + # Build config dict manually + config_dict = { + 'batchsize': driver.batchsize, + 'spatial_dim': driver.spatial_dim, + 'pad_h': driver.pad_h, + 'pad_w': driver.pad_w, + 'pad_d': driver.pad_d, + 'conv_stride_h': driver.conv_stride_h, + 'conv_stride_w': driver.conv_stride_w, + 'conv_stride_d': driver.conv_stride_d, + 'dilation_h': driver.dilation_h, + 'dilation_w': driver.dilation_w, + 'dilation_d': driver.dilation_d, + 'group_count': driver.group_count, + 'mode': driver.mode, + 'pad_mode': driver.pad_mode, + 'trans_output_pad_h': driver.trans_output_pad_h, + 'trans_output_pad_w': driver.trans_output_pad_w, + 'trans_output_pad_d': driver.trans_output_pad_d, + 'direction': driver.direction, + 'input_tensor': tensor_id_map[input_key], + 'weight_tensor': tensor_id_map[weight_key], + 'out_layout': driver.out_layout, + 'driver': str(driver) + } + + # Compute MD5 + dict_copy = config_dict.copy() + dict_copy.pop('driver') + md5_str = str(sorted(dict_copy.items())) + config_dict['md5'] = hashlib.md5(md5_str.encode()).hexdigest() + + config_dicts.append(config_dict) + except Exception as err: + logger.warning(f"Error building config dict: {err}") + + # Bulk insert configs + if config_dicts: + # Check for existing + md5s = [cd['md5'] for cd in config_dicts] + existing = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(md5s) + ).all() + existing_set = {row[0] for row in existing} + + new_configs = [cd for cd in config_dicts if cd['md5'] not in existing_set] + + if new_configs: + try: + session.bulk_insert_mappings(dbt.config_table, new_configs) + session.flush() + counts['cnt_configs'] += len(new_configs) + except IntegrityError as err: + logger.warning(f"Bulk config insert failed: {err}") + session.rollback() + + # Get config IDs for tagging + if args.tag or args.mark_recurrent: + config_ids = [] + for cd in config_dicts: + result = session.query(dbt.config_table.id).filter_by(md5=cd['md5']).first() + if result: + config_ids.append(result[0]) + + if config_ids: + tag_dicts = [create_query(args.tag, args.mark_recurrent, cid) for cid in config_ids] + + # Filter existing tags + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_tag_set = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_tag_set] + + if tag_dicts: + try: + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + counts['cnt_tagged_configs'].update([td['config'] for td in tag_dicts]) + except IntegrityError: + session.rollback() + + session.commit() + + if batch_end % 1000 == 0 or batch_end == total_drivers: + logger.info(f"Processed {batch_end}/{total_drivers} configs") + + total_time = time.time() - overall_start + logger.info("Ultra-optimized import complete (took %.2fs, %.2f configs/sec)", + total_time, total_drivers / total_time if total_time > 0 else 0) + return counts + + +def import_cfgs_batch(args: argparse.Namespace, dbt: MIOpenDBTables, + logger: logging.Logger, batch_size: int = 1000) -> dict: + """Optimized batch import of configs with proper tensor handling""" + import time + from tuna.utils.db_utility import get_session_val_map + from tuna.miopen.driver.base import MIOpenDriver + + connect_db() + + counts = {} + counts['cnt_configs'] = 0 + counts['cnt_tagged_configs'] = set() + unique_lines = set() + + # Step 1: Read and deduplicate file + start_time = time.time() + logger.info("Reading and deduplicating config file...") + with open(os.path.expanduser(args.file_name), "r") as infile: + for line_cnt, line in enumerate(infile, 1): + line = line.strip() + if line: + unique_lines.add(line) + if line_cnt % 10000 == 0: + logger.info("Parsed: %u lines, unique configs: %u", line_cnt, len(unique_lines)) + + parse_time = time.time() - start_time + logger.info("File parsing complete. Total lines: %u, unique configs: %u (took %.2fs)", + line_cnt, len(unique_lines), parse_time) + + # Step 2: Pre-load tensor cache to avoid repeated queries + start_time = time.time() + logger.info("Pre-loading tensor cache...") + with DbSession() as session: + tensor_attr = [column.name for column in TensorTable.__table__.columns] + MIOpenDriver.tensor_id_map = get_session_val_map(session, TensorTable, tensor_attr) + cache_time = time.time() - start_time + logger.info("Tensor cache loaded with %u entries (took %.2fs)", + len(MIOpenDriver.tensor_id_map), cache_time) + + # Step 3: Parse all driver objects + start_time = time.time() + logger.info("Parsing driver commands...") + drivers_to_process = [] + for line in unique_lines: + try: + if args.config_type == ConfigType.batch_norm: + driver = DriverBatchNorm(line, args.command) + else: + driver = DriverConvolution(line, args.command) + + if not args.batch_list: + drivers_to_process.append(driver) + else: + for bsz in args.batch_list: + driver_copy = DriverBatchNorm(line, args.command) if args.config_type == ConfigType.batch_norm else DriverConvolution(line, args.command) + driver_copy.batchsize = bsz + drivers_to_process.append(driver_copy) + except ValueError as err: + logger.warning(f"Error parsing line: {err}") + + driver_parse_time = time.time() - start_time + logger.info("Parsed %u driver objects to import (took %.2fs)", + len(drivers_to_process), driver_parse_time) + + # Step 4: Process in batches with true batch operations + total_drivers = len(drivers_to_process) + start_time = time.time() + logger.info(f"Starting batch import (batch size: {batch_size})...") + + batch_times = {'get_db_obj': 0, 'check_existing': 0, 'insert_configs': 0, 'insert_tags': 0, 'commit': 0} + + for batch_start in range(0, total_drivers, batch_size): + batch_end = min(batch_start + batch_size, total_drivers) + batch = drivers_to_process[batch_start:batch_end] + batch_start_time = time.time() + + with DbSession() as session: + # Collect all config objects for this batch + t0 = time.time() + config_objs = [] + for driver in batch: + try: + config_obj = driver.get_db_obj(keep_id=True) + config_objs.append((driver, config_obj)) + except ValueError as err: + logger.warning(f"Error creating config object: {err}") + batch_times['get_db_obj'] += time.time() - t0 + + if not args.tag_only: + # Batch insert configs + t0 = time.time() + new_configs = [c for d, c in config_objs if c.id is None] + + if new_configs: + # Check for existing configs by MD5 + new_md5s = [c.md5 for c in new_configs] + existing_md5s = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(new_md5s) + ).all() + existing_md5_set = {row[0] for row in existing_md5s} + batch_times['check_existing'] += time.time() - t0 + + # Filter to truly new configs + truly_new = [c for c in new_configs if c.md5 not in existing_md5_set] + + if truly_new: + t0 = time.time() + try: + session.bulk_save_objects(truly_new, return_defaults=True) + session.flush() + counts['cnt_configs'] += len(truly_new) + except IntegrityError as err: + logger.warning(f"Bulk insert failed, using individual inserts: {err}") + session.rollback() + for config in truly_new: + try: + session.add(config) + session.flush() + counts['cnt_configs'] += 1 + except IntegrityError: + session.rollback() + batch_times['insert_configs'] += time.time() - t0 + + # Refresh configs to get IDs + for config in new_configs: + if config.id is None: + result = session.query(dbt.config_table.id).filter_by(md5=config.md5).first() + if result: + config.id = result[0] + + # Batch insert tags + if args.tag or args.mark_recurrent: + t0 = time.time() + config_ids = [c.id for d, c in config_objs if c.id is not None] + + if config_ids: + tag_dicts = [create_query(args.tag, args.mark_recurrent, cid) for cid in config_ids] + + # Filter out existing tags + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_set = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_set] + + if tag_dicts: + try: + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + counts['cnt_tagged_configs'].update([td['config'] for td in tag_dicts]) + except IntegrityError as err: + logger.warning(f"Bulk tag insert failed, using individual inserts: {err}") + session.rollback() + for tag_dict in tag_dicts: + try: + tag_obj = dbt.config_tags_table(**tag_dict) + session.merge(tag_obj) + session.flush() + counts['cnt_tagged_configs'].add(tag_dict['config']) + except IntegrityError: + session.rollback() + batch_times['insert_tags'] += time.time() - t0 + + # Commit the entire batch + t0 = time.time() + try: + session.commit() + except IntegrityError as err: + logger.error(f"Batch commit failed: {err}") + session.rollback() + batch_times['commit'] += time.time() - t0 + + if batch_end % 1000 == 0 or batch_end == total_drivers: + batch_elapsed = time.time() - batch_start_time + logger.info(f"Processed {batch_end}/{total_drivers} configs (batch took {batch_elapsed:.2f}s)") + + total_import_time = time.time() - start_time + logger.info("Database import complete (took %.2fs)", total_import_time) + logger.info("Timing breakdown: get_db_obj=%.2fs, check_existing=%.2fs, insert_configs=%.2fs, insert_tags=%.2fs, commit=%.2fs", + batch_times['get_db_obj'], batch_times['check_existing'], + batch_times['insert_configs'], batch_times['insert_tags'], batch_times['commit']) + + logger.info("Database import complete.") + return counts + + def import_cfgs(args: argparse.Namespace, dbt: MIOpenDBTables, logger: logging.Logger) -> dict: """import configs to mysql from file with driver invocations""" @@ -163,21 +765,29 @@ def import_cfgs(args: argparse.Namespace, dbt: MIOpenDBTables, counts: dict = {} counts['cnt_configs'] = 0 counts['cnt_tagged_configs'] = set() - unique_lines: List[str] = [] + unique_lines = set() + + logger.info("Reading and deduplicating config file...") with open(os.path.expanduser(args.file_name), "r") as infile: # pylint: disable=unspecified-encoding - line_cnt = 0 - for line in infile: - line_cnt += 1 + for line_cnt, line in enumerate(infile, 1): line = line.strip() - if not line in unique_lines: - unique_lines.append(line) - logger.info("parsed: %u, unique: %u", line_cnt, len(unique_lines)) - for line in unique_lines: - try: - parse_line(args, line, counts, dbt, logger) - except ValueError as err: - logger.warning(err) - + if line: # Skip empty lines + unique_lines.add(line) + if line_cnt % 10000 == 0: + logger.info("Parsed: %u lines, unique configs: %u", line_cnt, len(unique_lines)) + + logger.info("File parsing complete. Total lines: %u, unique configs: %u", line_cnt, len(unique_lines)) + logger.info("Starting database import...") + + for idx, line in enumerate(unique_lines, 1): + try: + parse_line(args, line, counts, dbt, logger) + if idx % 1000 == 0: + logger.info("Processed %u/%u unique configs", idx, len(unique_lines)) + except ValueError as err: + logger.warning(err) + + logger.info("Database import complete.") return counts @@ -352,7 +962,20 @@ def run_import_configs(args: argparse.Namespace, return True set_import_cfg_batches(args) - counts = import_cfgs(args, dbt, logger) + + # Use batch import by default unless disabled or tag_only mode + use_batch = not getattr(args, 'disable_batch_import', False) and not args.tag_only + batch_size = getattr(args, 'batch_size', 1000) + + if use_batch: + logger.info("Using optimized batch import (batch_size=%d)", batch_size) + counts = import_cfgs_batch(args, dbt, logger, batch_size) + else: + if args.tag_only: + logger.info("Using original import (tag_only mode)") + else: + logger.info("Using original import (batch import disabled)") + counts = import_cfgs(args, dbt, logger) logger.info('New configs added: %u', counts['cnt_configs']) if args.tag or args.tag_only: diff --git a/tuna/miopen/worker/fin_class.py b/tuna/miopen/worker/fin_class.py index 19d94a2f..41c3f74f 100644 --- a/tuna/miopen/worker/fin_class.py +++ b/tuna/miopen/worker/fin_class.py @@ -335,8 +335,11 @@ def query_cfgs(self, label=None, skip_existing=False, config_limit=None): .filter(self.dbt.config_table.valid == 1) if label: - query = query.filter(self.dbt.config_table.id == self.dbt.config_tags_table.config)\ - .filter(self.dbt.config_tags_table.tag == label) + query = query.join( + self.dbt.config_tags_table, + self.dbt.config_table.id == self.dbt.config_tags_table.config + ).filter(self.dbt.config_tags_table.tag == label) + # Skip configs that already have applicability data in this session if skip_existing: From 2cce59a85ba634780055ab752f1a2a0fd4f92f18 Mon Sep 17 00:00:00 2001 From: amd-bartgips Date: Thu, 22 Jan 2026 07:03:48 -0600 Subject: [PATCH 33/33] feat(config): update database name to silo_heuristic_2d, share changes w Antti Change TUNA_DB_NAME environment variable from 'silo_heuristic' to 'silo_heuristic_2d' in the MI355 setup script to point to the 2D-specific heuristic database instance. --- .gitignore | 2 + tuna/mituna_interface.py | 223 ++++++++++++++++----------------------- 2 files changed, 95 insertions(+), 130 deletions(-) diff --git a/.gitignore b/.gitignore index 52675f89..f4b379f9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ __pycache__ *.rej *.orig +.cline* +*.egg-info \ No newline at end of file diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 2fc82b36..78e9fdcf 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -67,7 +67,20 @@ class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods """Interface class extended by libraries. The purpose of this class is to define - common functionalities.""" + common functionalities. + + Job Progress Tracking: + ---------------------- + The distributor uses database queries to track job progress, ensuring accuracy + and eliminating synchronization issues with in-memory tracking lists. + + - claimed_job_ids: List of job IDs claimed by this distributor instance + - Progress checking: Queries database directly for actual job states + - No reconciliation needed: Database is the single source of truth + + The Redis consumer still runs to process results and update the database, + but progress decisions are based solely on database queries. + """ def __init__(self, library=Library.MIOPEN) -> None: @@ -362,94 +375,63 @@ def celery_enqueue_call(self, context, q_name, task_id=False): raise NotImplementedError("Not implemented") def _should_wait_for_progress(self, job_batch_size): - """Check if we should wait before fetching more jobs based on progress""" - # Convert to sets for set operations - claimed_set = set(self.claimed_job_ids) - completed_set = set(self.completed_job_ids) - our_in_progress_count = len(claimed_set - completed_set) - progress_threshold = job_batch_size * self.progress_factor - - self.logger.info( - "Jobs in progress: %d, completed: %d, threshold: %.0f", - our_in_progress_count, - len(completed_set), - progress_threshold, - ) - - return our_in_progress_count >= progress_threshold - - def reconcile_tracking_state(self): - """Reconcile tracking lists with actual database state - - This method queries the database to check the actual state of claimed jobs - and updates the tracking lists accordingly. This prevents the distributor - from getting stuck waiting for jobs that have already completed. + """Check if we should wait before fetching more jobs based on database state - Returns: - Number of jobs reconciled + This method queries the database directly to get accurate job counts, + eliminating reliance on potentially stale in-memory tracking lists. """ if not self.claimed_job_ids: - self.logger.info("No claimed jobs to reconcile") - return 0 - - claimed_set = set(self.claimed_job_ids) - completed_set = set(self.completed_job_ids) - in_progress_set = claimed_set - completed_set - - if not in_progress_set: - self.logger.info("No in-progress jobs to reconcile") - return 0 + # No jobs claimed yet, don't wait + return False - self.logger.info("Reconciling %d in-progress jobs with database state", - len(in_progress_set)) + progress_threshold = job_batch_size * self.progress_factor - # Query database for actual state of these jobs + # Query database for actual state of claimed jobs with DbSession() as session: try: # Batch the query to avoid SQL statement too long - in_progress_list = list(in_progress_set) + claimed_list = list(self.claimed_job_ids) batch_size = 1000 - reconciled_count = 0 + total_in_progress = 0 + total_completed = 0 - for i in range(0, len(in_progress_list), batch_size): - batch = in_progress_list[i:i + batch_size] + for i in range(0, len(claimed_list), batch_size): + batch = claimed_list[i:i + batch_size] id_str = ','.join(map(str, batch)) - query = f""" - SELECT id, state FROM {self.dbt.job_table.__tablename__} + # Count jobs still in progress states + in_progress_query = f""" + SELECT COUNT(*) FROM {self.dbt.job_table.__tablename__} WHERE id IN ({id_str}) + AND state IN ('eval_start', 'compile_start') """ - results = session.execute(text(query)).fetchall() - - completed_jobs = 0 - removed_jobs = 0 - - for job_id, state in results: - if state in ['evaluated', 'errored']: - # Job is complete but not tracked - add to completed - if job_id not in self.completed_job_ids: - self.completed_job_ids.append(job_id) - completed_jobs += 1 - reconciled_count += 1 - elif state in ['compiled', 'new']: - # Job was reset but still in claimed - remove from claimed - if job_id in self.claimed_job_ids: - self.claimed_job_ids.remove(job_id) - removed_jobs += 1 - reconciled_count += 1 - # Jobs in 'eval_start' state are legitimately in progress - no action needed + batch_in_progress = session.execute(text(in_progress_query)).scalar() + total_in_progress += batch_in_progress - # Log batch summary instead of individual jobs - if completed_jobs > 0 or removed_jobs > 0: - self.logger.info("Batch %d: marked %d completed, removed %d from claimed", - i // batch_size + 1, completed_jobs, removed_jobs) + # Count completed jobs + completed_query = f""" + SELECT COUNT(*) FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + AND state IN ('evaluated', 'errored', 'completed') + """ + batch_completed = session.execute(text(completed_query)).scalar() + total_completed += batch_completed - self.logger.info("Reconciliation complete: %d total jobs updated", reconciled_count) - return reconciled_count + self.logger.info( + "DB query - Jobs in progress: %d, completed: %d, threshold: %.0f", + total_in_progress, + total_completed, + progress_threshold, + ) + + return total_in_progress >= progress_threshold except Exception as err: # pylint: disable=broad-exception-caught - self.logger.error("Error during reconciliation: %s", err) - return 0 + self.logger.error("Error querying job progress: %s", err) + # On error, be conservative: assume we should wait (return True) + # This prevents over-fetching if database is temporarily unavailable + self.logger.warning("Defaulting to WAIT due to database error (conservative approach)") + return True def _fetch_jobs_with_retry(self, job_batch_size, @@ -542,11 +524,6 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) loop_iteration = 0 - - # Track consecutive waits to detect stale state - consecutive_waits = 0 - last_in_progress_count = -1 - reconcile_threshold = int(os.environ.get('TUNA_RECONCILE_THRESHOLD', 5)) while True: loop_iteration += 1 @@ -555,44 +532,11 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): self.logger.info("=== Enqueue loop iteration %d ===", loop_iteration) # 1. Check if we should wait for progress (skip on first batch) + # Database query now provides accurate state, no reconciliation needed if not is_first_batch and self._should_wait_for_progress(job_batch_size): - claimed_set = set(self.claimed_job_ids) - completed_set = set(self.completed_job_ids) - current_in_progress = len(claimed_set - completed_set) - - # Check if we're stuck waiting with the same in-progress count - if current_in_progress == last_in_progress_count: - consecutive_waits += 1 - # Only log warning every 5 waits to reduce verbosity - if consecutive_waits % 5 == 0 or consecutive_waits >= reconcile_threshold - 2: - self.logger.warning( - "Consecutive waits: %d/%d with same in-progress count: %d", - consecutive_waits, reconcile_threshold, current_in_progress) - else: - # Log when wait state changes - if consecutive_waits > 0: - self.logger.info("Wait state changed - in-progress count: %d -> %d", - last_in_progress_count, current_in_progress) - consecutive_waits = 0 - last_in_progress_count = current_in_progress - - # Trigger reconciliation if stuck waiting too long - if consecutive_waits >= reconcile_threshold: - self.logger.warning( - "RECONCILIATION TRIGGERED: Stuck waiting for %d iterations with %d jobs in progress", - consecutive_waits, current_in_progress) - reconciled = self.reconcile_tracking_state() - self.logger.info("Reconciled %d jobs - resetting wait counter", reconciled) - consecutive_waits = 0 - last_in_progress_count = -1 - # Don't sleep, immediately retry fetching jobs - continue - - # Only log wait message on first wait or every 10th wait - if consecutive_waits == 1 or consecutive_waits % 10 == 0: - self.logger.info( - "Waiting for batch progress (iteration %d, wait #%d)", - loop_iteration, consecutive_waits) + self.logger.info( + "Waiting for batch progress (iteration %d)", + loop_iteration) # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs consecutive_empty_fetches = 0 time.sleep(poll_interval) @@ -656,24 +600,43 @@ def enqueue_jobs(self, job_counter, job_batch_size, q_name): self.logger.info("Batch processed successfully (iteration %d)", loop_iteration) def cleanup_completed_jobs(self): - """Periodically clean up old job tracking data""" - # Keep lists from growing indefinitely + """Periodically clean up old job tracking data + + Since we now query the database for accurate progress tracking, + we only need to keep claimed_job_ids from growing too large. + The completed_job_ids list is kept for Redis consumer compatibility + but is not used for progress decisions. + """ + # Keep claimed_job_ids list from growing indefinitely max_tracking_size = 10000 - if len(self.completed_job_ids) > max_tracking_size: - # Keep only the most recent completions - recent_completions = list(self.completed_job_ids)[-5000:] - # Clear and repopulate the shared list - del self.completed_job_ids[:] - self.completed_job_ids.extend(recent_completions) - - # Remove old claimed jobs that are completed - completed_set = set(recent_completions[:-1000]) - claimed_list = [ - job_id for job_id in self.claimed_job_ids - if job_id not in completed_set - ] - del self.claimed_job_ids[:] - self.claimed_job_ids.extend(claimed_list) + if len(self.claimed_job_ids) > max_tracking_size: + # Query database to find which claimed jobs are actually complete + with DbSession() as session: + try: + claimed_list = list(self.claimed_job_ids) + id_str = ','.join(map(str, claimed_list)) + + # Get IDs of jobs that are complete + query = f""" + SELECT id FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + AND state IN ('evaluated', 'errored', 'completed') + """ + completed_ids = {row[0] for row in session.execute(text(query)).fetchall()} + + # Keep only jobs that are still in progress + active_jobs = [job_id for job_id in claimed_list if job_id not in completed_ids] + + # Update the list + del self.claimed_job_ids[:] + self.claimed_job_ids.extend(active_jobs) + + self.logger.info( + "Cleaned up tracking: removed %d completed jobs, kept %d active jobs", + len(completed_ids), len(active_jobs)) + + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error during cleanup: %s", err) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key"""