diff --git a/.gitignore b/.gitignore index 5e544db19..fd6e2ef61 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ __pycache__ *.rej *.orig +.cline* +*.egg-info myvenv venv diff --git a/Dockerfile b/Dockerfile index b043df817..9599b6f97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ARG OSDB_BKC_VERSION= ARG HASVER=${ROCMVERSION:+$ROCMVERSION} ARG HASVER=${HASVER:-$OSDB_BKC_VERSION} -ARG BASEIMAGE=rocm/miopen:ci_3708da +ARG BASEIMAGE=rocm/miopen:ci_7c45f0 ARG UBUNTU=ubuntu:22.04 #use UBUNTU with rocm version set @@ -18,6 +18,8 @@ FROM $USEIMAGE as dtuna-ver-0 #args before from are wiped ARG ROCMVERSION= ARG OSDB_BKC_VERSION= +# pass through baseimage for later use +ARG BASEIMAGE RUN test -d /opt/rocm*; \ if [ $? -eq 0 ] ; then \ @@ -71,17 +73,21 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --all apt-utils \ build-essential \ cmake \ - clang-format-12 \ + clang-format \ curl \ doxygen \ gdb \ git \ lbzip2 \ lcov \ + libboost-filesystem-dev \ + libbz2-dev \ + libeigen3-dev \ libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ mysql-client \ + nlohmann-json3-dev \ openssh-server \ pkg-config \ python3 \ @@ -117,15 +123,64 @@ ENV UBSAN_OPTIONS=print_stacktrace=1 RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb +# Install frugally-deep and its dependencies (header-only libraries) +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + # Clone FunctionalPlus + git clone https://github.com/Dobiasd/FunctionalPlus.git /tmp/FunctionalPlus && \ + cd /tmp/FunctionalPlus && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/usr/local .. && \ + make install && \ + # Clone frugally-deep + git clone https://github.com/Dobiasd/frugally-deep.git /tmp/frugally-deep && \ + cd /tmp/frugally-deep && \ + mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/usr/local .. && \ + make install && \ + # Clean up + rm -rf /tmp/FunctionalPlus /tmp/frugally-deep; \ + fi + + +# ============================================ +# Check if BOTH MIOpen and Fin are already installed +# ============================================ +# We check both together because Fin depends on MIOpen headers +# If either is missing, we build both to ensure compatibility +RUN if [ -f /opt/rocm/lib/libMIOpen.so ] && [ -d /opt/rocm/include/miopen ] && \ + ([ -f /opt/rocm/bin/fin ] || [ -f /opt/rocm/miopen/bin/fin ]); then \ + echo "=== Both MIOpen and Fin already installed, skipping builds ==="; \ + echo "export SKIP_MIOPEN_BUILD=1" >> /env; \ + echo "export SKIP_FIN_BUILD=1" >> /env; \ + else \ + echo "=== Building MIOpen and Fin from source (Fin needs MIOpen headers) ==="; \ + fi +# ============================================ +# Clone MIOpen (if needed) +# ============================================ ARG ROCM_LIBS_DIR=/root/rocm-libraries ARG MIOPEN_DIR=$ROCM_LIBS_DIR/projects/miopen -#Clone MIOpen -RUN git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + git clone --filter=blob:none --sparse https://github.com/ROCm/rocm-libraries.git $ROCM_LIBS_DIR; \ + else \ + mkdir -p $ROCM_LIBS_DIR/projects && mkdir -p $MIOPEN_DIR; \ + fi + +# Run sparse-checkout from the git repo root +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + cd $ROCM_LIBS_DIR && git sparse-checkout set projects/miopen; \ + fi + WORKDIR $MIOPEN_DIR -RUN git sparse-checkout set projects/miopen -ARG MIOPEN_BRANCH=4940cf3ec -RUN git pull && git checkout $MIOPEN_BRANCH + +# not sure what this commit is, using latest develop for now +# ARG MIOPEN_BRANCH=4940cf3ec +ARG MIOPEN_BRANCH=develop +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + git pull && git checkout $MIOPEN_BRANCH; \ + fi ARG PREFIX=/opt/rocm ARG MIOPEN_DEPS=$MIOPEN_DIR/deps @@ -133,7 +188,7 @@ ARG MIOPEN_DEPS=$MIOPEN_DIR/deps # Install dependencies # included in rocm/miopen:ci_xxxxxx ARG BUILD_MIOPEN_DEPS= ARG ARCH_TARGET= -RUN . /env; if [ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]; then\ +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ] && ([ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]); then\ pip install cget; \ if ! [ -z $ARCH_TARGET ]; then \ sed -i "s#\(composable_kernel.*\)#\1 -DGPU_TARGETS=\"$ARCH_TARGET\"#" requirements.txt; \ @@ -141,6 +196,8 @@ RUN . /env; if [ -z $NO_ROCM_INST ] || ! [ -z $BUILD_MIOPEN_DEPS ]; then\ apt-get remove -y composablekernel-dev miopen-hip; \ CXX=/opt/rocm/llvm/bin/clang++ cget install -f ./dev-requirements.txt --prefix $MIOPEN_DEPS -DCMAKE_POLICY_VERSION_MINIMUM=3.5; \ git checkout requirements.txt; \ + echo "=== DEBUG: cget install completed, checking for composable_kernel ==="; \ + ls -la $MIOPEN_DEPS/lib/cmake/ || echo "No cmake configs found"; \ fi ARG TUNA_USER=miopenpdb @@ -150,36 +207,85 @@ WORKDIR $MIOPEN_DIR/build ARG MIOPEN_CACHE_DIR=/tmp/${TUNA_USER}/cache ARG MIOPEN_USER_DB_PATH=/tmp/$TUNA_USER/config/miopen # build kdb objects with offline clang compiler, disable comgr + hiprtc (which would make target id specific code objects) -ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=Off -DMIOPEN_USE_HIPRTC=Off -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS}" +ARG MIOPEN_CMAKE_ARGS="-DMIOPEN_USE_COMGR=on -DMIOPEN_USE_HIPRTC=On -DMIOPEN_INSTALL_CXX_HEADERS=On -DMIOPEN_CACHE_DIR=${MIOPEN_CACHE_DIR} -DMIOPEN_USER_DB_PATH=${MIOPEN_USER_DB_PATH} -DMIOPEN_BACKEND=${BACKEND} -DCMAKE_PREFIX_PATH=${MIOPEN_DEPS} -DBUILD_TESTING=Off -DMIOPEN_USE_MLIR=OFF" -RUN echo "MIOPEN: Selected $BACKEND backend." -RUN if [ $BACKEND = "OpenCL" ]; then \ - cmake -DMIOPEN_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ ${MIOPEN_CMAKE_ARGS} $MIOPEN_DIR ; \ - else \ - CXX=/opt/rocm/llvm/bin/clang++ cmake ${MIOPEN_CMAKE_ARGS} $MIOPEN_DIR ; \ +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + echo "MIOPEN: Selected $BACKEND backend."; \ + fi + + +# Debug: Check if cmake directory exists and list its contents +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + echo "=== DEBUG: Current directory ==="; \ + pwd; \ + echo "=== DEBUG: Parent directory contents ==="; \ + ls -la ..; \ + echo "=== DEBUG: Parent cmake directory ==="; \ + ls -la ../cmake/ || echo "cmake directory not found!"; \ + echo "=== DEBUG: CMAKE_MODULE_PATH value ==="; \ + echo "../cmake"; \ + echo "=== DEBUG: Checking if cmake files exist ==="; \ + test -f ../cmake/ClangCheck.cmake && echo "ClangCheck.cmake EXISTS" || echo "ClangCheck.cmake NOT FOUND"; \ + test -f ../cmake/TargetFlags.cmake && echo "TargetFlags.cmake EXISTS" || echo "TargetFlags.cmake NOT FOUND"; \ + test -f ../cmake/CheckCXXLinkerFlag.cmake && echo "CheckCXXLinkerFlag.cmake EXISTS" || echo "CheckCXXLinkerFlag.cmake NOT FOUND"; \ fi -RUN make -j $(nproc) -RUN make install -#Build Fin -WORKDIR $MIOPEN_DIR -RUN git submodule update --init --recursive +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + if [ $BACKEND = "OpenCL" ]; then \ + cmake -DMIOPEN_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ ${MIOPEN_CMAKE_ARGS} .. ; \ + else \ + CXX=/opt/rocm/llvm/bin/clang++ cmake ${MIOPEN_CMAKE_ARGS} .. ; \ + fi; \ + fi + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + make -j $(nproc) MIOpen; \ + make -j $(nproc) MIOpenDriver; \ + fi + +RUN . /env; if [ -z $SKIP_MIOPEN_BUILD ]; then \ + make install; \ + fi + +# ============================================ +# Build Fin (if needed) +# ============================================ +# Fin is built as a submodule of MIOpen, so we only build it if MIOpen was also built ARG FIN_DIR=$MIOPEN_DIR/fin + +# Initialize Fin submodule (only runs if MIOpen was built) +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + echo "=== Initializing Fin as MIOpen submodule ==="; \ + cd $MIOPEN_DIR && git submodule update --init --recursive; \ + fi + WORKDIR $FIN_DIR + # Can be a branch or a SHA ARG FIN_BRANCH=develop -RUN if ! [ -z $FIN_BRANCH ]; then \ - git fetch && git checkout $FIN_BRANCH; \ +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + if ! [ -z $FIN_BRANCH ]; then \ + git fetch && git checkout $FIN_BRANCH; \ + fi; \ fi + # Install dependencies #RUN cmake -P install_deps.cmake WORKDIR $FIN_DIR/_hip -RUN CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_PREFIX_PATH=$MIOPEN_DEPS $FIN_DIR -RUN make -j $(nproc) -RUN make install +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_PREFIX_PATH=$MIOPEN_DEPS $FIN_DIR; \ + fi + +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + make -j $(nproc); \ + fi + +RUN . /env; if [ -z $SKIP_FIN_BUILD ]; then \ + make install; \ + fi #SET MIOPEN ENVIRONMENT VARIABLES ENV MIOPEN_LOG_LEVEL=6 @@ -209,3 +315,26 @@ RUN python3 setup.py install # reset WORKDIR to /tuna WORKDIR /tuna + +# save BASEIMAGE as env variable +ENV BASEIMAGE=${BASEIMAGE} + +# install mysql-server and mysql-client +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + mysql-server \ + mysql-client + +# install redis-server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + redis-server + +# install RabbitMQ server +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + rabbitmq-server + +# install iproute2 +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -f -y --allow-unauthenticated \ + iproute2 + +# clean up apt cache +RUN apt-get clean && rm -rf /var/lib/apt/lists/* diff --git a/alembic/versions/054211043da5_benchmark.py b/alembic/versions/054211043da5_benchmark.py index 60ff89401..324ffe693 100644 --- a/alembic/versions/054211043da5_benchmark.py +++ b/alembic/versions/054211043da5_benchmark.py @@ -9,7 +9,7 @@ import sqlalchemy as sa from sqlalchemy.sql import func as sqla_func from sqlalchemy import Column, Integer, DateTime, text, ForeignKey, String -from tuna.miopen.benchmark import ModelEnum, FrameworkEnum +from tuna.miopen.db.benchmark import ModelEnum, FrameworkEnum from sqlalchemy.dialects.mysql import TINYINT, DOUBLE, MEDIUMBLOB, LONGBLOB from sqlalchemy import Float, BigInteger, String from sqlalchemy import Enum diff --git a/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py new file mode 100644 index 000000000..629a0438d --- /dev/null +++ b/alembic/versions/a1b2c3d4e5f6_add_machine_hostname_unique.py @@ -0,0 +1,38 @@ +"""add_machine_hostname_unique + +Revision ID: a1b2c3d4e5f6 +Revises: 219858383a66 +Create Date: 2025-11-18 02:38:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'a1b2c3d4e5f6' +down_revision = '219858383a66' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # First, remove any duplicate hostnames if they exist + # Keep the oldest entry (lowest id) for each hostname + op.execute(""" + DELETE m1 FROM machine m1 + INNER JOIN machine m2 + WHERE m1.id > m2.id + AND m1.hostname = m2.hostname + """) + + # Then add the unique constraint on hostname + # Using prefix length of 255 since hostname is TEXT type + op.create_index('idx_hostname', + 'machine', ['hostname'], + unique=True, + mysql_length={'hostname': 255}) + + +def downgrade() -> None: + # Remove the unique constraint + op.drop_index('idx_hostname', 'machine') diff --git a/requirements.txt b/requirements.txt index 43d5c0d26..de874cef1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ -aioredis==2.0.1 alembic==1.8.1 asn1crypto==0.24.0 -astroid==2.15.4 +astroid>=3.0.0 asyncio==3.4.3 attrs==19.3.0 backcall==0.1.0 @@ -11,7 +10,7 @@ celery==5.3.4 cryptography==43.0.1 decorator==4.3.0 docutils==0.20 -flask==2.2.5 +flask>=3.0.0 flower==2.0.1 idna==3.7 importlib-metadata>=6.6.0 @@ -23,18 +22,18 @@ markdown-it-py==3.0.0 mccabe==0.6.1 myst-parser==3.0.1 more-itertools==8.3.0 -numpy==1.24.2 +numpy>=1.26.0 opentelemetry-api==1.12.0rc2 opentelemetry-distro==0.32b0 opentelemetry-exporter-otlp-proto-http==1.11.1 packaging==24.1 -pandas==1.5.3 +pandas>=2.1.0 paramiko==3.5.0 parso==0.3.1 pathlib2==2.3.5 pexpect==4.6.0 pickleshare==0.7.5 -pluggy==0.13.1 +pluggy>=1.5.0 prompt-toolkit==3.0.36 protobuf<5.0.0dev,>=3.19.5 ptyprocess==0.6.0 @@ -42,21 +41,20 @@ py==1.10.0 pyasn1==0.4.4 pycparser==2.19 Pygments==2.18.0 -pylint<=2.17.0-dev0,>=2.15.4 +pylint>=3.0.0 pymysql==1.1.1 PyNaCl==1.5 pyparsing==2.4.7 -pytest==7.4.4 +pytest>=8.0.0 pytest-asyncio==0.21 -pyyaml==6.0 +pyyaml redis==5.0.1 -six==1.12.0 -sqlalchemy==1.3.23 +six>=1.16.0 +sqlalchemy>=2.0.0 sphinx==7.4.7 sphinx_rtd_theme==2.0.0 traitlets==4.3.2 twine==5.1.1 -typed-ast==1.5.4 types-PyYAML==6.0.12.6 types-paramiko==3.0.0.4 types-PyMySQL==1.0.19.5 diff --git a/tests/test_celery.py b/tests/test_celery.py index aa67462d0..1b6b49455 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -29,7 +29,7 @@ import pytest from time import sleep from multiprocessing import Value -import aioredis +import redis.asyncio as aioredis import pytest_asyncio from sqlalchemy.inspection import inspect diff --git a/tuna/db/session_mixin.py b/tuna/db/session_mixin.py index 52b66a300..6df15af73 100644 --- a/tuna/db/session_mixin.py +++ b/tuna/db/session_mixin.py @@ -41,14 +41,14 @@ class SessionMixin(): """Session Mixin to provide interface for the session table""" - arch: str = Column(String(length=20), nullable=False, server_default="") - num_cu: int = Column(Integer, nullable=False) - rocm_v: str = Column(String(length=64), nullable=False) - reason: str = Column(String(length=60), nullable=False) - ticket: str = Column(String(length=64), nullable=False, server_default="N/A") - docker: str = Column(String(length=128), - nullable=False, - server_default="miopentuna") + arch = Column(String(length=20), nullable=False, server_default="") + num_cu = Column(Integer, nullable=False) + rocm_v = Column(String(length=64), nullable=False) + reason = Column(String(length=60), nullable=False) + ticket = Column(String(length=64), nullable=False, server_default="N/A") + docker = Column(String(length=128), + nullable=False, + server_default="miopentuna") def __init__(self): self.id: int = 0 # pylint: disable=invalid-name @@ -60,7 +60,10 @@ def get_query(self, sess: Session, sess_obj, entry) -> Query: def add_new_session(self, args: argparse.Namespace, worker) -> None: """Add new session entry""" self.reason = args.label - self.docker = args.docker_name + if len(args.docker_name) >= 128: + self.docker = args.docker_name[:128] + else: + self.docker = args.docker_name if hasattr(args, 'arch') and args.arch: self.arch = args.arch else: diff --git a/tuna/machine.py b/tuna/machine.py index dfaaf75fb..c8ad6a2a7 100644 --- a/tuna/machine.py +++ b/tuna/machine.py @@ -38,6 +38,7 @@ from typing import Set, List, Optional, TextIO, Tuple, Dict, Union, Any, Callable from sqlalchemy import Text, Column, orm +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.dialects.mysql import TINYINT, INTEGER from paramiko import SSHClient @@ -68,7 +69,7 @@ class Machine(BASE): #pylint: disable=too-many-instance-attributes local_port: int = Column(INTEGER, server_default="22") user: str = Column(Text, nullable=False) password: str = Column(Text, nullable=False) - avail_gpus: List[int] = Column(Text, nullable=False) + _avail_gpus: str = Column('avail_gpus', Text, nullable=False) arch: str = Column(Text, nullable=False) arch_full: str = '' num_cu: int = Column(INTEGER, nullable=False, server_default="64") @@ -136,15 +137,38 @@ def __init__(self, **kwargs: dict) -> None: self.ipmi_user, self.ipmi_password) if not self.avail_gpus is None: - self.avail_gpus = [ - int(val) for val in self.avail_gpus.split(',') #type: ignore - ] #type: ignore - self.num_gpus = len(self.avail_gpus) + # Check if it's already a list (from hybrid property getter) + if isinstance(self.avail_gpus, list): + # Already converted by hybrid property, just use it + self.num_gpus = len(self.avail_gpus) + elif isinstance(self.avail_gpus, str) and self.avail_gpus: + # String from database, convert to list + self.avail_gpus = [int(val) for val in self.avail_gpus.split(',')] + self.num_gpus = len(self.avail_gpus) self.cpus = [] self.gpus = [] self.logger.info("avail gpus: %s", self.avail_gpus) + @hybrid_property + def avail_gpus(self) -> List[int]: + """Return avail_gpus as a list of integers for application use""" + if isinstance(self._avail_gpus, str) and self._avail_gpus: + return [int(x) for x in self._avail_gpus.split(',')] + elif isinstance(self._avail_gpus, list): + return self._avail_gpus + return [] + + @avail_gpus.setter + def avail_gpus(self, value: Union[List[int], str]) -> None: + """Store avail_gpus as comma-separated string for database storage""" + if isinstance(value, list): + self._avail_gpus = ','.join(map(str, value)) + elif value: + self._avail_gpus = str(value) + else: + self._avail_gpus = '' + def set_logger(self, logger: logging.Logger) -> bool: """set logging for machine, use this to associate the machine with a subprocess""" pid: int = os.getpid() diff --git a/tuna/miopen/celery_tuning/celery_tasks.py b/tuna/miopen/celery_tuning/celery_tasks.py index 3dfe464a2..61ade4155 100644 --- a/tuna/miopen/celery_tuning/celery_tasks.py +++ b/tuna/miopen/celery_tuning/celery_tasks.py @@ -26,9 +26,13 @@ # ############################################################################### """Module to register MIOpen celery tasks""" +import os +import socket import copy from celery.signals import celeryd_after_setup from celery.utils.log import get_task_logger +from sqlalchemy.exc import IntegrityError +from sqlalchemy import text from tuna.celery_app.celery_app import app from tuna.libraries import Operation from tuna.machine import Machine @@ -36,15 +40,109 @@ from tuna.utils.utility import SimpleDict from tuna.utils.celery_utils import prep_default_kwargs, get_cached_worker from tuna.miopen.miopen_lib import Q_NAME +from tuna.dbBase.sql_alchemy import DbSession logger = get_task_logger(__name__) +def check_hostname_unique_constraint(session): + """Check if hostname has a unique constraint on the machine table""" + try: + result = session.execute( + text("SELECT COUNT(*) FROM information_schema.statistics " + "WHERE table_schema = DATABASE() " + "AND table_name = 'machine' " + "AND column_name = 'hostname' " + "AND non_unique = 0")).scalar() + return result > 0 + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Could not check for hostname unique constraint: %s", e) + return None # Unknown state + + @celeryd_after_setup.connect def capture_worker_name(sender, instance, **kwargs): #pylint: disable=unused-argument - """Capture worker name""" + """Capture worker name and ensure machine is registered""" app.worker_name = sender + # Ensure this machine is in the database + global cached_machine + + # Ensure cached_machine is fully initialized + if not cached_machine.hostname: + cached_machine.hostname = socket.gethostname() + logger.info("Initialized hostname: %s", cached_machine.hostname) + + with DbSession() as session: + # Check for unique constraint on hostname (only check once) + if not check_hostname_unique_constraint(session): + logger.warning( + "WARNING: The 'machine' table does not have a UNIQUE constraint on 'hostname'. " + "This may lead to duplicate machine entries and race conditions. " + "Please run: ALTER TABLE machine ADD UNIQUE INDEX idx_hostname (hostname(255)); " + "Or apply the Alembic migration: alembic upgrade head") + + # Check if machine exists by hostname + existing = session.query(Machine).filter( + Machine.hostname == cached_machine.hostname).first() + + if not existing: + # Create a new machine object for database insertion + # Don't use cached_machine directly as it has id=0 hardcoded + # Note: avail_gpus can be passed as a list - the @validates decorator + # in Machine class will automatically convert it to a string for database storage + new_machine = Machine( + hostname=cached_machine.hostname, + user=os.getenv('USER', 'unknown'), + password='', + arch=cached_machine.arch if cached_machine.arch else 'unknown', + num_cu=cached_machine.num_cu if cached_machine.num_cu else 64, + avail_gpus=cached_machine.avail_gpus + if cached_machine.avail_gpus else []) + + try: + # Insert the machine and let database auto-assign ID + session.add(new_machine) + session.commit() + session.refresh(new_machine) + cached_machine.id = new_machine.id + logger.info("Registered machine %s with id %s", cached_machine.hostname, + cached_machine.id) + except IntegrityError as ie: + # Race condition: another worker beat us to it + # Rollback and query again to get the existing record + session.rollback() + logger.info( + "Race condition detected during machine registration, querying existing record" + ) + existing = session.query(Machine).filter( + Machine.hostname == cached_machine.hostname).first() + if existing: + cached_machine.id = existing.id + logger.info( + "Using existing machine %s with id %s (from race condition recovery)", + cached_machine.hostname, cached_machine.id) + else: + # This should never happen, but log it if it does + logger.error( + "Failed to find machine after IntegrityError - this should not happen!" + ) + raise ie + except Exception as e: # pylint: disable=broad-exception-caught + # Log any other errors during machine registration + session.rollback() + logger.error("Error registering machine: %s", e) + logger.error( + "Machine details - hostname: %s, arch: %s, num_cu: %s, avail_gpus: %s", + cached_machine.hostname, cached_machine.arch, cached_machine.num_cu, + cached_machine.avail_gpus) + raise + else: + # Use existing machine id + cached_machine.id = existing.id + logger.info("Using existing machine %s with id %s", + cached_machine.hostname, cached_machine.id) + cached_machine = Machine(local_machine=True) @@ -91,4 +189,8 @@ def celery_enqueue(context): worker = prep_worker(copy.deepcopy(context)) ret = worker.run() + + # Add machine_id to the context before returning + context['machine_id'] = cached_machine.id + return {"ret": ret, "context": context} diff --git a/tuna/miopen/db/build_schema.py b/tuna/miopen/db/build_schema.py index 583b84430..501a2c579 100755 --- a/tuna/miopen/db/build_schema.py +++ b/tuna/miopen/db/build_schema.py @@ -26,6 +26,7 @@ ############################################################################### """ Module for creating DB tables""" from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.miopen.db.get_db_tables import get_miopen_tables from tuna.miopen.db.triggers import get_miopen_triggers, drop_miopen_triggers from tuna.db_engine import ENGINE @@ -41,15 +42,16 @@ def recreate_triggers(drop_triggers, create_triggers): with ENGINE.connect() as conn: for dtg in drop_triggers: - conn.execute(f"drop trigger if exists {dtg}") + conn.execute(text(f"drop trigger if exists {dtg}")) for trg in create_triggers: try: - conn.execute(trg) + conn.execute(text(trg)) except OperationalError as oerr: LOGGER.warning("Operational Error occurred while adding trigger: '%s'", trg) LOGGER.info('%s \n', oerr) continue + conn.commit() return True diff --git a/tuna/miopen/db/mixin_tables.py b/tuna/miopen/db/mixin_tables.py index b2bfdc1fe..da933314e 100644 --- a/tuna/miopen/db/mixin_tables.py +++ b/tuna/miopen/db/mixin_tables.py @@ -27,7 +27,7 @@ """Represents Mixin type table class definitions """ import enum from sqlalchemy.sql import func as sqla_func -from sqlalchemy.databases import mysql +from sqlalchemy.dialects import mysql from sqlalchemy import Float, Boolean from sqlalchemy.dialects.mysql import TINYINT, MEDIUMBLOB, LONGBLOB from sqlalchemy.ext.declarative import declared_attr @@ -64,7 +64,7 @@ class MIOpenJobMixin(JobMixin): solver = Column(String(length=128), nullable=True, server_default="") eval_mid = Column(Integer, server_default="-1") - fin_step = Column(mysql.MSSet(*(list(k for k in FinStep.__members__))), + fin_step = Column(mysql.SET(*(list(k for k in FinStep.__members__))), nullable=False, server_default="not_fin") diff --git a/tuna/miopen/driver/base.py b/tuna/miopen/driver/base.py index 2bed7e6ef..8c0d34e5e 100755 --- a/tuna/miopen/driver/base.py +++ b/tuna/miopen/driver/base.py @@ -160,26 +160,38 @@ def get_tensor_id(session: Session, tensor_dict: dict) -> int: def __insert_tensor(self, tensor_dict: dict) -> int: """Insert new row into tensor table and return primary key""" ret_id: int = -1 + + # Check cache first without creating a session + tid = TensorTable(**tensor_dict) + tid.valid = 1 + key = build_dict_val_key(tid) + + # If cache is populated and key exists, return immediately + if MIOpenDriver.tensor_id_map and key in MIOpenDriver.tensor_id_map: + ret_id = MIOpenDriver.tensor_id_map[key] + LOGGER.info("Get Tensor: %s", ret_id) + return ret_id + + # Cache miss or not populated - need database session session: Session with DbSession() as session: try: - tid = TensorTable(**tensor_dict) - tid.valid = 1 - key = build_dict_val_key(tid) #cache the tensor table to avoid queries if not MIOpenDriver.tensor_id_map: MIOpenDriver.tensor_id_map = get_session_val_map( session, TensorTable, MIOpenDriver.tensor_attr) - id_map = MIOpenDriver.tensor_id_map - if key in id_map: - ret_id = id_map[key] - LOGGER.info("Get Tensor: %s", ret_id) - else: - session.add(tid) - session.commit() - ret_id = tid.id - id_map[key] = ret_id - LOGGER.info("Insert Tensor: %s", ret_id) + # Check cache again after loading + if key in MIOpenDriver.tensor_id_map: + ret_id = MIOpenDriver.tensor_id_map[key] + LOGGER.info("Get Tensor: %s", ret_id) + return ret_id + + # Not in cache, insert new tensor + session.add(tid) + session.commit() + ret_id = tid.id + MIOpenDriver.tensor_id_map[key] = ret_id + LOGGER.info("Insert Tensor: %s", ret_id) except IntegrityError as err: LOGGER.warning(err) session.rollback() diff --git a/tuna/miopen/miopen_lib.py b/tuna/miopen/miopen_lib.py index a20e1a23d..dcc5d1bc9 100644 --- a/tuna/miopen/miopen_lib.py +++ b/tuna/miopen/miopen_lib.py @@ -35,6 +35,7 @@ from kombu.utils.uuid import uuid from sqlalchemy.inspection import inspect from sqlalchemy.exc import OperationalError, DataError, IntegrityError +from sqlalchemy import text from tuna.mituna_interface import MITunaInterface from tuna.miopen.utils.helper import print_solvers from tuna.parse_args import TunaArgs, setup_arg_parser, args_check @@ -60,7 +61,8 @@ from tuna.miopen.db.triggers import drop_miopen_triggers, get_miopen_triggers from tuna.miopen.utils.config_type import ConfigType from tuna.miopen.db.tables import MIOpenDBTables -#from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue + +# from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue from tuna.miopen.utils.json_to_sql import process_fdb_w_kernels, process_tuning_data from tuna.miopen.utils.json_to_sql import process_pdb_compile from tuna.miopen.utils.json_to_sql import clean_cache_table @@ -88,130 +90,179 @@ def parse_args(self): # pylint: disable=too-many-statements """Function to parse arguments""" parser = setup_arg_parser( - 'Run Performance Tuning on a certain architecture', [ - TunaArgs.ARCH, TunaArgs.NUM_CU, TunaArgs.VERSION, - TunaArgs.CONFIG_TYPE, TunaArgs.SESSION_ID, TunaArgs.MACHINES, - TunaArgs.REMOTE_MACHINE, TunaArgs.LABEL, TunaArgs.RESTART_MACHINE, - TunaArgs.DOCKER_NAME, TunaArgs.SHUTDOWN_WORKERS, - TunaArgs.ENQUEUE_ONLY - ]) + "Run Performance Tuning on a certain architecture", + [ + TunaArgs.ARCH, + TunaArgs.NUM_CU, + TunaArgs.VERSION, + TunaArgs.CONFIG_TYPE, + TunaArgs.SESSION_ID, + TunaArgs.MACHINES, + TunaArgs.REMOTE_MACHINE, + TunaArgs.LABEL, + TunaArgs.RESTART_MACHINE, + TunaArgs.DOCKER_NAME, + TunaArgs.SHUTDOWN_WORKERS, + TunaArgs.ENQUEUE_ONLY, + ], + ) parser.add_argument( - '--find_mode', - dest='find_mode', + "--find_mode", + dest="find_mode", type=int, default=1, - help='Set the MIOPEN_FIND_MODE environment variable for MIOpen', - choices=['1', '3']) - parser.add_argument('--ticket', - dest='ticket', - type=str, - default=None, - help='Specify tuning ticket number') + help="Set the MIOPEN_FIND_MODE environment variable for MIOpen", + choices=["1", "3"], + ) + parser.add_argument( + "--ticket", + dest="ticket", + type=str, + default=None, + help="Specify tuning ticket number", + ) parser.add_argument( - '--solver_id', + "--solver_id", type=int, - dest='solver_id', + dest="solver_id", default=None, - help='Specify solver_id. Use --list_solvers to see options') - parser.add_argument('--dynamic_solvers_only', - dest='dynamic_solvers_only', - action='store_true', - default=False, - help='Only tune dynamic solvers.') + help="Specify solver_id. Use --list_solvers to see options", + ) + parser.add_argument( + "--dynamic_solvers_only", + dest="dynamic_solvers_only", + action="store_true", + default=False, + help="Only tune dynamic solvers.", + ) parser.add_argument( - '-B', - '--blacklist', - dest='blacklist', + "-B", + "--blacklist", + dest="blacklist", type=str, default=None, - help='MIOpen blacklist algorithm, if multiple then comma separate') - parser.add_argument('-i', - '--reset_interval', - type=int, - dest='reset_interval', - required=False, - help='Restart interval for job in hours.') + help="MIOpen blacklist algorithm, if multiple then comma separate", + ) parser.add_argument( - '--gpu_lim', - dest='gpu_lim', + "-i", + "--reset_interval", + type=int, + dest="reset_interval", + required=False, + help="Restart interval for job in hours.", + ) + parser.add_argument( + "--gpu_lim", + dest="gpu_lim", type=int, default=None, - help='Limit the number of gpu workers created by Tuna, index from 0') + help="Limit the number of gpu workers created by Tuna, index from 0", + ) parser.add_argument( - '-R', - '--rich_data', - dest='rich_data', - action='store_true', + "-R", + "--rich_data", + dest="rich_data", + action="store_true", default=False, - help='record intermediate parameter results from perf tuning') + help="record intermediate parameter results from perf tuning", + ) subcommands = parser.add_subcommands(required=False) - subcommands.add_subcommand('import_configs', + subcommands.add_subcommand("import_configs", get_import_cfg_parser(), required=False) - subcommands.add_subcommand('load_job', + subcommands.add_subcommand("load_job", get_load_job_parser(), required=False) - subcommands.add_subcommand('export_db', + subcommands.add_subcommand("export_db", get_export_db_parser(), required=False) - subcommands.add_subcommand('update_golden', + subcommands.add_subcommand("update_golden", get_update_golden_parser(), required=False) group = parser.add_mutually_exclusive_group() - group.add_argument('--add_tables', - dest='add_tables', - action='store_true', - help='Add MIOpen library specific tables') - - group.add_argument('--init_session', - action='store_true', - dest='init_session', - help='Set up a new tuning session.') group.add_argument( - '--fin_steps', + "--add_tables", + dest="add_tables", + action="store_true", + help="Add MIOpen library specific tables", + ) + + group.add_argument( + "--init_session", + action="store_true", + dest="init_session", + help="Set up a new tuning session.", + ) + group.add_argument( + "--fin_steps", type=str, - dest='fin_steps', - help='Specify fin steps. Multiple steps should be comma separated.') - group.add_argument('--list_solvers', - action='store_true', - dest='list_solvers', - help='List of solvers from the solver table') + dest="fin_steps", + help="Specify fin steps. Multiple steps should be comma separated.", + ) + group.add_argument( + "--list_solvers", + action="store_true", + dest="list_solvers", + help="List of solvers from the solver table", + ) # JD: implement the following two using fin_steps - group.add_argument('--update_solvers', - dest='update_solvers', - action='store_true', - help='Update the list of solvers in the database') - group.add_argument('--update_applicability', - dest='update_applicability', - action='store_true', - help='Update the applicability table in the database') - group.add_argument('-s', - '--status', - dest='check_status', - action='store_true', - default=False, - help='Check the status of machines') - - group.add_argument('-e', - '--exec', - dest='execute_cmd', - type=str, - default=None, - help='execute on each machine') + group.add_argument( + "--update_solvers", + dest="update_solvers", + action="store_true", + help="Update the list of solvers in the database", + ) + group.add_argument( + "--update_applicability", + dest="update_applicability", + action="store_true", + help="Update the applicability table in the database", + ) + parser.add_argument( + "--new_only", + dest="new_only", + action="store_true", + default=False, + help="Only update applicability for configs without existing data in this session (use with --update_applicability)", + ) + parser.add_argument( + "--config_limit", + dest="config_limit", + type=int, + default=None, + help="Limit the number of configs to process (useful for testing with --update_applicability)", + ) + group.add_argument( + "-s", + "--status", + dest="check_status", + action="store_true", + default=False, + help="Check the status of machines", + ) + + group.add_argument( + "-e", + "--exec", + dest="execute_cmd", + type=str, + default=None, + help="execute on each machine", + ) self.args = parser.parse_args() if self.args.config_type is None: self.args.config_type = ConfigType.convolution - #overwritte common lib args with subcommand args value + # overwritte common lib args with subcommand args value if self.args.subcommand is not None: self.overwrite_common_args() @@ -221,16 +272,16 @@ def parse_args(self): if self.args.list_solvers: print_solvers() - raise CustomError('Printing solvers...') + raise CustomError("Printing solvers...") - if self.args.fin_steps and self.args.subcommand != 'load_job': + if self.args.fin_steps and self.args.subcommand != "load_job": self.check_fin_args(parser) self.set_prefix() if self.args.find_mode is None and not (self.args.check_status or self.args.restart_machine or self.args.execute_cmd): - parser.error('find_mode must be specified for a tuning run') + parser.error("find_mode must be specified for a tuning run") if self.args.blacklist: self.check_blacklist(parser) @@ -238,8 +289,13 @@ def parse_args(self): args_check(self.args, parser) fin_session_steps = [ - 'miopen_find_compile', 'miopen_find_eval', 'miopen_perf_compile', - 'miopen_perf_eval', 'get_applicability', 'find_compile', 'find_eval' + "miopen_find_compile", + "miopen_find_eval", + "miopen_perf_compile", + "miopen_perf_eval", + "get_applicability", + "find_compile", + "find_eval", ] has_fin = False if self.args.fin_steps: @@ -255,14 +311,14 @@ def parse_args(self): def set_prefix(self): """Set redis key prefix""" if isinstance(self.args.fin_steps, Iterable): - steps_str = ('-').join(x for x in self.args.fin_steps) - self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_"\ - f"{steps_str}" + steps_str = ("-").join(x for x in self.args.fin_steps) + self.prefix = (f"d_{self.db_name}_sess_{self.args.session_id}_" + f"{steps_str}") else: steps_str = self.args.fin_steps[0] self.prefix = f"d_{self.db_name}_sess_{self.args.session_id}_{steps_str}" - self.logger.info('redis prefix: %s', self.prefix) + self.logger.info("redis prefix: %s", self.prefix) def overwrite_common_args(self): """Overwrite common MIOpen_lib args with subcommand args""" @@ -274,12 +330,12 @@ def overwrite_common_args(self): def check_fin_args(self, parser): """! Helper function for fin args - @param parser The command line argument parser + @param parser The command line argument parser """ valid_fin_steps = list(k for k in FinStep.__members__) - if ',' in self.args.fin_steps: - parser.error('Multiple fin_steps currently not supported') - f_steps = self.args.fin_steps.split(',') + if "," in self.args.fin_steps: + parser.error("Multiple fin_steps currently not supported") + f_steps = self.args.fin_steps.split(",") self.args.fin_steps = f_steps for step in self.args.fin_steps: if step not in valid_fin_steps: @@ -288,41 +344,46 @@ def check_fin_args(self, parser): def check_blacklist(self, parser): """! Helper function - @param parser The command line argument parser - @return ret Boolean value - """ - self.args.blacklist = self.args.blacklist.split(',') + @param parser The command line argument parser + @return ret Boolean value + """ + self.args.blacklist = self.args.blacklist.split(",") for sol in self.args.blacklist: if sol not in MIOPEN_ALG_LIST: parser.error("Incorrect blacklist value") def do_fin_work(self, gpu, f_vals): """! Helper function to execute job independendent fin work - @param gpu Unique ID of the GPU - @param f_vals Dict containing runtime information - """ + @param gpu Unique ID of the GPU + @param f_vals Dict containing runtime information + """ kwargs = self.get_kwargs(gpu, f_vals) fin_worker = FinClass(**kwargs) if self.args.update_solvers: if not fin_worker.get_solvers(): - self.logger.error('No solvers returned from Fin class') + self.logger.error("No solvers returned from Fin class") return True def launch_worker(self, gpu_idx, f_vals, worker_lst): """! Function to launch worker - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param worker_lst List containing worker instances - @return ret Boolean value - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param worker_lst List containing worker instances + @return ret Boolean value + """ # pylint: disable=too-many-branches worker = None kwargs = self.get_kwargs(gpu_idx, f_vals) if self.args.update_applicability: - kwargs['fin_steps'] = ['applicability'] + kwargs["fin_steps"] = ["applicability"] + kwargs["new_only"] = self.args.new_only + kwargs["config_limit"] = self.args.config_limit worker = FinClass(**kwargs) + self.logger.info("Created FinClass worker with gpu_id=%s, ROCR_VISIBLE_DEVICES in envmt: %s", + worker.gpu_id, + any('ROCR_VISIBLE_DEVICES' in env for env in worker.envmt)) worker.start() worker_lst.append(worker) return True @@ -330,8 +391,13 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst): worker = FinClass(**kwargs) ret = False if self.args.check_status: - if not super().check_status(worker, f_vals["b_first"], gpu_idx, - f_vals["machine"], self.args.docker_name): + if not super().check_status( + worker, + f_vals["b_first"], + gpu_idx, + f_vals["machine"], + self.args.docker_name, + ): ret = True elif self.args.init_session: Session().add_new_session(self.args, worker) @@ -345,8 +411,8 @@ def launch_worker(self, gpu_idx, f_vals, worker_lst): def compose_worker_list(self, machines): # pylint: disable=too-many-branches """! Helper function to compose worker_list - @param machines List of machines to execute on - """ + @param machines List of machines to execute on + """ worker_lst = [] fin_work_done = False for machine in machines: @@ -354,28 +420,43 @@ def compose_worker_list(self, machines): machine.restart_server(wait=False) continue - #fin_steps should only contain one step + # fin_steps should only contain one step worker_ids = None - if self.args.fin_steps and 'eval' in self.args.fin_steps[0]: - worker_ids = machine.get_avail_gpus() + if self.args.fin_steps and "eval" in self.args.fin_steps[0]: + worker_ids = machine.get_avail_gpus() # Use actual GPUs if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): - worker_ids = range(self.args.gpu_lim) + worker_ids = list(range(self.args.gpu_lim)) + elif self.args.update_applicability: + worker_ids = list(range(len(machine.get_avail_gpus()) * 4 )) # Use GPU count + if self.args.gpu_lim and self.args.gpu_lim < len(worker_ids): + worker_ids = list(range(self.args.gpu_lim)) else: - worker_ids = super().get_num_procs(machine) + worker_ids = super().get_num_procs(machine) # Use CPU count for other operations + if self.args.update_applicability: f_vals = super().get_f_vals(machine, [1]) kwargs = self.get_kwargs(0, f_vals) - kwargs['fin_steps'] = ['applicability'] + kwargs["fin_steps"] = ["applicability"] + kwargs["new_only"] = self.args.new_only + kwargs["config_limit"] = self.args.config_limit worker = FinClass(**kwargs) - query = worker.query_cfgs(self.args.label) + skip_existing = self.args.new_only + config_limit = self.args.config_limit + query = worker.query_cfgs(self.args.label, skip_existing=skip_existing, config_limit=config_limit) cfg_rows = query.all() len_rows = len(cfg_rows) + self.logger.warning("Found %d configs to process (label=%s, new_only=%s, config_limit=%s)", + len_rows, self.args.label, self.args.new_only, self.args.config_limit) proc_lim = (len_rows + 99) / 100 if 32 < proc_lim: proc_lim = 32 + self.logger.info("Calculated proc_lim=%d based on %d configs", proc_lim, len_rows) + initial_workers = len(worker_ids) while len(worker_ids) > proc_lim: worker_ids.pop() + self.logger.warning("Worker count: initial=%d, after limit=%d (proc_lim=%d)", + initial_workers, len(worker_ids), proc_lim) if len(worker_ids) == 0: return None @@ -388,7 +469,7 @@ def compose_worker_list(self, machines): break for gpu_idx in worker_ids: - self.logger.info('launch mid %u, proc %u', machine.id, gpu_idx) + self.logger.info("launch mid %u, proc %u", machine.id, gpu_idx) if not self.launch_worker(gpu_idx, f_vals, worker_lst): break @@ -396,10 +477,10 @@ def compose_worker_list(self, machines): def add_tables(self): """! Function to create new DB tables - @return Bool - """ + @return Bool + """ ret_t = create_tables(get_miopen_tables()) - self.logger.info('DB creation successful: %s', ret_t) + self.logger.info("DB creation successful: %s", ret_t) recreate_triggers(drop_miopen_triggers(), get_miopen_triggers()) return True @@ -414,19 +495,20 @@ def run(self): self.add_tables() return None - if self.args.subcommand is not None and self.args.subcommand == 'import_configs': + if (self.args.subcommand is not None and + self.args.subcommand == "import_configs"): run_import_configs(self.args.import_configs, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'load_job': + if self.args.subcommand is not None and self.args.subcommand == "load_job": run_load_job(self.args.load_job, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'export_db': + if self.args.subcommand is not None and self.args.subcommand == "export_db": run_export_db(self.args.export_db, self.logger) return None - if self.args.subcommand is not None and self.args.subcommand == 'update_golden': + if self.args.subcommand is not None and self.args.subcommand == "update_golden": run_update_golden(self.args.update_golden, self.logger) return None @@ -435,8 +517,7 @@ def run(self): return res def get_envmt(self): - """! Function to construct environment var - """ + """! Function to construct environment var""" envmt = ["MIOPEN_LOG_LEVEL=4"] envmt.append("MIOPEN_SQLITE_KERN_CACHE=ON") @@ -447,58 +528,66 @@ def get_envmt(self): if self.args.blacklist: bk_str = ", ".join([f"{arg}=0" for arg in self.args.blacklist]) - for bk_var in bk_str.split(','): + for bk_var in bk_str.split(","): envmt.append(bk_var) return envmt def get_kwargs(self, gpu_idx, f_vals, tuning=False): """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - @param tuning Boolean that indicates if kwargs are for a tuning step - @return kwargs Dictionary - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + @param tuning Boolean that indicates if kwargs are for a tuning step + @return kwargs Dictionary + """ kwargs = super().get_kwargs(gpu_idx, f_vals, tuning) - kwargs['fin_steps'] = self.args.fin_steps - kwargs['dynamic_solvers_only'] = self.args.dynamic_solvers_only - kwargs['config_type'] = self.args.config_type - kwargs['reset_interval'] = self.args.reset_interval + kwargs["fin_steps"] = self.args.fin_steps + kwargs["dynamic_solvers_only"] = self.args.dynamic_solvers_only + kwargs["config_type"] = self.args.config_type + kwargs["reset_interval"] = self.args.reset_interval return kwargs def get_job_list(self, session, find_state, claim_num): """! Get list of jobs - @param session DB session - @param find_state DB job state - @param claim_num Number of DB jobs to pick up - @return List of DB jobs + @param session DB session + @param find_state DB job state + @param claim_num Number of DB jobs to pick up + @return List of DB jobs - """ - job_list = self.get_job_objs(session, find_state, self.args.label, self.dbt, - self.get_job_attr(), claim_num, - self.args.fin_steps) + """ + job_list = self.get_job_objs( + session, + find_state, + self.args.label, + self.dbt, + self.get_job_attr(), + claim_num, + self.args.fin_steps, + ) return job_list - def get_job_objs(self, - session: DbSession, - find_state: list, - label: str, - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: + def get_job_objs( + self, + session: DbSession, + find_state: list, + label: str, + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: """! Get list of job objects - @param session DB session - @param find_state DB job state - @param label DB job reason - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param claim_num Number of DB jobs to pick up - @param fin_steps List of MIFin steps - @return List of DB jobs - """ + @param session DB session + @param find_state DB job state + @param label DB job reason + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param claim_num Number of DB jobs to pick up + @param fin_steps List of MIFin steps + @return List of DB jobs + """ entries: List[Tuple[SimpleDict, ...]] conds: List[str] = [f"session={dbt.session.id}", "valid=1"] @@ -506,38 +595,42 @@ def get_job_objs(self, conds.append(f"reason='{label}'") conds.append(f"retries<{self.max_job_retries}") - conds.append("state in (" + str(find_state).strip('{').strip('}') + ")") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") entries = self.compose_work_objs(session, conds, dbt, job_attr, claim_num, fin_steps) return entries - def compose_work_objs(self, - session: DbSession, - conds: List[str], - dbt: DBTablesInterface, - job_attr: List[str], - claim_num: int = None, - fin_steps: List[str] = None) -> List[SimpleDict]: + def compose_work_objs( + self, + session: DbSession, + conds: List[str], + dbt: DBTablesInterface, + job_attr: List[str], + claim_num: int = None, + fin_steps: List[str] = None, + ) -> List[SimpleDict]: """! Query a job list for update - @param session DB session - @param conds List of conditions for DB job WHERE clause - @param dbt Class representing all DB tables associated with this class - @param job_attr List of DB job columns - @param fin_steps List of MIFin steps - @return List of MIFin work objects - """ + @param session DB session + @param conds List of conditions for DB job WHERE clause + @param dbt Class representing all DB tables associated with this class + @param job_attr List of DB job columns + @param fin_steps List of MIFin steps + @return List of MIFin work objects + """ job_entries = [] if fin_steps: conds.append(f"fin_step like '%{fin_steps[0]}%'") else: conds.append("fin_step='not_fin'") - cond_str = ' AND '.join(conds) + cond_str = " AND ".join(conds) if cond_str: cond_str = f"WHERE {cond_str}" if claim_num: - cond_str += f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + cond_str += ( + f" ORDER BY retries,config ASC LIMIT {claim_num} FOR UPDATE SKIP LOCKED" + ) else: cond_str += " ORDER BY retries,config ASC FOR UPDATE SKIP LOCKED" @@ -546,26 +639,93 @@ def compose_work_objs(self, return job_entries + def detect_and_handle_locked_jobs(self, session: DbSession, + find_state: List[str]) -> bool: + """Detect jobs that are locked and preventing progress + + This method queries for jobs without locking to detect if jobs exist + but are being skipped due to database locks. If found, it marks jobs + with high retry counts as errored to unblock the pipeline. + + @param session DB session + @param find_state List of job states to check + @return True if locked jobs were found and handled, False otherwise + """ + # Query WITHOUT lock to see if jobs are being skipped + conds = [f"session={self.dbt.session.id}", "valid=1"] + + if self.args.label: + conds.append(f"reason='{self.args.label}'") + + conds.append(f"retries<{self.max_job_retries}") + conds.append("state in (" + str(find_state).strip("{").strip("}") + ")") + + if self.args.fin_steps: + conds.append(f"fin_step like '%{self.args.fin_steps[0]}%'") + + cond_str = " AND ".join(conds) + query = f""" + SELECT id, config, retries, state, solver + FROM {self.dbt.job_table.__tablename__} + WHERE {cond_str} + ORDER BY retries, config ASC + LIMIT 10 + """ + + unlocked_jobs = session.execute(text(query)).fetchall() + + if unlocked_jobs: + self.logger.warning( + "Found %d jobs in target state but they were skipped by FOR UPDATE SKIP LOCKED", + len(unlocked_jobs)) + self.logger.warning("Likely cause: stale database locks. Job IDs: %s", + [job[0] for job in unlocked_jobs]) + + # Mark jobs with high retries as errored to unblock + jobs_marked = 0 + for job_row in unlocked_jobs: + job_id, config_id, retries, state, solver = job_row + if retries >= (MAX_ERRORED_JOB_RETRIES - 1): # retries >= 2 + self.logger.warning( + "Marking locked job %d (config=%d, solver=%s, retries=%d) as errored to unblock pipeline", + job_id, config_id, solver, retries) + update_query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = 'errored', + result = 'Marked as errored due to stale lock or excessive retries', + update_ts = NOW() + WHERE id = {job_id} + """ + session.execute(text(update_query)) + jobs_marked += 1 + + if jobs_marked > 0: + session.commit() + self.logger.info("Marked %d locked jobs as errored", jobs_marked) + return True + + return False + def compose_work_objs_fin(self, session, job_entries, dbt) -> List[Tuple[SimpleDict, SimpleDict]]: """! Return jobs for fin work - @param session DB session - @param job_entries List of DB jobs - @param dbt Class representing all DB tables associated with this class - @return ret Job tuple - """ + @param session DB session + @param job_entries List of DB jobs + @param dbt Class representing all DB tables associated with this class + @return ret Job tuple + """ ret = [] cfg_rel = { key: { - 'key': list(val.local_columns)[0].name, - 'ftble': str(list(val.remote_side)[0]).split('.', maxsplit=1)[0], - 'fkey': str(list(val.remote_side)[0]).split('.')[1] + "key": list(val.local_columns)[0].name, + "ftble": str(list(val.remote_side)[0]).split(".", maxsplit=1)[0], + "fkey": str(list(val.remote_side)[0]).split(".")[1], } for key, val in inspect(dbt.config_table).relationships.items() } if job_entries: - id_str = ','.join({str(job.config) for job in job_entries}) + id_str = ",".join({str(job.config) for job in job_entries}) cfg_cond_str = f"where valid=1 and id in ({id_str})" cfg_attr = [column.name for column in inspect(dbt.config_table).c] cfg_entries = gen_select_objs(session, cfg_attr, @@ -583,30 +743,32 @@ def compose_work_objs_fin(self, session, job_entries, def attach_tensors(self, session, cfg_rel, cfg_entries): """! Attach tensor relationship information to config entries - @param session DB session - @param cfg_rel DB Config col value - @param cfg_entries List of DB Config entries - @return cfg_entries List of DB Config entries with attached tensors (foreign keys) + @param session DB session + @param cfg_rel DB Config col value + @param cfg_entries List of DB Config entries + @return cfg_entries List of DB Config entries with attached tensors (foreign keys) - """ + """ for key, val in cfg_rel.items(): rel_attr = [ column.name - for column in inspect(get_class_by_tablename(val['ftble'])).c + for column in inspect(get_class_by_tablename(val["ftble"])).c ] - val['fattr'] = rel_attr + val["fattr"] = rel_attr for cfg in cfg_entries: for key, val in cfg_rel.items(): - rel_val = getattr(cfg, val['key']) + rel_val = getattr(cfg, val["key"]) rel_cond_str = f"where {val['fkey']}={rel_val}" setattr( - cfg, key, - gen_select_objs(session, val['fattr'], val['ftble'], - rel_cond_str)[0]) + cfg, + key, + gen_select_objs(session, val["fattr"], val["ftble"], + rel_cond_str)[0], + ) return cfg_entries - #deprecated + # deprecated def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], job_attr: List[str]) -> List[SimpleDict]: """Find job tables in query results""" @@ -626,15 +788,16 @@ def get_job_tables(self, job_rows: List[Tuple[SimpleDict, ...]], def update_operation(self): """! Update the workers type that this library needs""" if self.args.fin_steps: - if 'miopen_find_compile' in self.args.fin_steps \ - or 'miopen_perf_compile' in self.args.fin_steps: - self.fetch_state.add('new') - self.set_state = 'compile_start' + if ("miopen_find_compile" in self.args.fin_steps or + "miopen_perf_compile" in self.args.fin_steps): + self.fetch_state.add("new") + self.set_state = "compile_start" self.operation = Operation.COMPILE - elif 'miopen_find_eval' in self.args.fin_steps or 'miopen_perf_eval' in self.args.fin_steps: - self.fetch_state.add('new') - self.fetch_state.add('compiled') - self.set_state = 'eval_start' + elif ("miopen_find_eval" in self.args.fin_steps or + "miopen_perf_eval" in self.args.fin_steps): + self.fetch_state.add("new") + self.fetch_state.add("compiled") + self.set_state = "eval_start" self.operation = Operation.EVAL if self.args.update_applicability: @@ -642,8 +805,8 @@ def update_operation(self): def has_tunable_operation(self): """! Check if its a tuning loop operation - @return Bool value that represents if operation is tuning - """ + @return Bool value that represents if operation is tuning + """ if self.args is None: self.parse_args() if self.args.subcommand and "load_job" in self.args.subcommand: @@ -657,8 +820,8 @@ def has_tunable_operation(self): @lru_cache(1) def get_fdb_attr(self): """! Get find_db table attrs - @return fdb_attr find_db table attributes without timestamps - """ + @return fdb_attr find_db table attributes without timestamps + """ fdb_attr = None fdb_attr = [column.name for column in inspect(self.dbt.find_db_table).c] fdb_attr.remove("insert_ts") @@ -668,8 +831,8 @@ def get_fdb_attr(self): @lru_cache(1) def get_tuning_data_attr(self): """! Get tuning_data table attrs - @return tuning_data_attr tuning_data table attributes without timestamps - """ + @return tuning_data_attr tuning_data table attributes without timestamps + """ tuning_data_attr = None tuning_data_attr = [ column.name for column in inspect(self.dbt.tuning_data_table).c @@ -680,10 +843,10 @@ def get_tuning_data_attr(self): def serialize_jobs(self, session: DbSession, batch_jobs: List[Any]): """! Return list of serialize jobs - @param session DB session - @param batch_jobs List of DB jobs - @return DB jobs, serialized - """ + @param session DB session + @param batch_jobs List of DB jobs + @return DB jobs, serialized + """ entries = self.compose_work_objs_fin(session, batch_jobs, self.dbt) return serialize_chunk(entries) @@ -696,15 +859,15 @@ def build_context( tuning_data_attr = self.get_tuning_data_attr() for job, config in serialized_jobs: context = { - 'job': job, - 'config': config, - 'operation': self.operation, - 'arch': self.dbt.session.arch, - 'num_cu': self.dbt.session.num_cu, - 'kwargs': kwargs, - 'rich_data': self.args.rich_data, - 'fdb_attr': fdb_attr, - 'tuning_data_attr': tuning_data_attr + "job": job, + "config": config, + "operation": self.operation, + "arch": self.dbt.session.arch, + "num_cu": self.dbt.session.num_cu, + "kwargs": kwargs, + "rich_data": self.args.rich_data, + "fdb_attr": fdb_attr, + "tuning_data_attr": tuning_data_attr, } context_list.append(context) @@ -712,48 +875,55 @@ def build_context( def celery_enqueue_call(self, context: dict, q_name: str, task_id=False): """! Enqueue job (context) for queue:q_name - @param context Context for Celery job - @param q_name Custom Celery queue name - @param task_id Custom Redis Key - """ + @param context Context for Celery job + @param q_name Custom Celery queue name + @param task_id Custom Redis Key + """ - #hacky way to get the Q_NAME to the task decorator for interpreter to decorate the - #function with correct q_name arg - #if import is moved to top it will result in circular imports - Q_NAME = q_name #pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name - from tuna.miopen.celery_tuning.celery_tasks import celery_enqueue #pylint: disable=import-outside-toplevel + # hacky way to get the Q_NAME to the task decorator for interpreter to decorate the + # function with correct q_name arg + # if import is moved to top it will result in circular imports + Q_NAME = q_name # pylint: disable=import-outside-toplevel,unused-variable,invalid-name,redefined-outer-name + from tuna.miopen.celery_tuning.celery_tasks import ( + celery_enqueue,) # pylint: disable=import-outside-toplevel - return celery_enqueue.apply_async((context,), - task_id=('-').join([self.prefix, - uuid()]), - queue=q_name, - reply_to=q_name) + return celery_enqueue.apply_async( + (context,), + task_id=("-").join([self.prefix, uuid()]), + queue=q_name, + reply_to=q_name, + ) def process_compile_results(self, session, fin_json, context): """! Process result from fin_build worker - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) pending = [] solver_id_map = get_solver_ids() failed_job = False - result_str = '' + result_str = "" status = None try: if fin_json: - if 'success' in fin_json and fin_json["success"] is False: + if "success" in fin_json and fin_json["success"] is False: status = [fin_json] else: - if 'miopen_find_compile_result' in fin_json: - status = process_fdb_w_kernels(session, fin_json, - copy.deepcopy(context), self.dbt, - context['fdb_attr'], pending) - - elif 'miopen_perf_compile_result' in fin_json: + if "miopen_find_compile_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + ) + + elif "miopen_perf_compile_result" in fin_json: status = process_pdb_compile(session, fin_json, job, self.dbt, solver_id_map) @@ -761,22 +931,22 @@ def process_compile_results(self, session, fin_json, context): failed_job = not success except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) + self.logger.warning("FinBuild: Unable to update Database %s", err) session.rollback() failed_job = True except DataError as err: self.logger.warning( - 'FinBuild: Invalid data, likely large workspace. DB Error: %s', err) + "FinBuild: Invalid data, likely large workspace. DB Error: %s", err) session.rollback() failed_job = True if failed_job: - set_job_state(session, job, self.dbt, 'errored', False, result=result_str) + set_job_state(session, job, self.dbt, "errored", False, result=result_str) else: set_job_state(session, job, self.dbt, - 'compiled', + "compiled", False, result=result_str) @@ -784,73 +954,103 @@ def process_compile_results(self, session, fin_json, context): def process_eval_results(self, session, fin_json, context): """! Process fin_json result - @param session DB session - @param fin_json MIFin results for job - @param context Context for Celery job - @return Boolean value - """ - job = SimpleDict(**context['job']) + @param session DB session + @param fin_json MIFin results for job + @param context Context for Celery job + @return Boolean value + """ + job = SimpleDict(**context["job"]) failed_job = True - result_str = '' + result_str = "" pending = [] - orig_state = 'compiled' + orig_state = "compiled" + + # Extract machine_id from context + machine_id = context.get('machine_id', None) try: if fin_json: - if 'success' in fin_json and fin_json["success"] is False: + if "success" in fin_json and fin_json["success"] is False: status = [fin_json] else: - if 'miopen_find_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_find_eval_result', - check_str='evaluated') - elif 'miopen_perf_eval_result' in fin_json: - status = process_fdb_w_kernels(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['fdb_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') + if "miopen_find_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_find_eval_result", + check_str="evaluated", + ) + elif "miopen_perf_eval_result" in fin_json: + status = process_fdb_w_kernels( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["fdb_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) if context["rich_data"]: - status = process_tuning_data(session, - fin_json, - copy.deepcopy(context), - self.dbt, - context['tuning_data_attr'], - pending, - result_str='miopen_perf_eval_result', - check_str='evaluated') + status = process_tuning_data( + session, + fin_json, + copy.deepcopy(context), + self.dbt, + context["tuning_data_attr"], + pending, + result_str="miopen_perf_eval_result", + check_str="evaluated", + ) success, result_str = get_fin_result(status) failed_job = not success if failed_job: - if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): #pylint: disable=no-member - self.logger.warning('max job retries exhausted, setting to errored') - set_job_state(session, job, self.dbt, 'errored', result=result_str) - else: - self.logger.warning('resetting job state to %s, incrementing retries', - orig_state) + if job.retries >= (MAX_ERRORED_JOB_RETRIES - 1): # pylint: disable=no-member + self.logger.warning("max job retries exhausted, setting to errored") set_job_state(session, job, self.dbt, - orig_state, - increment_retries=True, - result=result_str) + "errored", + result=result_str, + machine_id=machine_id) + else: + self.logger.warning("resetting job state to %s, incrementing retries", + orig_state) + set_job_state( + session, + job, + self.dbt, + orig_state, + increment_retries=True, + result=result_str, + machine_id=machine_id, + ) else: self.logger.info("\n\n Setting job state to evaluated") - set_job_state(session, job, self.dbt, 'evaluated', result=result_str) + set_job_state(session, + job, + self.dbt, + "evaluated", + result=result_str, + machine_id=machine_id) clean_cache_table(self.dbt, job) except (OperationalError, IntegrityError) as err: - self.logger.warning('FinBuild: Unable to update Database %s', err) + self.logger.warning("FinBuild: Unable to update Database %s", err) session.rollback() - set_job_state(session, job, self.dbt, 'errored', result=result_str) + set_job_state(session, job, self.dbt, "errored", result=result_str) return True + + def extract_job_id_from_context(self, context): + """Extract job ID from MIOpen celery task context""" + try: + # Extract job ID from the job context + return context.get("job", {}).get("id") + except (AttributeError, KeyError): + return None diff --git a/tuna/miopen/parse_miopen_args.py b/tuna/miopen/parse_miopen_args.py index 5d608d5ad..a90ba3767 100644 --- a/tuna/miopen/parse_miopen_args.py +++ b/tuna/miopen/parse_miopen_args.py @@ -140,6 +140,19 @@ def get_import_cfg_parser( 'Tag to mark the origin of this config but skips the insert new config \ step in case the config does not exist in the table. Wildcard columns \ allowed for tagging') + parser.add_argument( + '--batch_size', + type=int, + dest='batch_size', + default=1000, + help='Batch size for bulk database operations (default: 1000). \ + Higher values are faster but use more memory.') + parser.add_argument( + '--disable_batch_import', + action='store_true', + dest='disable_batch_import', + help='Disable batch import optimization and use original one-by-one import. \ + Use this for debugging or if batch import has issues.') return parser diff --git a/tuna/miopen/scripts/dupe_resolve.py b/tuna/miopen/scripts/dupe_resolve.py index e8f2ed2b4..915297ca0 100755 --- a/tuna/miopen/scripts/dupe_resolve.py +++ b/tuna/miopen/scripts/dupe_resolve.py @@ -3,6 +3,7 @@ #!/usr/bin/env python3 from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy import text from tuna.dbBase.sql_alchemy import DbSession from tuna.miopen.utils.helper import handle_op_error @@ -60,15 +61,15 @@ def main(): """main""" with DbSession() as session: - session.execute(view_perf_cfg_rep) + session.execute(text(view_perf_cfg_rep)) session.commit() - res = session.execute("select id, cfg from perf_cfg_rep").all() + res = session.execute(text("select id, cfg from perf_cfg_rep")).all() invalid = 0 for session_id, cfg in res: try: query = f"update conv_perf_config set config={cfg} where id={session_id};" print(query) - #session.execute(query) + #session.execute(text(query)) #session.commit() except OperationalError as error: handle_op_error(LOGGER, error) @@ -79,21 +80,21 @@ def main(): query = f"update conv_perf_config set valid=0 where id={session_id};" LOGGER.warning('Invalidating entry (%s)', query) invalid += 1 - session.execute(query) + session.execute(text(query)) session.commit() if invalid: LOGGER.warning('Invalidated %u perf_config entries', invalid) - session.execute(view_perf_db_rep) + session.execute(text(view_perf_db_rep)) session.commit() - res = session.execute("select theid, mcfg from perf_db_rep").all() + res = session.execute(text("select theid, mcfg from perf_db_rep")).all() invalid = 0 for session_id, cfg in res: try: query = f"update conv_perf_db set miopen_config={cfg} where id={session_id};" print(query) - session.execute(query) + session.execute(text(query)) session.commit() except OperationalError as error: handle_op_error(LOGGER, error) @@ -104,7 +105,7 @@ def main(): query = f"update conv_perf_db set valid=0 where id={session_id};" LOGGER.warning('Invalidating entry (%s)', query) invalid += 1 - session.execute(query) + session.execute(text(query)) session.commit() if invalid: diff --git a/tuna/miopen/scripts/report.py b/tuna/miopen/scripts/report.py index 1c6bc0421..0f4b35d90 100755 --- a/tuna/miopen/scripts/report.py +++ b/tuna/miopen/scripts/report.py @@ -28,6 +28,7 @@ import numpy as np import pandas as pd +from sqlalchemy import text from tuna.parse_args import TunaArgs, setup_arg_parser from tuna.utils.logger import setup_logger from tuna.miopen.db.tables import MIOpenDBTables @@ -66,14 +67,14 @@ def get_data(args, dbt, arch, num_cu): query = f"select config, solver, kernel_time from {dbt.find_db_table.__tablename__} "\ f"where session={args.session_id} order by config" pd.options.display.max_rows = 100 - query_data = session.execute(query).fetchall() + query_data = session.execute(text(query)).fetchall() all_cfgs = [x[0] for x in query_data] configs = set(all_cfgs) session_data = pd.DataFrame(data=query_data) query = f"select config, solver, kernel_time from conv_golden where golden_miopen_v="\ f"{args.golden_v} and arch='{arch}' and num_cu={num_cu} and config in "\ f"{tuple(configs)} order by config" - golden_data = pd.DataFrame(data=session.execute(query).fetchall()) + golden_data = pd.DataFrame(data=session.execute(text(query)).fetchall()) session_data.columns = golden_data.columns = ['config', 'solver', 'ktime'] dfr = pd.merge(session_data, diff --git a/tuna/miopen/subcmd/import_configs.py b/tuna/miopen/subcmd/import_configs.py index 5e23d2ac8..c6aa1aba0 100755 --- a/tuna/miopen/subcmd/import_configs.py +++ b/tuna/miopen/subcmd/import_configs.py @@ -42,6 +42,8 @@ from tuna.miopen.driver.batchnorm import DriverBatchNorm from tuna.miopen.db.tables import MIOpenDBTables from tuna.miopen.db.benchmark import Framework, Model +from tuna.miopen.db.tensortable import TensorTable +from tuna.utils.db_utility import build_dict_val_key def create_query(tag: str, mark_recurrent: bool, config_id: int) -> dict: @@ -155,6 +157,606 @@ def parse_line(args: argparse.Namespace, line: str, counts: dict, return True +def batch_insert_tensors(session, tensor_dicts: List[dict], logger: logging.Logger) -> dict: + """Insert tensors in batch with duplicate handling. Returns dict mapping tensor_key -> tensor_id""" + if not tensor_dicts: + return {} + + tensor_map = {} + + # Get existing tensors to avoid duplicates + tensor_keys = [build_dict_val_key(TensorTable(**td)) for td in tensor_dicts] + unique_dicts = {build_dict_val_key(TensorTable(**td)): td for td in tensor_dicts} + + # Query existing tensors + existing_tensors = session.query(TensorTable).all() + for tensor in existing_tensors: + key = build_dict_val_key(tensor) + if key in unique_dicts: + tensor_map[key] = tensor.id + + # Filter out already existing tensors + new_tensor_dicts = [td for key, td in unique_dicts.items() if key not in tensor_map] + + if new_tensor_dicts: + try: + # Bulk insert new tensors + session.bulk_insert_mappings(TensorTable, new_tensor_dicts, return_defaults=True) + session.flush() + + # Query back to get IDs + for td in new_tensor_dicts: + key = build_dict_val_key(TensorTable(**td)) + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + tensor_map[key] = result[0] + except IntegrityError as err: + logger.warning(f"Bulk tensor insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for td in new_tensor_dicts: + try: + tensor = TensorTable(**td) + tensor.valid = 1 + session.add(tensor) + session.flush() + key = build_dict_val_key(tensor) + tensor_map[key] = tensor.id + except IntegrityError: + session.rollback() + # Already exists, query it + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_map[key] = result[0] + + return tensor_map + + +def batch_insert_configs(session, drivers: List[DriverBase], dbt: MIOpenDBTables, + logger: logging.Logger) -> Tuple[int, List]: + """Insert configs in batch with duplicate handling. Returns (count_inserted, list_of_config_objects)""" + if not drivers: + return 0, [] + + # Get config objects + config_objs = [driver.get_db_obj(keep_id=True) for driver in drivers] + + # Filter out configs that already have IDs (already in DB) + new_configs = [c for c in config_objs if c.id is None] + + if not new_configs: + return 0, config_objs + + # Get MD5s of new configs to check for existing + new_md5s = [c.md5 for c in new_configs] + existing_md5s = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(new_md5s) + ).all() + existing_md5_set = {row[0] for row in existing_md5s} + + # Filter to only truly new configs + truly_new_configs = [c for c in new_configs if c.md5 not in existing_md5_set] + + inserted_count = 0 + if truly_new_configs: + try: + # Try bulk insert + session.bulk_save_objects(truly_new_configs, return_defaults=True) + session.flush() + inserted_count = len(truly_new_configs) + except IntegrityError as err: + logger.warning(f"Bulk config insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for config in truly_new_configs: + try: + session.add(config) + session.flush() + inserted_count += 1 + except IntegrityError: + session.rollback() + + # Refresh all configs to get IDs + for config in config_objs: + if config.id is None: + # Query to get ID + result = session.query(dbt.config_table.id).filter_by(md5=config.md5).first() + if result: + config.id = result[0] + + return inserted_count, config_objs + + +def batch_insert_tags(session, config_ids: List[int], dbt: MIOpenDBTables, + args: argparse.Namespace, logger: logging.Logger) -> int: + """Insert tags in batch with duplicate handling. Returns count of tags inserted""" + if not config_ids or not (args.tag or args.mark_recurrent): + return 0 + + # Build tag dictionaries + tag_dicts = [] + for config_id in config_ids: + tag_dict = create_query(args.tag, args.mark_recurrent, config_id) + tag_dicts.append(tag_dict) + + if not tag_dicts: + return 0 + + # Get existing tags to avoid duplicates + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_config_ids = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_config_ids] + + inserted_count = 0 + if tag_dicts: + try: + # Try bulk insert + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + inserted_count = len(tag_dicts) + except IntegrityError as err: + logger.warning(f"Bulk tag insert failed, falling back to individual inserts: {err}") + session.rollback() + + # Fallback: insert one by one + for tag_dict in tag_dicts: + try: + tag_obj = dbt.config_tags_table(**tag_dict) + session.merge(tag_obj) + session.flush() + inserted_count += 1 + except IntegrityError: + session.rollback() + + return inserted_count + + +def get_or_create_tensor_ids(session, tensor_dicts: List[dict], logger: logging.Logger) -> dict: + """Get or create tensor IDs in bulk. Returns dict mapping tensor_key -> tensor_id""" + if not tensor_dicts: + return {} + + import time + t0 = time.time() + + # Build unique tensor dict map + unique_tensors = {} + for td in tensor_dicts: + td['valid'] = 1 + key = build_dict_val_key(TensorTable(**td)) + unique_tensors[key] = td + + logger.info("Found %d unique tensors to process", len(unique_tensors)) + + # Query existing tensors in bulk + existing_tensors = session.query(TensorTable).all() + tensor_id_map = {} + for tensor in existing_tensors: + key = build_dict_val_key(tensor) + if key in unique_tensors: + tensor_id_map[key] = tensor.id + + logger.info("Found %d existing tensors in DB (%.2fs)", len(tensor_id_map), time.time() - t0) + + # Insert new tensors + new_tensors = [td for key, td in unique_tensors.items() if key not in tensor_id_map] + if new_tensors: + t0 = time.time() + try: + session.bulk_insert_mappings(TensorTable, new_tensors) + session.flush() + logger.info("Bulk inserted %d new tensors (%.2fs)", len(new_tensors), time.time() - t0) + + # Query back to get IDs + for td in new_tensors: + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_id_map[key] = result[0] + except IntegrityError as err: + logger.warning(f"Bulk tensor insert failed: {err}") + session.rollback() + # Fallback to individual + for td in new_tensors: + try: + tensor = TensorTable(**td) + session.add(tensor) + session.flush() + key = build_dict_val_key(tensor) + tensor_id_map[key] = tensor.id + except IntegrityError: + session.rollback() + result = session.query(TensorTable.id).filter_by(**td).first() + if result: + key = build_dict_val_key(TensorTable(**td)) + tensor_id_map[key] = result[0] + + return tensor_id_map + + +def import_cfgs_batch_ultra(args: argparse.Namespace, dbt: MIOpenDBTables, + logger: logging.Logger, batch_size: int = 1000) -> dict: + """Ultra-optimized batch import bypassing get_db_obj()""" + import time + import hashlib + from tuna.miopen.utils.metadata import TENSOR_PRECISION + + connect_db() + + counts = {} + counts['cnt_configs'] = 0 + counts['cnt_tagged_configs'] = set() + + # Step 1: Read and parse + start_time = time.time() + logger.info("Reading and parsing config file...") + drivers_to_process = [] + unique_lines = set() + + with open(os.path.expanduser(args.file_name), "r") as infile: + for line in infile: + line = line.strip() + if line: + unique_lines.add(line) + + for line in unique_lines: + try: + if args.config_type == ConfigType.batch_norm: + driver = DriverBatchNorm(line, args.command) + else: + driver = DriverConvolution(line, args.command) + + if not args.batch_list: + drivers_to_process.append(driver) + else: + for bsz in args.batch_list: + driver_copy = DriverBatchNorm(line, args.command) if args.config_type == ConfigType.batch_norm else DriverConvolution(line, args.command) + driver_copy.batchsize = bsz + drivers_to_process.append(driver_copy) + except ValueError as err: + logger.warning(f"Error parsing line: {err}") + + parse_time = time.time() - start_time + logger.info("Parsed %u driver objects (took %.2fs)", len(drivers_to_process), parse_time) + + # Step 2: Collect all unique tensors + start_time = time.time() + logger.info("Collecting tensor dictionaries...") + all_tensor_dicts = [] + for driver in drivers_to_process: + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + all_tensor_dicts.extend([input_t, weight_t]) + + logger.info("Collected %d tensor dicts (took %.2fs)", len(all_tensor_dicts), time.time() - start_time) + + # Step 3: Batch process tensors and configs + total_drivers = len(drivers_to_process) + logger.info(f"Starting ultra-optimized batch import (batch size: {batch_size})...") + overall_start = time.time() + + for batch_start in range(0, total_drivers, batch_size): + batch_end = min(batch_start + batch_size, total_drivers) + batch = drivers_to_process[batch_start:batch_end] + + with DbSession() as session: + # Collect tensors for this batch + batch_tensor_dicts = [] + for driver in batch: + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + batch_tensor_dicts.extend([input_t, weight_t]) + + # Get/create tensor IDs + tensor_id_map = get_or_create_tensor_ids(session, batch_tensor_dicts, logger) + + # Build config dictionaries manually (bypass get_db_obj) + config_dicts = [] + for driver in batch: + try: + # Get tensor IDs + input_t = driver._MIOpenDriver__compose_input_t() if hasattr(driver, '_MIOpenDriver__compose_input_t') else {} + weight_t = driver.compose_weight_t() + input_t['valid'] = 1 + weight_t['valid'] = 1 + + input_key = build_dict_val_key(TensorTable(**input_t)) + weight_key = build_dict_val_key(TensorTable(**weight_t)) + + if input_key not in tensor_id_map or weight_key not in tensor_id_map: + logger.warning("Missing tensor IDs for config, skipping") + continue + + # Build config dict manually + config_dict = { + 'batchsize': driver.batchsize, + 'spatial_dim': driver.spatial_dim, + 'pad_h': driver.pad_h, + 'pad_w': driver.pad_w, + 'pad_d': driver.pad_d, + 'conv_stride_h': driver.conv_stride_h, + 'conv_stride_w': driver.conv_stride_w, + 'conv_stride_d': driver.conv_stride_d, + 'dilation_h': driver.dilation_h, + 'dilation_w': driver.dilation_w, + 'dilation_d': driver.dilation_d, + 'group_count': driver.group_count, + 'mode': driver.mode, + 'pad_mode': driver.pad_mode, + 'trans_output_pad_h': driver.trans_output_pad_h, + 'trans_output_pad_w': driver.trans_output_pad_w, + 'trans_output_pad_d': driver.trans_output_pad_d, + 'direction': driver.direction, + 'input_tensor': tensor_id_map[input_key], + 'weight_tensor': tensor_id_map[weight_key], + 'out_layout': driver.out_layout, + 'driver': str(driver) + } + + # Compute MD5 + dict_copy = config_dict.copy() + dict_copy.pop('driver') + md5_str = str(sorted(dict_copy.items())) + config_dict['md5'] = hashlib.md5(md5_str.encode()).hexdigest() + + config_dicts.append(config_dict) + except Exception as err: + logger.warning(f"Error building config dict: {err}") + + # Bulk insert configs + if config_dicts: + # Check for existing + md5s = [cd['md5'] for cd in config_dicts] + existing = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(md5s) + ).all() + existing_set = {row[0] for row in existing} + + new_configs = [cd for cd in config_dicts if cd['md5'] not in existing_set] + + if new_configs: + try: + session.bulk_insert_mappings(dbt.config_table, new_configs) + session.flush() + counts['cnt_configs'] += len(new_configs) + except IntegrityError as err: + logger.warning(f"Bulk config insert failed: {err}") + session.rollback() + + # Get config IDs for tagging + if args.tag or args.mark_recurrent: + config_ids = [] + for cd in config_dicts: + result = session.query(dbt.config_table.id).filter_by(md5=cd['md5']).first() + if result: + config_ids.append(result[0]) + + if config_ids: + tag_dicts = [create_query(args.tag, args.mark_recurrent, cid) for cid in config_ids] + + # Filter existing tags + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_tag_set = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_tag_set] + + if tag_dicts: + try: + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + counts['cnt_tagged_configs'].update([td['config'] for td in tag_dicts]) + except IntegrityError: + session.rollback() + + session.commit() + + if batch_end % 1000 == 0 or batch_end == total_drivers: + logger.info(f"Processed {batch_end}/{total_drivers} configs") + + total_time = time.time() - overall_start + logger.info("Ultra-optimized import complete (took %.2fs, %.2f configs/sec)", + total_time, total_drivers / total_time if total_time > 0 else 0) + return counts + + +def import_cfgs_batch(args: argparse.Namespace, dbt: MIOpenDBTables, + logger: logging.Logger, batch_size: int = 1000) -> dict: + """Optimized batch import of configs with proper tensor handling""" + import time + from tuna.utils.db_utility import get_session_val_map + from tuna.miopen.driver.base import MIOpenDriver + + connect_db() + + counts = {} + counts['cnt_configs'] = 0 + counts['cnt_tagged_configs'] = set() + unique_lines = set() + + # Step 1: Read and deduplicate file + start_time = time.time() + logger.info("Reading and deduplicating config file...") + with open(os.path.expanduser(args.file_name), "r") as infile: + for line_cnt, line in enumerate(infile, 1): + line = line.strip() + if line: + unique_lines.add(line) + if line_cnt % 10000 == 0: + logger.info("Parsed: %u lines, unique configs: %u", line_cnt, len(unique_lines)) + + parse_time = time.time() - start_time + logger.info("File parsing complete. Total lines: %u, unique configs: %u (took %.2fs)", + line_cnt, len(unique_lines), parse_time) + + # Step 2: Pre-load tensor cache to avoid repeated queries + start_time = time.time() + logger.info("Pre-loading tensor cache...") + with DbSession() as session: + tensor_attr = [column.name for column in TensorTable.__table__.columns] + MIOpenDriver.tensor_id_map = get_session_val_map(session, TensorTable, tensor_attr) + cache_time = time.time() - start_time + logger.info("Tensor cache loaded with %u entries (took %.2fs)", + len(MIOpenDriver.tensor_id_map), cache_time) + + # Step 3: Parse all driver objects + start_time = time.time() + logger.info("Parsing driver commands...") + drivers_to_process = [] + for line in unique_lines: + try: + if args.config_type == ConfigType.batch_norm: + driver = DriverBatchNorm(line, args.command) + else: + driver = DriverConvolution(line, args.command) + + if not args.batch_list: + drivers_to_process.append(driver) + else: + for bsz in args.batch_list: + driver_copy = DriverBatchNorm(line, args.command) if args.config_type == ConfigType.batch_norm else DriverConvolution(line, args.command) + driver_copy.batchsize = bsz + drivers_to_process.append(driver_copy) + except ValueError as err: + logger.warning(f"Error parsing line: {err}") + + driver_parse_time = time.time() - start_time + logger.info("Parsed %u driver objects to import (took %.2fs)", + len(drivers_to_process), driver_parse_time) + + # Step 4: Process in batches with true batch operations + total_drivers = len(drivers_to_process) + start_time = time.time() + logger.info(f"Starting batch import (batch size: {batch_size})...") + + batch_times = {'get_db_obj': 0, 'check_existing': 0, 'insert_configs': 0, 'insert_tags': 0, 'commit': 0} + + for batch_start in range(0, total_drivers, batch_size): + batch_end = min(batch_start + batch_size, total_drivers) + batch = drivers_to_process[batch_start:batch_end] + batch_start_time = time.time() + + with DbSession() as session: + # Collect all config objects for this batch + t0 = time.time() + config_objs = [] + for driver in batch: + try: + config_obj = driver.get_db_obj(keep_id=True) + config_objs.append((driver, config_obj)) + except ValueError as err: + logger.warning(f"Error creating config object: {err}") + batch_times['get_db_obj'] += time.time() - t0 + + if not args.tag_only: + # Batch insert configs + t0 = time.time() + new_configs = [c for d, c in config_objs if c.id is None] + + if new_configs: + # Check for existing configs by MD5 + new_md5s = [c.md5 for c in new_configs] + existing_md5s = session.query(dbt.config_table.md5).filter( + dbt.config_table.md5.in_(new_md5s) + ).all() + existing_md5_set = {row[0] for row in existing_md5s} + batch_times['check_existing'] += time.time() - t0 + + # Filter to truly new configs + truly_new = [c for c in new_configs if c.md5 not in existing_md5_set] + + if truly_new: + t0 = time.time() + try: + session.bulk_save_objects(truly_new, return_defaults=True) + session.flush() + counts['cnt_configs'] += len(truly_new) + except IntegrityError as err: + logger.warning(f"Bulk insert failed, using individual inserts: {err}") + session.rollback() + for config in truly_new: + try: + session.add(config) + session.flush() + counts['cnt_configs'] += 1 + except IntegrityError: + session.rollback() + batch_times['insert_configs'] += time.time() - t0 + + # Refresh configs to get IDs + for config in new_configs: + if config.id is None: + result = session.query(dbt.config_table.id).filter_by(md5=config.md5).first() + if result: + config.id = result[0] + + # Batch insert tags + if args.tag or args.mark_recurrent: + t0 = time.time() + config_ids = [c.id for d, c in config_objs if c.id is not None] + + if config_ids: + tag_dicts = [create_query(args.tag, args.mark_recurrent, cid) for cid in config_ids] + + # Filter out existing tags + if args.tag: + existing_tags = session.query(dbt.config_tags_table.config).filter( + dbt.config_tags_table.config.in_(config_ids), + dbt.config_tags_table.tag == args.tag + ).all() + existing_set = {row[0] for row in existing_tags} + tag_dicts = [td for td in tag_dicts if td['config'] not in existing_set] + + if tag_dicts: + try: + session.bulk_insert_mappings(dbt.config_tags_table, tag_dicts) + session.flush() + counts['cnt_tagged_configs'].update([td['config'] for td in tag_dicts]) + except IntegrityError as err: + logger.warning(f"Bulk tag insert failed, using individual inserts: {err}") + session.rollback() + for tag_dict in tag_dicts: + try: + tag_obj = dbt.config_tags_table(**tag_dict) + session.merge(tag_obj) + session.flush() + counts['cnt_tagged_configs'].add(tag_dict['config']) + except IntegrityError: + session.rollback() + batch_times['insert_tags'] += time.time() - t0 + + # Commit the entire batch + t0 = time.time() + try: + session.commit() + except IntegrityError as err: + logger.error(f"Batch commit failed: {err}") + session.rollback() + batch_times['commit'] += time.time() - t0 + + if batch_end % 1000 == 0 or batch_end == total_drivers: + batch_elapsed = time.time() - batch_start_time + logger.info(f"Processed {batch_end}/{total_drivers} configs (batch took {batch_elapsed:.2f}s)") + + total_import_time = time.time() - start_time + logger.info("Database import complete (took %.2fs)", total_import_time) + logger.info("Timing breakdown: get_db_obj=%.2fs, check_existing=%.2fs, insert_configs=%.2fs, insert_tags=%.2fs, commit=%.2fs", + batch_times['get_db_obj'], batch_times['check_existing'], + batch_times['insert_configs'], batch_times['insert_tags'], batch_times['commit']) + + logger.info("Database import complete.") + return counts + + def import_cfgs(args: argparse.Namespace, dbt: MIOpenDBTables, logger: logging.Logger) -> dict: """import configs to mysql from file with driver invocations""" @@ -163,21 +765,29 @@ def import_cfgs(args: argparse.Namespace, dbt: MIOpenDBTables, counts: dict = {} counts['cnt_configs'] = 0 counts['cnt_tagged_configs'] = set() - unique_lines: List[str] = [] + unique_lines = set() + + logger.info("Reading and deduplicating config file...") with open(os.path.expanduser(args.file_name), "r") as infile: # pylint: disable=unspecified-encoding - line_cnt = 0 - for line in infile: - line_cnt += 1 + for line_cnt, line in enumerate(infile, 1): line = line.strip() - if not line in unique_lines: - unique_lines.append(line) - logger.info("parsed: %u, unique: %u", line_cnt, len(unique_lines)) - for line in unique_lines: - try: - parse_line(args, line, counts, dbt, logger) - except ValueError as err: - logger.warning(err) - + if line: # Skip empty lines + unique_lines.add(line) + if line_cnt % 10000 == 0: + logger.info("Parsed: %u lines, unique configs: %u", line_cnt, len(unique_lines)) + + logger.info("File parsing complete. Total lines: %u, unique configs: %u", line_cnt, len(unique_lines)) + logger.info("Starting database import...") + + for idx, line in enumerate(unique_lines, 1): + try: + parse_line(args, line, counts, dbt, logger) + if idx % 1000 == 0: + logger.info("Processed %u/%u unique configs", idx, len(unique_lines)) + except ValueError as err: + logger.warning(err) + + logger.info("Database import complete.") return counts @@ -352,7 +962,20 @@ def run_import_configs(args: argparse.Namespace, return True set_import_cfg_batches(args) - counts = import_cfgs(args, dbt, logger) + + # Use batch import by default unless disabled or tag_only mode + use_batch = not getattr(args, 'disable_batch_import', False) and not args.tag_only + batch_size = getattr(args, 'batch_size', 1000) + + if use_batch: + logger.info("Using optimized batch import (batch_size=%d)", batch_size) + counts = import_cfgs_batch(args, dbt, logger, batch_size) + else: + if args.tag_only: + logger.info("Using original import (tag_only mode)") + else: + logger.info("Using original import (batch import disabled)") + counts = import_cfgs(args, dbt, logger) logger.info('New configs added: %u', counts['cnt_configs']) if args.tag or args.tag_only: diff --git a/tuna/miopen/subcmd/load_job.py b/tuna/miopen/subcmd/load_job.py index 1cc9fe954..61a9cde6a 100755 --- a/tuna/miopen/subcmd/load_job.py +++ b/tuna/miopen/subcmd/load_job.py @@ -34,6 +34,7 @@ from sqlalchemy.exc import IntegrityError #pylint: disable=wrong-import-order from sqlalchemy.sql.expression import true +from sqlalchemy import text from tuna.miopen.utils.metadata import ALG_SLV_MAP, TENSOR_PRECISION from tuna.miopen.db.solver import get_solver_ids @@ -105,7 +106,7 @@ def config_query(args: argparse.Namespace, session, dbt: MIOpenDBTables): if args.tag: tag_query = session.query(dbt.config_tags_table.config)\ .filter(dbt.config_tags_table.tag == args.tag).subquery() - cfg_query = cfg_query.filter(dbt.config_table.id.in_(tag_query)) + cfg_query = cfg_query.filter(dbt.config_table.id.in_(tag_query.select())) if args.cmd: cfg_query = cfg_query.filter( @@ -135,7 +136,7 @@ def compose_query(args: argparse.Namespace, session, dbt: MIOpenDBTables, if args.only_dynamic: query = query.filter(Solver.is_dynamic == true()) - query = query.filter(dbt.solver_app.config.in_(cfg_query.subquery())) + query = query.filter(dbt.solver_app.config.in_(cfg_query.subquery().select())) return query @@ -160,7 +161,7 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables, where session={args.session_id} and fin_step='{fin_step_str}'" logger.info(query) - ret = session.execute(query) + ret = session.execute(text(query)) pre_ex: Dict[str, Dict[str, bool]] = {} for config, solver in ret: if config not in pre_ex: @@ -182,8 +183,8 @@ def add_jobs(args: argparse.Namespace, dbt: MIOpenDBTables, if job.config in pre_ex: if job.solver in pre_ex[job.config]: - logger.warning("Job exists (skip): %s : %s", job.config, - job.solver) + # logger.warning("Job exists (skip): %s : %s", job.config, + # job.solver) continue session.add(job) diff --git a/tuna/miopen/subcmd/update_golden.py b/tuna/miopen/subcmd/update_golden.py index 7ec485cff..e0d3f268d 100755 --- a/tuna/miopen/subcmd/update_golden.py +++ b/tuna/miopen/subcmd/update_golden.py @@ -31,6 +31,7 @@ from typing import Dict, Any from sqlalchemy.sql.expression import func as sqlfunc from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.miopen.parse_miopen_args import get_update_golden_parser from tuna.dbBase.sql_alchemy import DbSession @@ -142,9 +143,10 @@ def create_perf_table(args: argparse.Namespace, logger: logging.Logger): print(table_name) with ENGINE.connect() as conn: try: - conn.execute(f'drop table if exists {table_name}') + conn.execute(text(f'drop table if exists {table_name}')) logger.info('Creating new performance table %s', table_name) - conn.execute(get_perf_str(args, table_name)) + conn.execute(text(get_perf_str(args, table_name))) + conn.commit() logger.info('Done creating new performance table %s', table_name) except OperationalError as oerr: logger.info('%s \n', oerr) @@ -169,7 +171,7 @@ def gold_base_update(session: DbSession, f" where cg.golden_miopen_v={gold_v} and ps.golden_miopen_v={base_gold_v} and ps.valid=1"\ " and ps.kernel_time>0;" logger.info(update_q) - session.execute(update_q) + session.execute(text(update_q)) logger.info("Inserting golden version %s -> %s.", base_gold_v, gold_v) insert_q = "insert ignore into conv_golden (valid, golden_miopen_v, arch, num_cu, config"\ @@ -178,7 +180,7 @@ def gold_base_update(session: DbSession, ", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\ f" from conv_golden where golden_miopen_v={base_gold_v} and valid=1 and kernel_time>0;" logger.info(insert_q) - session.execute(insert_q) + session.execute(text(insert_q)) session.commit() return True @@ -200,7 +202,7 @@ def gold_session_update(session: DbSession, ", cg.kernel_time=ps.kernel_time, cg.kernel_group=ps.kernel_group, cg.session=ps.session"\ f" where cg.golden_miopen_v={gold_v} and ps.session={tune_s} and ps.valid=1"\ " and ps.kernel_time>0;" - session.execute(update_q) + session.execute(text(update_q)) logger.info("Gold %s Insert session %s.", gold_v, tune_s) insert_q = "insert ignore into conv_golden (valid, golden_miopen_v, arch, num_cu, config"\ @@ -209,7 +211,7 @@ def gold_session_update(session: DbSession, ", workspace_sz, alg_lib, opencl, kernel_group, session, solver"\ " from conv_find_db as cfd inner join session as s on cfd.session=s.id"\ f" where session={tune_s} and cfd.valid=1 and kernel_time>0;" - session.execute(insert_q) + session.execute(text(insert_q)) session.commit() return True diff --git a/tuna/miopen/utils/helper.py b/tuna/miopen/utils/helper.py index fea4ffade..f71cee520 100644 --- a/tuna/miopen/utils/helper.py +++ b/tuna/miopen/utils/helper.py @@ -31,6 +31,7 @@ from time import sleep from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.orm import Query +from sqlalchemy import text from tuna.utils.logger import setup_logger from tuna.dbBase.sql_alchemy import DbSession @@ -203,18 +204,42 @@ def get_db_id(db_elems, config_table): return cid -def set_job_state(session, job, dbt, state, increment_retries=False, result=""): +def set_job_state(session, + job, + dbt, + state, + increment_retries=False, + result="", + machine_id=None): """Update job state for builder/evaluator job_set_attr: List[str]""" LOGGER.info('Setting job id %s state to %s', job.id, state) job_set_attr = ['state', 'gpu_id'] job.state = state + + # Add machine_id if provided + if machine_id is not None: + job_set_attr.append('machine_id') + job.machine_id = machine_id + LOGGER.info('Setting job %s machine_id to %s', job.id, machine_id) + if result: job_set_attr.append('result') job.result = result if increment_retries: job_set_attr.append('retries') - job.retries += 1 + # Query current retry count from database to avoid using stale context data + query_retries = f"SELECT retries FROM {dbt.job_table.__tablename__} WHERE id = {job.id}" + current_retries = session.execute(text(query_retries)).scalar() + if current_retries is not None: + job.retries = current_retries + 1 + LOGGER.info('Job %s retry count: %d -> %d', job.id, current_retries, + job.retries) + else: + # Fallback if query fails + job.retries = getattr(job, 'retries', 0) + 1 + LOGGER.warning( + 'Could not query current retries for job %s, using fallback', job.id) #pylint: disable=duplicate-code if '_start' in state: @@ -229,7 +254,7 @@ def set_job_state(session, job, dbt, state, increment_retries=False, result=""): query: str = gen_update_query(job, job_set_attr, dbt.job_table.__tablename__) def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True diff --git a/tuna/miopen/utils/json_to_sql.py b/tuna/miopen/utils/json_to_sql.py index 3b3eb31d4..5e0256263 100644 --- a/tuna/miopen/utils/json_to_sql.py +++ b/tuna/miopen/utils/json_to_sql.py @@ -27,6 +27,7 @@ """Utility module for parsing fin json results""" import functools from sqlalchemy.exc import OperationalError +from sqlalchemy import text from tuna.utils.logger import setup_logger from tuna.dbBase.sql_alchemy import DbSession @@ -68,13 +69,13 @@ def __update_fdb_w_kernels( #pylint: disable=too-many-arguments,too-many-locals if not pending: query = gen_update_query(fdb_entry, fdb_attr, dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) else: assert len(pending) == 1 pending.pop() query = gen_insert_query(fdb_entry, fdb_attr, dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) fdb_entry = __update_fdb_entry(session, solver_id_map[fdb_obj['solver_name']], @@ -83,7 +84,7 @@ def __update_fdb_w_kernels( #pylint: disable=too-many-arguments,too-many-locals fdb_entry.kernel_group = fdb_entry.id query = gen_update_query(fdb_entry, ['kernel_group'], dbt.find_db_table.__tablename__) - session.execute(query) + session.execute(text(query)) if fdb_obj['reason'] == 'Success': __compose_kernel_entry(session, fdb_obj, fdb_entry, dbt) @@ -377,11 +378,11 @@ def __submit_tuning_data_entry( #pylint: disable=too-many-arguments pending.remove(tuning_data_entry) query = gen_insert_query(tuning_data_entry, tuning_data_attr, dbt.tuning_data_table.__tablename__) - session.execute(query) + session.execute(text(query)) else: query = gen_update_query(tuning_data_entry, tuning_data_attr, dbt.tuning_data_table.__tablename__) - session.execute(query) + session.execute(text(query)) def process_fdb_w_kernels(session, diff --git a/tuna/miopen/worker/fin_class.py b/tuna/miopen/worker/fin_class.py index 691869250..41c3f74fb 100644 --- a/tuna/miopen/worker/fin_class.py +++ b/tuna/miopen/worker/fin_class.py @@ -39,7 +39,7 @@ except ImportError: import Queue as queue #type: ignore -from sqlalchemy import func as sqlalchemy_func +from sqlalchemy import func as sqlalchemy_func, text from sqlalchemy.exc import IntegrityError, InvalidRequestError #pylint: disable=wrong-import-order from sqlalchemy.inspection import inspect @@ -67,7 +67,7 @@ def __init__(self, **kwargs): """Constructor""" allowed_keys = set([ 'fin_steps', 'local_file', 'fin_infile', 'fin_outfile', 'config_type', - 'dynamic_solvers_only' + 'dynamic_solvers_only', 'new_only', 'config_limit' ]) self.__dict__.update((key, None) for key in allowed_keys) @@ -101,6 +101,15 @@ def __init__(self, **kwargs): ) self.envmt.append( f"MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopenpdb/thread-{self.gpu_id}/cache") + + if hasattr(self, 'gpu_id') and self.gpu_id is not None: + num_gpus = len(self.machine.get_avail_gpus()) if hasattr(self, 'machine') else 1 + actual_gpu = self.gpu_id % num_gpus # Wrap around available GPUs + self.envmt.append(f"ROCR_VISIBLE_DEVICES={actual_gpu}") + self.logger.info("Set ROCR_VISIBLE_DEVICES=%d for worker (worker_id=%d, num_gpus=%d)", + actual_gpu, self.gpu_id, num_gpus) + else: + self.logger.warning("gpu_id not set - ROCR_VISIBLE_DEVICES not configured. All workers may use same GPU!") self.cfg_attr = [column.name for column in inspect(self.dbt.config_table).c] @@ -319,27 +328,48 @@ def applicability(self): return True - def query_cfgs(self, label=None): - """query all configs from table, optionally limit by label""" + def query_cfgs(self, label=None, skip_existing=False, config_limit=None): + """query all configs from table, optionally limit by label, skip existing, and limit count""" with DbSession() as session: query = session.query(self.dbt.config_table)\ .filter(self.dbt.config_table.valid == 1) if label: - query = query.filter(self.dbt.config_table.id == self.dbt.config_tags_table.config)\ - .filter(self.dbt.config_tags_table.tag == label) + query = query.join( + self.dbt.config_tags_table, + self.dbt.config_table.id == self.dbt.config_tags_table.config + ).filter(self.dbt.config_tags_table.tag == label) + + + # Skip configs that already have applicability data in this session + if skip_existing: + query = query.outerjoin( + self.dbt.solver_app, + (self.dbt.solver_app.config == self.dbt.config_table.id) & + (self.dbt.solver_app.session == self.session_id) + ).filter(self.dbt.solver_app.id == None) #order by id for splitting configs into blocks query = query.order_by(self.dbt.config_table.id) + + # Apply config limit if specified + if config_limit is not None and config_limit > 0: + query = query.limit(config_limit) + self.logger.info("Limiting query to %d configs", config_limit) + return query def __set_all_configs(self, idx: int = 0, num_blk: int = 1) -> bool: """Gathering all configs from Tuna DB to set up fin input file""" if idx == 0: - query = self.query_cfgs(self.label) + skip_existing = getattr(self, 'new_only', False) + config_limit = getattr(self, 'config_limit', None) + query = self.query_cfgs(self.label, skip_existing=skip_existing, config_limit=config_limit) rows = query.all() len_rows = len(rows) + self.logger.warning("Query returned %d configs (label=%s, skip_existing=%s, config_limit=%s)", + len_rows, self.label, skip_existing, config_limit) master_cfg_list = [] for row in rows: r_dict = compose_config_obj(row, self.config_type) @@ -491,7 +521,7 @@ def __insert_applicability(self, session: DbSession, self.logger.info('Commit bulk configs (%s), entries (%s), please wait', len(app_cfgs), len(app_values)) for sql_str in inserts: - session.execute(sql_str) + session.execute(text(sql_str)) session.commit() self.logger.info('End bulk inserts') diff --git a/tuna/mituna_interface.py b/tuna/mituna_interface.py index 140d1542b..78e9fdcfa 100644 --- a/tuna/mituna_interface.py +++ b/tuna/mituna_interface.py @@ -25,45 +25,62 @@ # ############################################################################### """Interface class to set up and launch tuning functionality""" -import os -from multiprocessing import Value, Lock, Queue as mpQueue, Process -from typing import Optional, Dict, Any, List -from io import StringIO -from functools import lru_cache +import argparse +import asyncio import json import logging -import argparse +import os import subprocess -import time +import sys import threading -import asyncio +import time from datetime import timedelta -from sqlalchemy.exc import NoInspectionAvailable -from sqlalchemy.inspection import inspect -import aioredis +from functools import lru_cache +from io import StringIO +from multiprocessing import Lock, Manager, Process +from multiprocessing import Queue as mpQueue +from multiprocessing import Value +from typing import Any, Dict, List, Optional + import kombu +import redis.asyncio as aioredis from paramiko.channel import ChannelFile +from sqlalchemy import text +from sqlalchemy.exc import NoInspectionAvailable +from sqlalchemy.inspection import inspect -from tuna.worker_interface import WorkerInterface -from tuna.machine import Machine -from tuna.libraries import Library -from tuna.utils.logger import setup_logger -from tuna.utils.utility import get_env_vars, SimpleDict -from tuna.dbBase.sql_alchemy import DbSession -from tuna.celery_app.celery_app import stop_active_workers, stop_named_worker -from tuna.celery_app.celery_app import get_backend_env, purge_queue -from tuna.celery_app.utility import get_q_name +from tuna.celery_app.celery_app import (get_backend_env, purge_queue, + stop_active_workers, stop_named_worker) from tuna.celery_app.celery_workers import launch_celery_worker -from tuna.libraries import Operation +from tuna.celery_app.utility import get_q_name from tuna.custom_errors import CustomError +from tuna.dbBase.sql_alchemy import DbSession +from tuna.libraries import Library, Operation +from tuna.machine import Machine from tuna.utils.db_utility import gen_update_query, session_retry +from tuna.utils.logger import setup_logger +from tuna.utils.utility import SimpleDict, get_env_vars +from tuna.worker_interface import WorkerInterface job_counter_lock = threading.Lock() -class MITunaInterface(): #pylint:disable=too-many-instance-attributes,too-many-public-methods - """ Interface class extended by libraries. The purpose of this class is to define - common functionalities. """ +class MITunaInterface: # pylint:disable=too-many-instance-attributes,too-many-public-methods + """Interface class extended by libraries. The purpose of this class is to define + common functionalities. + + Job Progress Tracking: + ---------------------- + The distributor uses database queries to track job progress, ensuring accuracy + and eliminating synchronization issues with in-memory tracking lists. + + - claimed_job_ids: List of job IDs claimed by this distributor instance + - Progress checking: Queries database directly for actual job states + - No reconciliation needed: Database is the single source of truth + + The Redis consumer still runs to process results and update the database, + but progress decisions are based solely on database queries. + """ def __init__(self, library=Library.MIOPEN) -> None: @@ -77,16 +94,22 @@ def __init__(self, library=Library.MIOPEN) -> None: self.max_job_retries = 10 self.dbt = None self.operation = None - self.db_name = os.environ['TUNA_DB_NAME'] + self.db_name = os.environ["TUNA_DB_NAME"] self.prefix = None + # Track jobs claimed by this specific instance when in distributor mode + self.claimed_job_ids = set() + self.completed_job_ids = set() + # if less than 25% of the jobs are remaining, we can grab more jobs + self.progress_factor = 0.25 + def check_docker(self, worker: WorkerInterface, dockername="miopentuna") -> bool: """! Checking for docker - @param worker The worker interface instance - @param dockername The name of the docker - """ + @param worker The worker interface instance + @param dockername The name of the docker + """ out2: ChannelFile _, out2, _ = worker.exec_command("sudo docker info") while not out2.channel.exit_status_ready(): @@ -102,34 +125,44 @@ def check_docker(self, for line in out.readlines(): if line is not None: if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) + self.logger.warning("%s docker image exists", dockername) return True if line is None: - self.logger.warning('%s docker image does not exist', dockername) + self.logger.warning("%s docker image does not exist", dockername) return False return False - def check_status(self, - worker: WorkerInterface, - b_first: int, - gpu_idx: int, - machine: Machine, - dockername: str = "miopentuna") -> bool: + def check_status( + self, + worker: WorkerInterface, + b_first: int, + gpu_idx: int, + machine: Machine, + dockername: str = "miopentuna", + ) -> bool: """! Function to check gpu_status - @param worker The worker interface instance - @param b_first Flag to keep track of visited GPU - @param gpu_idx Unique ID of the GPU - @param machine The machine instance - @param dockername The name of the docker - """ + @param worker The worker interface instance + @param b_first Flag to keep track of visited GPU + @param gpu_idx Unique ID of the GPU + @param machine The machine instance + @param dockername The name of the docker + """ if machine.chk_gpu_status(worker.gpu_id): - self.logger.info('Machine: (%s, %u) GPU_ID: %u OK', machine.hostname, - machine.port, gpu_idx) + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u OK", + machine.hostname, + machine.port, + gpu_idx, + ) else: - self.logger.info('Machine: (%s, %u) GPU_ID: %u ERROR', machine.hostname, - machine.port, gpu_idx) + self.logger.info( + "Machine: (%s, %u) GPU_ID: %u ERROR", + machine.hostname, + machine.port, + gpu_idx, + ) if not b_first: return False @@ -146,10 +179,10 @@ def check_status(self, for line in out.readlines(): if line is not None: if line.find(dockername) != -1: - self.logger.warning('%s docker image exists', dockername) + self.logger.warning("%s docker image exists", dockername) break else: - self.logger.warning('%s docker image does not exist', dockername) + self.logger.warning("%s docker image does not exist", dockername) return True @@ -163,16 +196,16 @@ def get_num_procs(self, machine: Machine) -> List: num_procs: int env: Dict[str, Any] env = get_env_vars() - if env['slurm_cpus'] > 0: - num_procs = int(env['slurm_cpus']) + if env["slurm_cpus"] > 0: + num_procs = int(env["slurm_cpus"]) else: - num_procs = int(machine.get_num_cpus() * .6) + num_procs = int(machine.get_num_cpus() * 0.6) worker_ids = list(range(num_procs)) if len(worker_ids) == 0: - self.logger.error('num_procs must be bigger than zero to launch worker') - self.logger.error('Cannot launch worker on machine: %s', machine.id) + self.logger.error("num_procs must be bigger than zero to launch worker") + self.logger.error("Cannot launch worker on machine: %s", machine.id) worker_ids = [] return worker_ids @@ -181,14 +214,14 @@ def get_f_vals(self, machine: Machine, worker_ids: range, tuning=False) -> Dict[str, Any]: - #pylint:disable=unused-argument + # pylint:disable=unused-argument """Determine kwargs for worker_interface""" f_vals: Dict[str, Any] f_vals = self.compose_f_vals(machine) - f_vals['envmt'] = self.get_envmt() + f_vals["envmt"] = self.get_envmt() if not tuning: - f_vals["num_procs"] = Value('i', len(worker_ids)) + f_vals["num_procs"] = Value("i", len(worker_ids)) return f_vals @@ -198,20 +231,20 @@ def get_envmt(self): def compose_f_vals(self, machine: Machine, tuning=False) -> Dict[str, Any]: """! Compose dict for WorkerInterface constructor - @param args The command line arguments - @param machine Machine instance - """ + @param args The command line arguments + @param machine Machine instance + """ f_vals: Dict[str, Any] = {} f_vals["b_first"] = True - #adding non-serializable obj when not running through celery + # adding non-serializable obj when not running through celery if not tuning: f_vals["machine"] = machine f_vals["bar_lock"] = Lock() - #multiprocess queue for jobs, shared on machine + # multiprocess queue for jobs, shared on machine f_vals["job_queue"] = mpQueue() f_vals["job_queue_lock"] = Lock() - f_vals["end_jobs"] = Value('i', 0) + f_vals["end_jobs"] = Value("i", 0) return f_vals @@ -220,21 +253,21 @@ def get_kwargs(self, f_vals: Dict[str, Any], tuning=False) -> Dict[str, Any]: """! Helper function to set up kwargs for worker instances - @param gpu_idx Unique ID of the GPU - @param f_vals Dict containing runtime information - """ + @param gpu_idx Unique ID of the GPU + @param f_vals Dict containing runtime information + """ envmt: Dict[str, Any] = f_vals["envmt"].copy() kwargs: Dict[str, Any] = {} kwargs = { - 'gpu_id': gpu_idx, - 'envmt': envmt, - 'label': self.args.label, - 'docker_name': self.args.docker_name, - 'session_id': self.args.session_id + "gpu_id": gpu_idx, + "envmt": envmt, + "label": self.args.label, + "docker_name": self.args.docker_name, + "session_id": self.args.session_id, } - #adding non-serializable obj when not running through celery + # adding non-serializable obj when not running through celery if not tuning: kwargs["machine"] = f_vals["machine"] kwargs["job_queue"] = f_vals["job_queue"] @@ -251,19 +284,20 @@ def get_job_list(self, session, find_state, claim_num): """Get list of jobs""" raise NotImplementedError("Not implemented") - def get_jobs(self, - session: DbSession, - find_state: List[str], - set_state: str, - session_id: int, - claim_num: int = None, - no_update=False): + def get_jobs( + self, + session: DbSession, + find_state: List[str], + set_state: str, + session_id: int, + claim_num: int = None, + no_update=False, + ): """Interface function to get jobs based on session and find_state""" - #job_rows: List[SimpleDict] + # job_rows: List[SimpleDict] ids: list row: SimpleDict - self.logger.info('Fetching DB rows...') job_list = self.get_job_list(session, find_state, claim_num) if not self.check_jobs_found(job_list, find_state, session_id): @@ -273,16 +307,24 @@ def get_jobs(self, return job_list ids = [row.id for row in job_list] - self.logger.info("%s jobs %s", find_state, ids) - self.logger.info('Updating job state to %s', set_state) - for job in job_list: - job.state = set_state - if self.dbt is not None: - query: str = gen_update_query(job, ['state'], - self.dbt.job_table.__tablename__) - else: - raise CustomError('DBTable must be set') - session.execute(query) + # Log summary of jobs being updated + self.logger.info("Updating %d jobs from %s to %s", len(ids), find_state, set_state) + + # OPTIMIZATION: Use bulk UPDATE instead of individual updates + if self.dbt is not None: + id_str = ','.join(map(str, ids)) + query = f""" + UPDATE {self.dbt.job_table.__tablename__} + SET state = '{set_state}' + WHERE id IN ({id_str}) + """ + session.execute(text(query)) + + # Update local objects to reflect new state + for job in job_list: + job.state = set_state + else: + raise CustomError("DBTable must be set") session.commit() @@ -295,68 +337,306 @@ def shutdown_workers(self): def cancel_consumer(self, queue): """Cancel consumers for queue""" try: - cmd = f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" - subp = subprocess.Popen( #pylint: disable=consider-using-with + cmd = ( + f"celery -A tuna.celery_app.celery_app control cancel_consumer {queue}" + ) + subp = subprocess.Popen( # pylint: disable=consider-using-with cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True, - universal_newlines=True) + universal_newlines=True, + ) - #filter the workers by session id - sess_str = "sess_" + queue.split('_')[-1] + # filter the workers by session id + sess_str = "sess_" + queue.split("_")[-1] stdout, _ = subp.stdout, subp.stderr while True: line = stdout.readline() if not line: break - #stop workers that were feeding from this queue + # stop workers that were feeding from this queue if "->" in line and sess_str in line: - hostname = line.split('->')[1].split()[0].split(':')[0] + hostname = line.split("->")[1].split()[0].split(":")[0] stop_named_worker(hostname) - except Exception as exp: #pylint: disable=broad-exception-caught + except Exception as exp: # pylint: disable=broad-exception-caught self.logger.warning( - 'Error occurred trying to cancel consumer for queue: %s ', queue) + "Error occurred trying to cancel consumer for queue: %s ", queue) self.logger.warning(exp) return False - self.logger.info('Sucessfully cancelled consumer for queue: %s', queue) + self.logger.info("Sucessfully cancelled consumer for queue: %s", queue) return True def celery_enqueue_call(self, context, q_name, task_id=False): """Wrapper function for celery enqueue func""" - raise NotImplementedError('Not implemented') + raise NotImplementedError("Not implemented") - def enqueue_jobs(self, job_counter, job_batch_size, q_name): - """Enqueue celery jobs""" - self.logger.info('Starting enqueue') + def _should_wait_for_progress(self, job_batch_size): + """Check if we should wait before fetching more jobs based on database state + + This method queries the database directly to get accurate job counts, + eliminating reliance on potentially stale in-memory tracking lists. + """ + if not self.claimed_job_ids: + # No jobs claimed yet, don't wait + return False + + progress_threshold = job_batch_size * self.progress_factor + + # Query database for actual state of claimed jobs with DbSession() as session: - while True: - job_list = [] - #get all the jobs from mySQL - job_list = self.get_jobs( - session, - self.fetch_state, - self.set_state, #pylint: disable=no-member - self.args.session_id, #pylint: disable=no-member - job_batch_size) - - with job_counter_lock: - job_counter.value = job_counter.value + len(job_list) - - for i in range(0, len(job_list), job_batch_size): - batch_jobs = job_list[i:min(i + job_batch_size, len(job_list))] - context_list = self.get_context_list(session, batch_jobs) - for context in context_list: - #calling celery task, enqueuing to celery queue - self.celery_enqueue_call(context, q_name=q_name) - - self.logger.info('Job counter: %s', job_counter.value) - if not job_list: - self.logger.info('All tasks added to queue') - break + try: + # Batch the query to avoid SQL statement too long + claimed_list = list(self.claimed_job_ids) + batch_size = 1000 + total_in_progress = 0 + total_completed = 0 + + for i in range(0, len(claimed_list), batch_size): + batch = claimed_list[i:i + batch_size] + id_str = ','.join(map(str, batch)) + + # Count jobs still in progress states + in_progress_query = f""" + SELECT COUNT(*) FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + AND state IN ('eval_start', 'compile_start') + """ + batch_in_progress = session.execute(text(in_progress_query)).scalar() + total_in_progress += batch_in_progress + + # Count completed jobs + completed_query = f""" + SELECT COUNT(*) FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + AND state IN ('evaluated', 'errored', 'completed') + """ + batch_completed = session.execute(text(completed_query)).scalar() + total_completed += batch_completed + + self.logger.info( + "DB query - Jobs in progress: %d, completed: %d, threshold: %.0f", + total_in_progress, + total_completed, + progress_threshold, + ) + + return total_in_progress >= progress_threshold + + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error querying job progress: %s", err) + # On error, be conservative: assume we should wait (return True) + # This prevents over-fetching if database is temporarily unavailable + self.logger.warning("Defaulting to WAIT due to database error (conservative approach)") + return True + + def _fetch_jobs_with_retry(self, + job_batch_size, + max_retries=3, + retry_delay=5): + """Fetch jobs from database with retry logic + + Returns: + List of jobs if successful, empty list if no jobs found, None if error + """ + for attempt in range(max_retries): + try: + with DbSession() as session: + job_list = self.get_jobs( + session, + self.fetch_state, + self.set_state, # pylint: disable=no-member + self.args.session_id, # pylint: disable=no-member + job_batch_size, + ) + return job_list + + except Exception as db_err: # pylint: disable=broad-exception-caught + self.logger.warning('Database error on attempt %d/%d: %s', attempt + 1, + max_retries, db_err) + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) # Exponential backoff + else: + self.logger.error('Max retries exceeded for database operation.') + raise + + return None + + def _process_job_batch(self, job_list, job_counter, q_name): + """Process a batch of jobs by enqueuing them to Celery""" + # Track the jobs we just claimed (extend list with new job IDs) + new_job_ids = [job.id for job in job_list] + self.claimed_job_ids.extend(new_job_ids) + self.logger.info("Claimed %d jobs", len(new_job_ids)) + + # Update job counter + with job_counter_lock: + job_counter.value = job_counter.value + len(job_list) + + # Get context and enqueue each job + with DbSession() as session: + context_list = self.get_context_list(session, job_list) + + for context in context_list: + try: + self.celery_enqueue_call(context, q_name=q_name) + except Exception as enqueue_err: # pylint: disable=broad-exception-caught + self.logger.error('Failed to enqueue job: %s', enqueue_err) + continue + + self.logger.info( + "Job counter: %s, enqueued batch size: %s", + job_counter.value, + len(job_list), + ) + + # Cleanup old tracking data periodically + self.cleanup_completed_jobs() + + def enqueue_jobs(self, job_counter, job_batch_size, q_name): + """Enqueue celery jobs with simplified progress tracking""" + # Configure logger for subprocess to write to stdout + # This ensures logs are captured by bash redirection (> logfile.log 2>&1) + + # Remove any existing handlers to avoid duplicates + self.logger.handlers.clear() + + # Add StreamHandler that writes to stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s' + ) + stdout_handler.setFormatter(formatter) + self.logger.addHandler(stdout_handler) + self.logger.setLevel(logging.INFO) + + self.logger.info("Starting enqueue loop - batch_size=%d, queue=%s", + job_batch_size, q_name) + self.logger.info("Fetch states: %s, Set state: %s", self.fetch_state, + self.set_state) + + is_first_batch = True + consecutive_empty_fetches = 0 + max_empty_fetches = int(os.environ.get('TUNA_MAX_EMPTY_FETCHES', 3)) + poll_interval = int(os.environ.get("TUNA_POLL_INTERVAL", 60)) + loop_iteration = 0 + + while True: + loop_iteration += 1 + # Only log iteration every 10 iterations or when something interesting happens + if loop_iteration % 10 == 1: + self.logger.info("=== Enqueue loop iteration %d ===", loop_iteration) + + # 1. Check if we should wait for progress (skip on first batch) + # Database query now provides accurate state, no reconciliation needed + if not is_first_batch and self._should_wait_for_progress(job_batch_size): + self.logger.info( + "Waiting for batch progress (iteration %d)", + loop_iteration) + # Reset consecutive_empty_fetches since we're waiting for progress, not out of jobs + consecutive_empty_fetches = 0 + time.sleep(poll_interval) + continue + + # 2. Fetch jobs with built-in retry logic + job_list = self._fetch_jobs_with_retry(job_batch_size) + # Only log fetch details when jobs are found or on errors + if job_list: + self.logger.info("Fetched %d jobs (iteration %d)", len(job_list), loop_iteration) + + # 3. Handle empty results + if not job_list: + consecutive_empty_fetches += 1 + self.logger.warning( + 'No jobs found (attempt %d/%d) - iteration %d', + consecutive_empty_fetches, max_empty_fetches, loop_iteration) + + # Check if jobs are being skipped due to database locks + if consecutive_empty_fetches == 2: # After 2nd empty fetch + self.logger.warning( + "Checking for locked jobs that may be blocking progress (iteration %d)", + loop_iteration) + with DbSession() as lock_check_session: + if hasattr(self, 'detect_and_handle_locked_jobs'): + try: + handled = self.detect_and_handle_locked_jobs( + lock_check_session, list(self.fetch_state)) + if handled: + self.logger.info( + "Handled locked jobs, resetting empty fetch counter") + consecutive_empty_fetches = 0 # Reset counter to retry + continue + else: + self.logger.info("No locked jobs found to handle") + except Exception as lock_err: # pylint: disable=broad-exception-caught + self.logger.error("Error checking for locked jobs: %s", lock_err) + else: + self.logger.warning( + "detect_and_handle_locked_jobs method not available") + + if consecutive_empty_fetches >= max_empty_fetches: + self.logger.warning( + 'EXITING: No more jobs available after %d attempts (iteration %d). Exiting enqueue loop.', + max_empty_fetches, loop_iteration) + self.logger.info("Final state - claimed: %d, completed: %d", + len(self.claimed_job_ids), len(self.completed_job_ids)) + return + + self.logger.info("Sleeping for %d seconds before retry (iteration %d)...", + poll_interval, loop_iteration) + time.sleep(poll_interval) + continue + + # 4. Process the batch + self.logger.info("Processing batch of %d jobs (iteration %d)", len(job_list), + loop_iteration) + consecutive_empty_fetches = 0 + self._process_job_batch(job_list, job_counter, q_name) + is_first_batch = False + self.logger.info("Batch processed successfully (iteration %d)", loop_iteration) + + def cleanup_completed_jobs(self): + """Periodically clean up old job tracking data + + Since we now query the database for accurate progress tracking, + we only need to keep claimed_job_ids from growing too large. + The completed_job_ids list is kept for Redis consumer compatibility + but is not used for progress decisions. + """ + # Keep claimed_job_ids list from growing indefinitely + max_tracking_size = 10000 + if len(self.claimed_job_ids) > max_tracking_size: + # Query database to find which claimed jobs are actually complete + with DbSession() as session: + try: + claimed_list = list(self.claimed_job_ids) + id_str = ','.join(map(str, claimed_list)) + + # Get IDs of jobs that are complete + query = f""" + SELECT id FROM {self.dbt.job_table.__tablename__} + WHERE id IN ({id_str}) + AND state IN ('evaluated', 'errored', 'completed') + """ + completed_ids = {row[0] for row in session.execute(text(query)).fetchall()} + + # Keep only jobs that are still in progress + active_jobs = [job_id for job_id in claimed_list if job_id not in completed_ids] + + # Update the list + del self.claimed_job_ids[:] + self.claimed_job_ids.extend(active_jobs) + + self.logger.info( + "Cleaned up tracking: removed %d completed jobs, kept %d active jobs", + len(completed_ids), len(active_jobs)) + + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error during cleanup: %s", err) async def cleanup_redis_results(self, prefix): """Remove stale redis results by key""" @@ -366,25 +646,25 @@ async def cleanup_redis_results(self, prefix): keys = [] cursor = "0" if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key + # no prefix, match any key cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s old results', len(results)) + self.logger.info("Found %s old results", len(results)) for key in keys: try: await redis.delete(key) - except aioredis.exceptions.ResponseError as red_err: + except Exception as red_err: self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) + self.logger.info(key.decode("utf-8")) continue - self.logger.info('Done removing old redis results for prefix: %s', prefix) + self.logger.info("Done removing old redis results for prefix: %s", prefix) return True @@ -399,30 +679,30 @@ async def consume(self, job_counter, prefix): keys = [] while cursor != 0: if prefix: - #a prefix is necessary when the need to different results in redis based on operation - #withough a prefix the redis key defaults to: "celery-task-meta-" - #with a prefix the key will look like: "celery-task-meta--" - #the prefix can be applied when filtering the redis keys as bellow + # a prefix is necessary when the need to different results in redis based on operation + # withough a prefix the redis key defaults to: "celery-task-meta-" + # with a prefix the key will look like: "celery-task-meta--" + # the prefix can be applied when filtering the redis keys as bellow cursor, results = await redis.scan(cursor, match=f"*{prefix}*") else: - #no prefix, match any key + # no prefix, match any key cursor, results = await redis.scan(cursor, match="*") keys.extend(results) - self.logger.info('Found %s results', len(results)) + self.logger.info("Found %s results", len(results)) for key in keys: try: data = await redis.get(key) if data: - _ = await self.parse_result(data.decode('utf-8')) + _ = await self.parse_result(data.decode("utf-8")) await redis.delete(key) with job_counter_lock: job_counter.value = job_counter.value - 1 - except aioredis.exceptions.ResponseError as red_err: + except Exception as red_err: self.logger.error(red_err) - self.logger.info(key.decode('utf-8')) + self.logger.info(key.decode("utf-8")) await asyncio.sleep(1) - self.logger.info('Job counter reached 0') + self.logger.info("Job counter reached 0") await redis.close() return True @@ -434,33 +714,33 @@ def prep_tuning(self): q_name = None if self.operation == Operation.COMPILE: q_name = get_q_name(self, op_compile=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" #pylint: disable=line-too-long + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -n tuna_HOSTNAME_sess_{self.args.session_id} -Q {q_name}" # pylint: disable=line-too-long else: q_name = get_q_name(self, op_eval=True) - cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" #pylint: disable=line-too-long + cmd = f"celery -A tuna.celery_app.celery_app worker -l info -E -c 1 -n tuna_HOSTNAME_sess_{self.args.session_id}_gpu_id_GPUID -Q {q_name}" # pylint: disable=line-too-long - self.logger.info('celery Q name: %s', q_name) + self.logger.info("celery Q name: %s", q_name) if not self.args.enqueue_only: try: - self.logger.info('Launching celery workers for queue %s', q_name) + self.logger.info("Launching celery workers for queue %s", q_name) subp_list = launch_celery_worker(self.operation, cmd, self.args, True) - self.logger.info('Done launching celery workers') + self.logger.info("Done launching celery workers") if not subp_list: - raise CustomError('Could not launch celery worker') + raise CustomError("Could not launch celery worker") except kombu.exceptions.OperationalError as k_err: - self.logger.error('Redis error ocurred: %s', k_err) + self.logger.error("Redis error ocurred: %s", k_err) return False else: purge_queue([q_name]) return q_name, subp_list - #pylint: disable=too-many-locals + # pylint: disable=too-many-locals def tune(self, job_batch_size=1000): """tuning loop to spin out celery tasks""" if self.args.shutdown_workers: - self.logger.info('Shutting down all celery workers') + self.logger.info("Shutting down all celery workers") stop_active_workers() return True @@ -471,7 +751,7 @@ def tune(self, job_batch_size=1000): return False try: - #if enqueue_only is False, we launch the celery workers + # if enqueue_only is False, we launch the celery workers if not self.args.enqueue_only: for subp in subp_list: subp.wait() @@ -483,43 +763,41 @@ def tune(self, job_batch_size=1000): start = time.time() - #set job count to 1 until first job fetch is finished - job_counter = Value('i', 1) - try: - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - #Start enqueue proc - enqueue_proc.start() + # set job count to 1 until first job fetch is finished + job_counter = Value("i", 1) - #cleanup old results + # Create shared data structures for cross-process communication + manager = Manager() + self.claimed_job_ids = manager.list() # Shared list across processes + self.completed_job_ids = manager.list() # Shared list across processes + + try: + # cleanup old results cleanup_proc = Process(target=self.async_wrap, args=(self.cleanup_redis_results, self.prefix)) cleanup_proc.start() cleanup_proc.join() - #start async consume thread, blocking + # start async consume thread, blocking consume_proc = Process(target=self.async_wrap, args=(self.consume, job_counter, self.prefix)) - self.logger.info('Starting consume thread') + self.logger.info("Starting consume thread") consume_proc.start() - enqueue_proc.join() - #enqueue finished first fetch, remove hold on job_counter - with job_counter_lock: - job_counter.value = job_counter.value - 1 - - #check for new jobs - while consume_proc.is_alive(): - enqueue_proc = Process(target=self.enqueue_jobs, - args=[job_counter, job_batch_size, q_name]) - enqueue_proc.start() - enqueue_proc.join() - time.sleep(10) + # Start enqueue proc - let it run continuously with persistent state + enqueue_proc = Process(target=self.enqueue_jobs, + args=[job_counter, job_batch_size, q_name]) + enqueue_proc.start() + # Wait for both processes to complete naturally consume_proc.join() + enqueue_proc.join() - except (KeyboardInterrupt, Exception) as exp: #pylint: disable=broad-exception-caught - self.logger.error('Error ocurred %s', exp) + except ( + KeyboardInterrupt, + Exception, + ) as exp: # pylint: disable=broad-exception-caught + self.logger.error("Error ocurred %s", exp) purge_queue([q_name]) self.cancel_consumer(q_name) self.reset_job_state_on_ctrl_c() @@ -528,7 +806,7 @@ def tune(self, job_batch_size=1000): self.cancel_consumer(q_name) end = time.time() - self.logger.info("Took {:0>8} to tune".format( #pylint: disable=consider-using-f-string + self.logger.info("Took {:0>8} to tune".format( # pylint: disable=consider-using-f-string str(timedelta(seconds=end - start)))) return True @@ -542,38 +820,40 @@ def async_wrap(self, async_func, *args): try: asyncio.run(self.async_callback(async_func, *args)) except KeyboardInterrupt: - self.logger.warning('Keyboard interrupt caught, terminating') + self.logger.warning("Keyboard interrupt caught, terminating") def reset_job_state_on_ctrl_c(self): """Reset job state for jobs in flight""" temp_obj = SimpleDict() - temp_obj.session_id = self.args.session_id #pylint: disable=invalid-name - attribs = ['state'] + temp_obj.session_id = self.args.session_id # pylint: disable=invalid-name + attribs = ["state"] temp_obj.state = 1 - self.logger.info('Resetting job state in DB for in flight jobs') + self.logger.info("Resetting job state in DB for in flight jobs") if self.operation == Operation.COMPILE: state = 16 elif self.operation == Operation.EVAL: state = 12 - query = gen_update_query(temp_obj, attribs, - self.dbt.job_table.__tablename__, - [('session', self.args.session_id), - ('state', state)]) + query = gen_update_query( + temp_obj, + attribs, + self.dbt.job_table.__tablename__, + [("session", self.args.session_id), ("state", state)], + ) with DbSession() as session: - #pylint: disable=duplicate-code + # pylint: disable=duplicate-code def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True - #pylint: enable=duplicate-code + # pylint: enable=duplicate-code assert session_retry(session, callback, lambda x: x(), self.logger) - self.logger.info('Sucessfully reset job state') + self.logger.info("Sucessfully reset job state") return True return False @@ -598,7 +878,7 @@ def check_jobs_found(self, job_rows: List[SimpleDict], find_state: List[Any], """check for end of jobs""" if not job_rows: # we are done - self.logger.warning('No %s jobs found, session %s', find_state, + self.logger.warning("No %s jobs found, session %s", find_state, session_id) return False return True @@ -624,7 +904,7 @@ def get_context_list(self, session, batch_jobs): context_list: List[dict] = None serialized_jobs = self.serialize_jobs(session, batch_jobs) - #build context for each celery task + # build context for each celery task context_list = self.build_context(serialized_jobs) return context_list @@ -635,22 +915,73 @@ async def parse_result(self, data): with DbSession() as session: try: - fin_json = data['result']['ret'] - context = data['result']['context'] + fin_json = data["result"]["ret"] + context = data["result"]["context"] + + # Extract job ID from context to track completion + job_id = self.extract_job_id_from_context(context) + except KeyError as kerr: self.logger.error(kerr) return False - self.logger.info('Parsing: %s', fin_json) + self.logger.info("Parsing: %s", fin_json) if self.operation == Operation.COMPILE: self.process_compile_results(session, fin_json, context) elif self.operation == Operation.EVAL: self.process_eval_results(session, fin_json, context) else: - raise CustomError('Unsupported tuning operation') + raise CustomError("Unsupported tuning operation") + + # Update tracking after processing to get the final job state + if job_id and job_id in self.claimed_job_ids: + # Check the final state of the job after processing + final_state = self.get_job_final_state(session, job_id) + + if final_state in ['evaluated', 'errored']: + # Job is truly complete - append to completed list + self.completed_job_ids.append(job_id) + self.logger.info("Marked job %s as completed with state: %s", job_id, + final_state) + elif final_state == 'compiled': + # Job failed and was reset to compiled for retry + # Remove from claimed so it can be re-grabbed + try: + self.claimed_job_ids.remove(job_id) + self.logger.info( + "Job %s failed and reset to 'compiled' - removed from claimed list for retry", + job_id) + except ValueError: + # Job ID not in list, ignore + pass + else: + self.logger.warning("Job %s has unexpected final state: %s", job_id, + final_state) return True + def get_job_final_state(self, session, job_id): + """Query the database to get the current state of a job""" + try: + if self.dbt is not None: + query = f""" + SELECT state FROM {self.dbt.job_table.__tablename__} + WHERE id = {job_id} + """ + result = session.execute(text(query)).fetchone() + if result: + return result[0] + return None + except Exception as err: # pylint: disable=broad-exception-caught + self.logger.error("Error querying job state for job %s: %s", job_id, err) + return None + + def extract_job_id_from_context(self, context): + """Extract job ID from celery task context""" + # This needs to be implemented in the MIOpen subclass + # based on how job IDs are stored in the context + raise NotImplementedError("Subclass must implement job ID extraction") + def process_compile_results(self, session, fin_json, context): """Process result from fin_build worker""" raise NotImplementedError("Not implemented") diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index 49b81a502..a451fab00 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -770,11 +770,10 @@ def get_tables() -> List[BASE]: tables: List[BASE] = [] with DbSession() as session: engine = session.bind - connect = session.connection() def append_if_not_exists(table): - # Note: this changes in sqlalchemy 1.4. - if not inspect(engine).dialect.has_table(connect, table.__tablename__): + # Updated for SQLAlchemy 2.0 + if not inspect(engine).has_table(table.__tablename__): tables.append(table) append_if_not_exists(SessionRocMLIR()) diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py index 2964c6c50..df3a3c17b 100644 --- a/tuna/rocmlir/rocmlir_worker.py +++ b/tuna/rocmlir/rocmlir_worker.py @@ -35,6 +35,7 @@ import traceback from sqlalchemy.inspection import inspect +from sqlalchemy import text from tenacity import Retrying, stop_after_attempt, before_sleep_log, wait_random @@ -94,7 +95,7 @@ def update_result_table(self, session, result_str): self.logger.info('Inserting results for job_id=%s', self.job.id) query = gen_insert_query(obj, self.result_attr, self.dbt.results.__tablename__) - session.execute(query) + session.execute(text(query)) session.commit() return True diff --git a/tuna/utils/db_utility.py b/tuna/utils/db_utility.py index 400559c14..5c2dc1697 100644 --- a/tuna/utils/db_utility.py +++ b/tuna/utils/db_utility.py @@ -35,7 +35,7 @@ from typing import Callable, Any, List, Dict import pymysql from sqlalchemy.exc import OperationalError, IntegrityError, ProgrammingError -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from tuna.dbBase.sql_alchemy import DbSession from tuna.dbBase.base_class import BASE @@ -49,8 +49,7 @@ ENV_VARS = get_env_vars() ENGINE = create_engine(f"mysql+pymysql://{ENV_VARS['user_name']}:{ENV_VARS['user_password']}" +\ - f"@{ENV_VARS['db_hostname']}:3306/{ENV_VARS['db_name']}", - encoding="utf8") + f"@{ENV_VARS['db_hostname']}:3306/{ENV_VARS['db_name']}") def connect_db(): @@ -62,19 +61,25 @@ def connect_db(): raise ValueError('DB name must be specified in env variable: TUNA_DB_NAME') try: - ENGINE.execute(f'Use {db_name}') + with ENGINE.connect() as conn: + conn.execute(text(f'Use {db_name}')) + conn.commit() return except OperationalError: # as err: LOGGER.warning('Database %s does not exist, attempting to create database', db_name) try: - ENGINE.execute(f'Create database if not exists {db_name}') + with ENGINE.connect() as conn: + conn.execute(text(f'Create database if not exists {db_name}')) + conn.commit() except OperationalError as err: LOGGER.error('Database creation failed %s for username: %s', err, ENV_VARS['user_name']) - ENGINE.execute(f'Use {db_name}') - ENGINE.execute('SET GLOBAL max_allowed_packet=4294967296') + with ENGINE.connect() as conn: + conn.execute(text(f'Use {db_name}')) + conn.execute(text('SET GLOBAL max_allowed_packet=4294967296')) + conn.commit() def create_tables(all_tables): @@ -100,7 +105,8 @@ def create_indices(all_indices): with ENGINE.connect() as conn: for idx in all_indices: try: - conn.execute(idx) + conn.execute(text(idx)) + conn.commit() LOGGER.info('Idx created successfully: %s', idx) except (OperationalError, ProgrammingError) as oerr: LOGGER.info('%s \n', oerr) @@ -132,6 +138,20 @@ def session_retry(session: DbSession, return False +def sanitize_sql_string(value: str, max_length: int = 2000) -> str: + """Sanitize string for safe SQL insertion by escaping special characters""" + # Truncate to safe length to avoid excessively long queries + if len(value) > max_length: + value = value[:max_length] + '... [truncated]' + + # Escape backslashes first (must be done before quotes) + value = value.replace('\\', '\\\\') + # Escape single quotes by doubling them (SQL standard) + value = value.replace("'", "''") + + return value + + def get_attr_vals(obj, attr_list): """create the dictionary of values for the attribute list """ attr_vals = {} @@ -140,10 +160,14 @@ def get_attr_vals(obj, attr_list): if val is None: val = 'NULL' elif isinstance(val, (datetime, str)): - val = f"'{val}'" + # Sanitize and escape the string value + sanitized = sanitize_sql_string(str(val)) + val = f"'{sanitized}'" elif isinstance(val, bytes): val = val.decode('utf-8') - val = f"'{val}'" + # Sanitize and escape the string value + sanitized = sanitize_sql_string(val) + val = f"'{sanitized}'" else: val = str(val) attr_vals[attr] = val @@ -213,7 +237,7 @@ def get_job_rows(session, attribs, tablename, cond_str): LOGGER.info('Query Select: %s', query) try: - ret = session.execute(query) + ret = session.execute(text(query)) except (Exception, KeyboardInterrupt) as ex: #pylint: disable=broad-except LOGGER.warning(ex) ret = None @@ -245,7 +269,7 @@ def has_attr_set(obj, attribs): def get_class_by_tablename(tablename): """use tablename to find class""" # pylint: disable=protected-access - for class_name in BASE._decl_class_registry.values(): + for class_name in BASE.registry._class_registry.values(): if hasattr(class_name, '__tablename__') and class_name.__tablename__ == tablename: return class_name diff --git a/tuna/worker_interface.py b/tuna/worker_interface.py index 846d57af3..d7be120e0 100644 --- a/tuna/worker_interface.py +++ b/tuna/worker_interface.py @@ -44,6 +44,7 @@ from typing import List, Tuple, Union, Set, Optional, Any, Dict from sqlalchemy.exc import IntegrityError, OperationalError, NoInspectionAvailable from sqlalchemy.inspection import inspect +from sqlalchemy import text from tuna.dbBase.sql_alchemy import DbSession from tuna.machine import Machine @@ -283,7 +284,7 @@ def get_job(self, find_state: str, set_state: str, imply_end: bool) -> bool: job_set_attr = ['state'] query: str = gen_update_query(job, job_set_attr, self.dbt.job_table.__tablename__) - session.execute(query) + session.execute(text(query)) session.commit() self.job_queue_push(job_rows) @@ -349,7 +350,7 @@ def set_job_state(self, self.dbt.job_table.__tablename__) def callback() -> bool: - session.execute(query) + session.execute(text(query)) session.commit() return True