diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..190b04a --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +exclude = .venv/* +docstring-convention=google \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b5bb093 --- /dev/null +++ b/.gitignore @@ -0,0 +1,228 @@ +# Created by https://www.toptal.com/developers/gitignore/api/python,venv,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=python,venv,visualstudiocode + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### venv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/python,venv,visualstudiocode + +uv.lock +.ruff_cache + +# Data Folders +data + +# Model Storage +model_checkpoints/* + + +# testing/debugging notebooks +test.ipynb +buowset.ipynb + +# Question: do we want to commit vscode setting.json files? +settings.json + +# Block all configs besides the example config +whoot_model_training/configs +!whoot_model_training/configs/config.yml diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/README.md b/README.md index 3f18983..64d767b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,45 @@ # whoot Tools for capturing, analyzing, and parsing audio data + +# Installation Instructions + +## Default Python Instructions +1) Install Python>=3.10 +2) Create a virtual enviroment via `python -m venv` +3) Activate the enviroment using an activate script: + +- Windows: `.venv\Scripts\activate` +- macOS/Linux: `source .venv/bin/activate` + +If this works, you should see in your command line `(whoot)`. If not check https://docs.python.org/3/library/venv.html#how-venvs-work + +4) Run in project root `pip install -e .` + +To install optional dependencies run `pip install -e .[extra1,extra2,...]` + +Current support optional dependency collections include + +- `cpu`: Installs torch and torchvision for CPU use only +- `cu128`: Installs torch and torchvision with Cuda 12.8 Binaries +- `model-training`: Required for running scripts in `whoot/model_training`, make sure to add either `cpu` or `cu128` +- `dev`: Installs linters pylint and flake8. MUST be used by developers of whoot + +## Usage + +Once the enviroment is activated, you should be able to do `python path/to/script.py` to run any of the whoot scripts. If a script states a package is missing, you might not be using the virtual enviroment. + +# Developer Notes + +## Creating a new Project + +When adding a new package, like `assess_birdnet` to the whoot toolkit, add your package name to the `[tool.setuptools]` section of `pyproject.toml` + +### Linting + +Style guidelines are listed in `.flake8` and `pylintrc`. To use these tools do the following + +1) Follow the Installation Instructions, on pip install do `pip install -e .[dev,extra1,extra2,...]`. +2) Activate the environment + +To run the linters run `python -m flake8` and `python -m pylint --recursive y PATH/TO/FILES.py` +In order to contribute to whoot, both of these must be cleared. \ No newline at end of file diff --git a/cfgs/params_segment_2017_data.yaml b/cfgs/params_segment_2017_data.yaml new file mode 100644 index 0000000..f09ca5a --- /dev/null +++ b/cfgs/params_segment_2017_data.yaml @@ -0,0 +1,2 @@ +#length added to beginning and end of detection segment (ms) +padding: 100 diff --git a/comet_ml_panels/leaderboard.py b/comet_ml_panels/leaderboard.py new file mode 100644 index 0000000..7f73c6d --- /dev/null +++ b/comet_ml_panels/leaderboard.py @@ -0,0 +1,85 @@ +# """Creates the Leaderboard for Comet ML Panels + +# This script queries from a given Comet ML project a DataFrame of +# model metrics at each step for each model in the project +# Then displays the top models. + +# Example: +# This is not intended to be run locally. Please test on Comet-ML. + +# For Developers: +# For more on adding to this see docs at +# https://www.comet.com/docs/v2/guides/comet-ui/experiment-management/visualizations/python-panel/ + +# Note that updating this file does not update comet-ml. Please +# go into the project to update after pushing to GitHub. + +# Do not include Doc string in comet-ml... for some reason this +# is displayed in the comet-ml panel if copied directly +# """ +from comet_ml import API, APIExperiment, ui +import pandas as pd +import numpy as np + +# Initialize Comet API +api = API() + +# Select the experiments and metrics to compare +available_metrics = ["train/valid_cMAP", "train/valid_ROCAUC"] +selected_metric = ui.dropdown("Select a metric:", available_metrics) + +experiment_keys = api.get_panel_experiment_keys() +data = api.get_metrics_for_chart( + experiment_keys, metrics=[selected_metric], parameters=["task"]) + +# Given all experiments, find all possible tasks to measure! +available_tasks = list( + set(data[key]["params"]["task"] + for key in data if "task" in data[key]["params"]) +) +available_tasks.append(None) +selected_task = ui.dropdown("Select a Task:", available_tasks) + +processed_data = [] + +for key in data: + # Note, some of the early runs have no value for the task + # The following code handles those cases + TASK = None + if "task" in data[key]["params"]: + TASK = data[key]["params"]["task"] + + # Only display the leaderboard for tasks we want + # This CAN include runs with no task + if TASK is not selected_task and TASK != selected_task: + continue + + # Failed runs may not have metrics + if len(data[key]["metrics"]) == 0: + continue + + metric_values = data[key]["metrics"][0]["values"] + max_index = np.argmax(metric_values) + + processed_data.append({ + "experiment_name": data[key]["experimentName"], + "experiment_key": key, + selected_metric: max(metric_values), + "step": data[key]["metrics"][0]["steps"][max_index], + }) + +leaderboard_df = pd.DataFrame(processed_data).sort_values( + selected_metric, ascending=False) + +leaderboard_df["users"] = leaderboard_df["experiment_key"].apply( + lambda key: APIExperiment(previous_experiment=key).get_user() +) + +col_order = [ + "experiment_name", + selected_metric, + "experiment_key", + "step", + "users" +] +ui.display(leaderboard_df[col_order]) diff --git a/create_dataset/create_dataset.py b/create_dataset/create_dataset.py new file mode 100644 index 0000000..a0d036d --- /dev/null +++ b/create_dataset/create_dataset.py @@ -0,0 +1,134 @@ +"""Create dataset of burrowing owl vocalizations and noise. + +This script will parse through 2017 and 2018 human labeled +burrowing owl data. It will create a folder with segments of +labeled detections, and an equal number of noise samples from +the same wav files. It will create a CSV with metadata associated +with the segments. The metadata will include the UUID of the segment, +the label, the original filepath of the original wav the segment came +from, the path to the segment, and the start and end time of the +labeled detection relative to the original wav file. The labeled +segments will be the duration of the label, and the duration of the +noise will be fixed and consistent. The user of the dataset may choose +to pad the labeled detections if they need consistent length segments. + +Usage: + python3 create_dataset.py -labels /path/to/human/labeled.csv + -wav_dir /path/to/parent/dir/of/wavs/ + -output_dir /path/to/desired/output/dir/ + -class_list /path/to/classes.txt + +""" +import argparse +import ntpath +import os +import pandas as pd +from create_segments import get_paths, create_segments +from create_segments import create_noise_segments +from filter_labels import filter_labels_2017, filter_labels_2018 + + +def create_dataset(labels, wav_dir, output_dir, class_list): + """Creates labeled and non labeled segments and metadata. + + Creates segments based on human labeled data of a detection, + and then creates an equal number of randomized 'non-detection' + segments at fixed length. It cretaes a uuid for each segment + and spits out a metadata file that matches the segment to its + label, original wav file, relative start time to original wav, + and duration. + + Args: + labels (str): Path to label file. + wav_dir (str): Path to original wav segments of audio. + output_dir (str): Path to where the segments and metadata + will go. + class_list (str): Path to file containing the classes + seen in the human labels file that you want to create + segments for. Current format is ',' delimited list + in a .txt file. + """ + # parse the inputs + out_file = ntpath.dirname(output_dir) + result_file = os.path.join(out_file, "metadata.csv") + if os.path.exists(result_file): + all_data = pd.read_csv(result_file, index_col=0) + else: + all_data = pd.DataFrame() + # walk dir to list paths to each original wav file + wav_file_paths = get_paths(wav_dir) + # open human label file + labels = pd.read_csv(labels) + use_2017 = None + # iterate through each individual original wav + if "2017" in labels['DATE'].iloc[0]: + use_2017 = True + elif "2018" in labels['DATE'].iloc[0]: + use_2017 = False + wav_files = [] + num_samples = [] + for wav in wav_file_paths: + # check which label format to select parsing method + # create dataframe of only the labels that correspond to the wav + if use_2017: + filtered_labels = filter_labels_2017(wav, + labels) + else: + filtered_labels = filter_labels_2018(wav, + labels) + # output the labeled segments and return the dataframe of annotations + new_buow_rows = create_segments(wav, + filtered_labels, + output_dir, + class_list) + # create same number of noise segments from the same wav file randomly + all_buow_rows = create_noise_segments(wav, + new_buow_rows, + output_dir) + # add the annotations to the csv of metadata for the dataset + if not all_buow_rows.empty: + wavv = str(wav) + wav_files.append(wavv) + num_samples.append(len(all_buow_rows)) + all_data = pd.concat([all_data, all_buow_rows], ignore_index=True) + print("printing concated data") + print(all_data) + + all_data.index = all_data.index.astype(int) + all_data.to_csv(result_file) + intt = 0 + for wavs in wav_files: + print(f"{wavs} had {num_samples[intt]} including noise segments") + intt += 1 + print(f"Created results: {result_file}") + + +def main(labels, wav_dir, output_dir, class_list): + """Main script to run create dataset. + + Args: + labels (str): Path to label file. + wav_dir (str): Path to original wav segments of audio. + output_dir (str): Path to where the segments and metadata + will go. + class_list (str): Path to file containing the classes + seen in the human labels file that you want to + create segments for. + """ + create_dataset(labels, wav_dir, output_dir, class_list) + + +if __name__ == "__main__": + PARSER = argparse.ArgumentParser( + description='Input Directory Path' + ) + PARSER.add_argument('-labels', type=str, + help='Path to human labeled csv') + PARSER.add_argument('-wav_dir', type=str, + help='Path to directory containing wav files.') + PARSER.add_argument('-output_dir', type=str, + help='Path to desired directory for segments.') + PARSER.add_argument('-class_list', type=str, + help='Path to txt file of list of labeled classes') + ARGS = PARSER.parse_args() + main(ARGS.labels, ARGS.wav_dir, ARGS.output_dir, ARGS.class_list) diff --git a/create_dataset/create_segments.py b/create_dataset/create_segments.py new file mode 100644 index 0000000..c80eaf5 --- /dev/null +++ b/create_dataset/create_segments.py @@ -0,0 +1,167 @@ +"""Functions to create segments of detections of interest from wavs. + +Functions get called in create_dataset.py. +""" +import os +import uuid +import csv +import pandas as pd +from pydub import AudioSegment, exceptions +import numpy as np + + +def get_paths(home_dir): + """Obtain paths to every wav in the directory provided. + + Args: + home_dir (str): Path to directory containing original wavs. + + Returns: + list: List of all the full paths to a wav in + the given directory. + """ + wavs_file_paths = [] + for path, _, files in os.walk(home_dir): + for file in files: + if file.endswith('.wav'): + new_file = os.path.join(path, file) + wavs_file_paths.append(new_file) + return wavs_file_paths + + +def create_segments(wav, filtered_labels, out_path, class_list): + """Create the labeled segments. + + Args: + wav (str): Path to current wav file in loop. + + filtered_labels (pd.Dataframe): The human label file reduced + to only contain the rows of detections pertinent to the + wav of interest. + out_path (str): Path to directory where segment will be saved. + class_list (str): Path to the class list that you'd like segments + to be created for. What the manual ID's are in the human + label file- will ignore everything that is misspelled or + unknown labels. + + Returns: + pd.Dataframe: The metadata now associated with the + created segments for a given wav file. + """ + print(f"creating segments for {wav}") + if filtered_labels is None: + print(f"skipping segment creation for {wav} because " + "it does not have labels or is not a file of interest") + return None + if filtered_labels.empty: + print("filtered labels is an empty dataframe, " + "meaning either the sound file was not " + "labeled or has no detections") + return None + output_rows = pd.DataFrame(columns=['segment', + 'label', + 'segment_path', + 'original_path', + 'segment_duration_s', + 'segment_rel_start_ms']) + with open(class_list, 'r', newline='', encoding='utf-8') as file: + reader = csv.reader(file) + classes = next(reader) + print(classes) + try: + audio = AudioSegment.from_wav(wav) + except exceptions.CouldntDecodeError: + print(f"Couldn't decode: {wav}, moving to next file") + filtered_labels['MANUAL ID*'] = filtered_labels['MANUAL ID*'].str.lower() + print(filtered_labels) + df_row = 0 + for _, row in filtered_labels.iterrows(): + for call_type in classes: + if row['MANUAL ID*'] == call_type: + start_time = float(row['OFFSET']) + end_time = start_time + float(row['DURATION']) + start_time = start_time * 1000 + end_time = end_time * 1000 + segment = audio[start_time:end_time] + segment_id = uuid.uuid4() + segment_id = str(segment_id) + '.wav' + segment_path = os.path.join(out_path, segment_id) + segment.export(segment_path, format='wav') + output_rows.loc[df_row] = [segment_id, + call_type, + segment_path, + wav, + float(row['DURATION']), + start_time] + df_row += 1 + else: + continue + return output_rows + + +def create_noise_segments(wav, new_buow_rows, out_path): + """Create 'no_buow' segments. + Randomly select an equal number of 3s noise segments to + the number of detections per audio file, a buffer length + away from all of the detections in the file. + + Args: + wav (str): The path to the given wav. + new_buow_rows (pd.Dataframe): The human labeled detection + segment metadata for the given wav. + out_path (str): The directory where the new no_buow segments will + go to join the human labeled segments. + + Returns: + pd.Dataframe: The metadata for the detection as well as + the no_buow segments created from the given wav. + """ + if new_buow_rows is None: + print(f"not creating noise segments from {wav} because " + "there were no labels or no associated labels") + all_buow_rows = pd.DataFrame() + return all_buow_rows + try: + audio = AudioSegment.from_wav(wav) + # duration in seconds, cutting off the ms + duration = int(len(audio) / 1000) + except exceptions.CouldntDecodeError: + print(f"Couldn't decode: {wav}, moving to next file") + call_type = "no_buow" + num = len(new_buow_rows) * 2 + seconds_array = np.zeros(duration) + for _, row in new_buow_rows.iterrows(): + start = int((row['segment_rel_start_ms'] / 1000) - 1) + end = int((row['segment_rel_start_ms'] / 1000) + + row['segment_duration_s']) + mask_start = max(0, start - 30) + mask_end = min(len(seconds_array), end + 30 + 1) + seconds_array[mask_start:mask_end] = 1 + new_sample = num / 2 + while num > new_sample: + try: + random_index = np.random.choice(len(seconds_array)-3) + except ValueError: + print(f"{wav} is not long enough to generate no_buow sounds, " + "keeping the detection segment but adding no no_buow") + return new_buow_rows + if (seconds_array[random_index] == 0 and + seconds_array[random_index + 3] == 0): + start_time = (random_index + 1) * 1000 + end_time = (random_index + 4) * 1000 + segment = audio[start_time:end_time] + duration_of_segment = len(segment) / 1000 + segment_id = uuid.uuid4() + segment_id = str(segment_id) + '.wav' + segment_path = os.path.join(out_path, segment_id) + segment.export(segment_path, format='wav') + new_buow_rows.loc[new_sample] = [segment_id, + call_type, + segment_path, + wav, + duration_of_segment, + start_time] + new_sample += 1 + + all_buow_rows = new_buow_rows + return all_buow_rows diff --git a/create_dataset/filter_labels.py b/create_dataset/filter_labels.py new file mode 100644 index 0000000..d6d9373 --- /dev/null +++ b/create_dataset/filter_labels.py @@ -0,0 +1,96 @@ +"""Correlating the wav paths with the labels for 2017 and 2018. + +The label file format is different for the 2018 and 2017 label +files. This means we use different information in those files to +ensure the wav file we found in the folder corresponds to the +label in the label file. Depending on the label file, one +of these two functions gets called to ensure we're dealing +with the proper wav file and only the labels that correspond +to that wav file. +""" +import os +import ntpath + + +def filter_labels_2017(wav, labels): + """Filter labels from 2017 data. + + Args: + wav (str): The current wav file. + labels (pd.DataFrame): All of the labels. + + Returns: + pd.DataFrame: The labels associated with the wav of interest. + """ + file_name = ntpath.basename(wav) + # isolate labels that match the wav basename + filtered_labels = labels[labels['IN FILE'] == file_name] + index_drop = [] + wav = str(wav) + # ensure the labels match the site and burrow name of wav file + # this step is crucial, it catches accidential duplicates of wav files + for index, row in filtered_labels.iterrows(): + burrow = row['Burrow'] + bur = burrow[:-1] + site = burrow[-1:] + if bur not in wav: + print(f"{bur} is not in {wav}") + index_drop.append(index) + if site not in wav: + print(f"{site} is not in {wav}") + index_drop.append(index) + + filtered_labels = filtered_labels.drop(index_drop) + return filtered_labels + + +def filter_labels_2018(wav, labels): + """Filter labels from 2018 data. + + Args: + wav (str): The current wav file. + labels (pd.DataFrame): All of the labels. + + Returns: + pd.DataFrame: The labels associated with the wav of interest. + """ + file_name = ntpath.basename(wav) + path_name = ntpath.dirname(wav) + basepath = os.path.basename(path_name) + if basepath in ('ClassificationResults', 'Classification_Results'): + print(f"skipping {wav} because it's basepath is {basepath}") + # skipping extra wav files that exist as duplicates in these sub dirs + return None + # some of the folders have an underscore and some do not + path_labels = [] + path_labels.append(path_name + "/ClassificationResults/") + path_labels.append(path_name + "/Classification_Results/") + path_to_results = None + # checking if it's the one with an underscore vs not + for path in path_labels: + exists = os.path.exists(path) + if exists is True: + path_to_results = path + else: + print(f"{path} does not exist") + continue + if path_to_results is None: + print(f"skipping {wav} because it's not a file of interest") + return None + filtered_labels = labels[labels['IN FILE'] == file_name] + index_to_drop = [] + # iterating the columns in labels that match the wav file name + for index, row in filtered_labels.iterrows(): + stripped = row['Fled_2018_LS133_SM1.csv '].strip() + check_path = os.path.join(path_to_results, stripped) + if os.path.isfile(check_path): + continue + if stripped == 'EarBreed_2018_LS128_SM10A.csv': + check_path = os.path.join(path_to_results, + 'EarBreed_LS128_SM10A.csv') + if not os.path.isfile(check_path): + index_to_drop.append(index) + else: + index_to_drop.append(index) + filtered_labels = filtered_labels.drop(index_to_drop) + return filtered_labels diff --git a/create_dataset/k_fold_split_copy.py b/create_dataset/k_fold_split_copy.py new file mode 100644 index 0000000..be7e47b --- /dev/null +++ b/create_dataset/k_fold_split_copy.py @@ -0,0 +1,218 @@ +"""Optimizing k-fold splits with groups. + +These functions aid strat_k_folds.py in calculating the +optimal fold allocation for all the groups in the dataset +ensuring the folds are as equal in size as they can be, +while also being as close to the actual class distribution +as possible. + +Downloaded and modified from: +https://github.com/joaofig/strat-group-split/tree/main +""" +from typing import Set, Tuple +import numpy as np +from numpy.random import default_rng +from numba import njit + + +@njit +def calculate_cost(problem: np.ndarray, + solution: np.ndarray, + k: int) -> float: + """Calculate difference of current solution to optimal solution. + + Args: + problem (np.array): A matrix with a column per class and the + class counts for each group as the values. + solution (np.ndarray): A 1D array where each value is the current + fold allocation for the corresponding group. + k (int): Number of folds. + + Returns: + float: The summation of the differences between the folds' + class distributions from the optimal class distribution, and the + size of the folds to the size the folds should be. + """ + cost = 0.0 + total = np.sum(problem) + class_sums = np.sum(problem, axis=0) + num_classes = problem.shape[1] + + for i in range(k): + idx = solution == i + fold_sum = np.sum(problem[idx, :]) + + # Start by calculating the fold imbalance cost + cost += (fold_sum / total - 1.0 / k) ** 2 + + # Now calculate the cost associated with the class imbalances + # Katie: had to add division by 0 error for if fold_sums equal 0 + # there were no chick begging calls during test so this row was 0 + for j in range(num_classes): + if fold_sum == 0: + cost += (0 - class_sums[j] / total) ** 2 + else: + sum_problem = np.sum(problem[idx, j]) / fold_sum + cost += (sum_problem - class_sums[j] / total) ** 2 + return cost + + +@njit +def generate_search_space(problem: np.ndarray, + solution: np.ndarray, + k: int) -> np.ndarray: + """Generate the search space. + + Args: + problem (np.ndarray): A matrix with a column per class and the + class counts for each group as the values. + solution (np.ndarray): The last known solution. + k (int): Number of folds. + + Returns: + np.ndarray: The search space. Folds as columns and cost values + for each group with a placeholder in one fold each to allow for + a cost calculation relative to the placeholder. + """ + num_groups = problem.shape[0] + + space = np.zeros((num_groups, k)) + sol = solution.copy() + + for i in range(num_groups): + for j in range(k): + if solution[i] == j: + space[i, j] = np.inf + else: + sol[i] = j + space[i, j] = calculate_cost(problem, sol, k) + sol[i] = solution[i] + return space + + +@njit +def solution_to_str(solution: np.ndarray) -> str: + """Convert the solution to a string. + + Args: + solution (np.ndarray): The current solution. + Returns: + str: The current solution as a string. + """ + return "".join([str(n) for n in solution]) + + +def generate_initial_solution(problem: np.ndarray, + k: int, + algo: str = "k-bound") -> np.ndarray: + """Generate the first solution. + + Args: + problem (np.array): A matrix with a column per class and the + class counts for each group as the values. + k (int): The number of folds. + algo (str): Method for creating initial solution. Defaults to a + greedy algorithm to satisfy fold proportion requirements only. + + Returns: + np.ndarray: A 1D array where each value is the current fold + allocation for the corresponding group. + """ + num_groups = problem.shape[0] + if algo == "k-bound": + rng = default_rng() + total = np.sum(problem) + indices = rng.permutation(problem.shape[0]) + + solution = np.zeros(num_groups, dtype=int) + current_fold = 0 + fold_total = 0 + for i in indices: + group = np.sum(problem[i, :]) + if fold_total + group < total / k: + fold_total += group + else: + current_fold = (current_fold + 1) % k + fold_total = group + solution[i] = current_fold + elif algo == "random": + rng = default_rng() + solution = rng.integers(low=0, high=k, size=num_groups) + elif algo == "zeros": + solution = np.zeros(num_groups, dtype=int) + else: + raise Exception("Invalid algorithm name") + return solution + + +def solve(problem: np.ndarray, + k=5, + min_cost=1e-5, + max_retry=100, + verbose=False) -> np.ndarray: + """Solve the problem. + + Args: + problem (np.ndarray): The problem matrix. + k (int): Number of folds, default 5. + min_cost (float): The largest the cost can be for an + acceptable solution. Default 1e-5. + max_retry (int): The max amount of times the program will + attempt to alter the current solution for a more + optimal one. + verbose (bool): True for more debug prints, defaults to False. + + Returns: + np.ndarray: Optimized solution as a 1D array where each + value is the fold allocation for each group. + """ + hist = set() + retry = 0 + + solution = generate_initial_solution(problem, k) + incumbent = solution.copy() + low_cost = calculate_cost(problem, solution, k) + cost = 1.0 + while retry < max_retry and cost > min_cost: + decision = generate_search_space(problem, solution, k=5) + grp, cls = select_move(decision, solution, hist) + + if grp != -1: + solution[grp] = cls + cost = calculate_cost(problem, solution, k=5) + if cost < low_cost: + low_cost = cost + incumbent = solution.copy() + retry = 0 + if verbose: + print(cost) + else: + retry += 1 + hist.add(solution_to_str(solution)) + return incumbent + + +def select_move(decision: np.ndarray, + solution: np.ndarray, + history: Set) -> Tuple: + """Select the change to make to the current solution. + + Args: + decision (np.ndarray): The current search space matrix. + solution (np.ndarray): The current solution. + history (Set): Previous solutions. + Returns: + Tuple: Position in the solution matrix to move a group + into a different fold. + """ + candidates = np.argsort(decision, axis=None) + + for candidate in candidates: + position = np.unravel_index(candidate, decision.shape) + sol = solution.copy() + sol[position[0]] = position[1] + sol_str = solution_to_str(sol) + + if sol_str not in history: + return position[0], position[1] + return -1, -1 # No move found! diff --git a/create_dataset/strat_k_folds.py b/create_dataset/strat_k_folds.py new file mode 100644 index 0000000..0de2eb0 --- /dev/null +++ b/create_dataset/strat_k_folds.py @@ -0,0 +1,90 @@ +"""Split buowset into stratified k-folds. + +Groups detections from the same wav file into 'groups' +and then determines the overall class distribution and +the class distribution for each 'group'. It allocates +all the groups to a 'fold' in a way where the folds +are roughly the same class distribution as the overall +dataset. + +Usage: + python3 strat_k_folds.py /path/to/metadata.csv +""" +import argparse +import pandas as pd +import numpy as np + + +from k_fold_split_copy import solve + + +def create_strat_folds(df): + """Create grouped stratified k-folds. + + Args: + df (pd.Dataframe): The metadata csv from when the dataset was created. + + Returns: + pd.DataFrame: The same metadata but with labels as ints and a new fold + column to denote the fold that segment is apart of. + """ + num_classes = 6 + original_df = df + df['label'] = df['label'].replace('cluck', 0) + df['label'] = df['label'].replace('coocoo', 1) + df['label'] = df['label'].replace('twitter', 2) + df['label'] = df['label'].replace('alarm', 3) + df['label'] = df['label'].replace('chick begging', 4) + df['label'] = df['label'].replace('no_buow', 5) + # group is the subset of the index which is the wav file they all come from + grouped = df.groupby('original_path') + group_names = [] + group_matrix = [] + for index, group in grouped: + counts = np.zeros(num_classes, dtype=int) + label_counts = group['label'].value_counts() + for label, count in label_counts.items(): + counts[int(label)] = count + group_matrix.append(counts) + group_names.append(index) + problem = np.array(group_matrix) + solution = solve(problem, k=5, verbose=True) + # the fold allocation for each 'group' + print(f"solution {solution}") + print(np.sum(problem, axis=0) / np.sum(problem)) + folds = [problem[solution == i] for i in range(5)] + fold_percents = np.array( + [np.sum(folds[i], axis=0) / np.sum(folds[i]) for i in range(5)] + ) + # the % of each class in each fold + print(f"Fold percents: {fold_percents}") + print(folds) + grouped_original = original_df.groupby('original_path') + df_with_folds = pd.DataFrame() + count = 0 + for i, group in grouped_original: + group['fold'] = solution[count] + df_with_folds = pd.concat([df_with_folds, group], ignore_index=True) + count += 1 + return df_with_folds + + +def main(meta): + """Execute main script. + + Args: + meta (str): Path to metadata csv from creating the dataset. + """ + df = pd.read_csv(meta, index_col=0) + df_with_folds = create_strat_folds(df) + df_with_folds.to_csv("5-fold_meta.csv") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Input Directory Path' + ) + parser.add_argument('meta', type=str, + help='Path to metadata csv') + args = parser.parse_args() + main(args.meta) diff --git a/create_dataset/tests/label_2017_wavs.py b/create_dataset/tests/label_2017_wavs.py new file mode 100644 index 0000000..b0fc5e9 --- /dev/null +++ b/create_dataset/tests/label_2017_wavs.py @@ -0,0 +1,167 @@ +"""Create human labeled audio segments. + +Using a CSV with human labels across a large dataset, we can +find the segments in the audio files that correspond to a +burrowing owl call as labeled by a human labeler. We can then +segment these audio chunks into a folder so that we can use +them to easily train other models. We can also do the same +for the rest of the data to obtain segments with no bird +call labels, to provide another class in the same domain +as our bird vocalizations. As there are significantly more +negatives than positives, we can choose if we'd like to get +the same number output or select a higher or lower amount. + +Example: + + $ python segment_labeled_2017_data.py /path/to/human_labels.csv \ + /path/to/directory/of/wavs/ /path/to/directory/output/ /path/to/config.yaml + +""" + +import argparse +import yaml +import os +import pandas as pd +from pydub import AudioSegment, exceptions + + +def read_configs(config): + """reading in config file variables + + """ + with open(config, "r", encoding='utf-8') as cfg: + configs = yaml.load(cfg, Loader=yaml.SafeLoader) + + return configs + +def create_bird_segments(labels, wavs, output, config): + """Create human labeled dataframes. + + Main script to create csvs of human labeled data for each + wav file of interest. + + Args: + labels (str): The path to human labeled csv. + wavs (str): The path to all audio files. + output (str): The path to directory where each csv will + output (1 for each wav). + + """ + os.makedirs(output, exist_ok=True) + + scored_data = pd.read_csv(labels) + output = output + "bird_sounds/" + os.makedirs(output, exist_ok=True) + + configs = read_configs(config) + padding = configs['padding'] + padding = int(padding) + + for audio_file in os.listdir(wavs): + if audio_file.endswith('.wav'): + audio_path = os.path.join(wavs, audio_file) + + filtered_data = scored_data[scored_data['IN FILE'] == audio_file] + try: + bird_sound = AudioSegment.from_wav(audio_path) + except exceptions.CouldntDecodeError: + print(f"Counldn't decode: {audio_path}, moving to next file.") + continue + segment_index = 0 + for _, row in filtered_data.iterrows(): + if row['TOP1MATCH'] != 'null': + start_time = float(row['OFFSET']) + end_time = (start_time + float(row['DURATION'])) + start_time = start_time * 1000 + end_time = end_time * 1000 + start_time = start_time - padding + end_time = end_time + padding + segment = bird_sound[start_time:end_time] + output_file = os.path.join( + output, f'{os.path.splitext(audio_file)[0]}_segment_{segment_index}.wav' + ) + segment.export(output_file, format='wav') + segment_index += 1 + + print("Processing complete!") + +def create_no_bird_segments(labels, wavs, output): + """Create no bird call audio segments. + + """ + os.makedirs(output, exist_ok=True) + + scored_data = pd.read_csv(labels) + output = output + "no_bird_sounds/" + os.makedirs(output, exist_ok=True) + + for audio_file in os.listdir(wavs): + if audio_file.endswith('.wav'): + audio_path = os.path.join(wavs, audio_file) + + try: + time_series, sample_rate = librosa.load(audio_path, sr=None) + audio_duration = librosa.get_duration(y=time_series, sr=sample_rate) + except Exception as err: + print(f"Error processing {audio_file}: {err}") + continue + + total_chunks = int(audio_duration // 3) + 1 + chunks_data = { + 'Chunk Start': [i * 3 for i in range(total_chunks)], + 'Chunk End': [(i + 1) * 3 for i in range(total_chunks)], + 'Label': ['no'] * total_chunks + } + chunks_df = pd.DataFrame(chunks_data) + + filtered_data = scored_data[scored_data['IN FILE'] == audio_file] + + for _, row in filtered_data.iterrows(): + if row['TOP1MATCH'] != 'null': + start_time = float(row['OFFSET']) + end_time = start_time + float(row['DURATION']) + + for i in range(len(chunks_df)): + chunk_start = chunks_df.loc[i, 'Chunk Start'] + chunk_end = chunks_df.loc[i, 'Chunk End'] + if start_time < chunk_end and end_time > chunk_start: + chunks_df.loc[i, 'Label'] = 'bird' + + bird_sound = AudioSegment.from_wav(audio_path) + segment_index = 0 + for i in range(len(chunks_df)): + if chunks_df.loc[i, 'Label'] == 'no': + chunk_start = chunks_df.loc[i, 'Chunk Start'] * 1000 + chunk_end = chunks_df.loc[i, 'Chunk End'] * 1000 + segment = bird_sound[chunk_start:chunk_end] + + output_file = os.path.join( + output, f'{os.path.splitext(audio_file)[0]}_nobird_segment_{segment_index}.wav' + ) + segment.export(output_file, format='wav') + segment_index += 1 + + print("Processing complete!") + +def main(labels, wavs, output, config_file): + """Run main script + + """ + create_bird_segments(labels, wavs, output, config_file) + #create_no_bird_segments(labels, wavs, output) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Input Directory Path' + ) + parser.add_argument('labels', type=str, + help='Path to human labeled csv') + parser.add_argument('wavs', type=str, + help='Path to all wav files that have been labeled') + parser.add_argument('output', type=str, + help='Path to desired directory for output csvs') + parser.add_argument('config_file', type=str, + help='Path to config file') + args = parser.parse_args() + main(args.labels, args.wavs, args.output, args.config_file) diff --git a/make_model/buowset/embed_to_df_birdnet.py b/make_model/buowset/embed_to_df_birdnet.py index 29f400e..e74d04f 100644 --- a/make_model/buowset/embed_to_df_birdnet.py +++ b/make_model/buowset/embed_to_df_birdnet.py @@ -35,10 +35,11 @@ def obtain_birdnet_embeddings(embeds): filename = ntpath.basename(embed) filename = filename.replace(".birdnet.embeddings.txt", ".wav") dfb = pd.read_csv(embed, - delimiter="[,\t]", + delimiter="[\t]", engine='python', header=None) - dfb_stripped = dfb.drop(dfb.columns[:2], axis=1) + dfb[2] = dfb[2].apply(lambda x: [float(i) for i in x.split(',') if i]) + dfb_stripped = dfb.iloc[:, 2:] flattened = dfb_stripped.values.flatten() if len(flattened) > 1024: print(f"filename {filename} has extra lines. Trunicating") @@ -62,7 +63,7 @@ def merge_dfs(metadata, embed_dict): embed_df.index.name = 'segment' df_merged = metadata.merge(embed_df, on='segment') df_merged = df_merged.drop(columns=['segment_duration_s']) - + df_merged = df_merged.rename(columns={0: 'embedding'}) return df_merged diff --git a/make_model/buowset/make_svm.py b/make_model/buowset/make_svm.py index ed3024a..0d8c298 100644 --- a/make_model/buowset/make_svm.py +++ b/make_model/buowset/make_svm.py @@ -61,11 +61,9 @@ def make_x_and_y(embed_df): train_df = embed_df[embed_df['fold'].isin(TRAINING_FOLDS)] test_df = embed_df[embed_df['fold'].isin(TESTING_FOLDS)] - embedding_cols = embed_df.select_dtypes(include='float64').columns.tolist() - - x_train = train_df[embedding_cols].values + x_train = list(train_df['embedding'].values) y_train = train_df['binary_label'].values - x_test = test_df[embedding_cols].values + x_test = list(test_df['embedding'].values) y_test = test_df['binary_label'].values return x_train, y_train, x_test, y_test diff --git a/make_model/make_perch_embeddings.py b/make_model/make_perch_embeddings.py new file mode 100644 index 0000000..02b634b --- /dev/null +++ b/make_model/make_perch_embeddings.py @@ -0,0 +1,122 @@ +''' +Create Perch Embeddings Script + +This script processes a directory of audio chunks (.wav files), +creates perch embeddings, and stores the results as a sqlite database + +Usage: + python make_perch_embeddings.py dataset_name path/to/directory/of/wavs + path/to/desired/output/dir + +Outputs: + hoplite.sqlite + usearch.index + +Note: + this code requires: + python, version 3.10+ + numpy, version 1.2+ + tensorflow, version 2+ +''' + +import argparse +from etils import epath + +from perch_hoplite.agile import colab_utils +from perch_hoplite.agile import embed +from perch_hoplite.agile import source_info +from perch_hoplite.db import interface + +def create_embeddings(dataset_name, wavs, output): + ''' + creates perch embeddings + + Args: + dataset_name (str): name of dataset being embedded + wavs (str): path to directory containing .wav audio segments + output (str): path to directory for output files (SQLite DB) + + Returns: + None + ''' + + dataset_base_path = wavs + dataset_fileglob = '*.wav' + db_path = output + model_choice = 'perch_8' + + use_file_sharding = True + + audio_glob = source_info.AudioSourceConfig( + dataset_name=dataset_name, + base_path=dataset_base_path, + file_glob=dataset_fileglob, + min_audio_len_s=1.0, + target_sample_rate_hz=-2, + shard_len_s=60.0 if use_file_sharding else None, + ) + + configs = colab_utils.load_configs( + source_info.AudioSources((audio_glob,)), + db_path, + model_config_key=model_choice, + db_key='sqlite_usearch') + + # Initialize DB + db = configs.db_config.load_db() + num_embeddings = db.count_embeddings() + print('Initialized DB located at ', configs.db_config.db_config.db_path) + + def drop_and_reload_db() -> interface.HopliteDBInterface: + db_path = epath.Path(configs.db_config.db_config.db_path) + for fp in db_path.glob('hoplite.sqlite*'): + fp.unlink() + (db_path / 'usearch.index').unlink() + print('\n Deleted previous db at: ', + configs.db_config.db_config.db_path) + + if num_embeddings > 0: + print('Existing DB contains datasets: ', db.get_dataset_names()) + print('num embeddings: ', num_embeddings) + print(f'This will permanently delete all {num_embeddings} ' + 'embeddings from the existing database.\n') + drop_and_reload_db() + + # Run embedding + print(f'Embedding dataset: {audio_glob.dataset_name}') + + worker = embed.EmbedWorker( + audio_sources=configs.audio_sources_config, + db=db, + model_config=configs.model_config) + + worker.process_all(target_dataset_name=audio_glob.dataset_name) + + print('\n\nEmbedding complete! \nTotal embeddings: ', db.count_embeddings()) + print(f'Embeddings dataset saved at: \n ' + f'\t{output}/hoplite.sqlite \n ' + f'\t{output}/usearch.index') + +def main(dataset_name, wavs, output): + ''' + run main script + ''' + + create_embeddings(dataset_name, wavs, output) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='Input Directory Paths' + ) + parser.add_argument('dataset_name', type=str, + help='Name of dataset to embed') + parser.add_argument('wavs', type=str, + help='Path to labeled audio chunks. ' + 'All .wav files will be embedded') + parser.add_argument('output', type=str, + help='Path to desired directory for output database') + args = parser.parse_args() + + main(args.dataset_name, args.wavs, args.output) diff --git a/make_model/prepare_perch_embeddings.py b/make_model/prepare_perch_embeddings.py new file mode 100644 index 0000000..c006b8f --- /dev/null +++ b/make_model/prepare_perch_embeddings.py @@ -0,0 +1,105 @@ +''' +Convert Perch Embedding Output to standard embeddings .pkl +for easy training of various models + +This script processes the outputs of the perch embedding +scripts and converts to a .pkl dataframe that stores +filename, embedding, and related metadata + +Usage: python prepare_perch_embeddings \ + /path/to/sqlite_dir \ + /path/to/metadata_file \ + /path/to/output_dir \ + embeddings_description + +Arguments: + sqlite_dir (str): path to directory that contains + hoplite.sqlite & usearch.index + metadata_path (str): path to metadata file with labels + and fold information + outout_dir (str): path to directory to store output + pkl + embeddings_description (str) :description of set of embeddings + for file naming purposes + +Outputs: + _perch_embeddings.pkl + +''' + + +import os +import argparse +import pandas as pd +from perch_hoplite.db import sqlite_usearch_impl + + +def prepare_perch_embeddings(sqlite_dir, + metadata_path, + output_dir, + embeddings_description): + ''' + converts raw perch embeddings (from sqlite database ) into standard + dataframe format for SVM. + + Args: + sqlite_dir (str): path to directory that contains + hoplite.sqlite $ usearch.index + metadata_path (str): path to metadata file + output_dir (str): path to directory to store + output .pkl file + embeddings_description (str): description of set of embeddings + for file naming purposes + + Returns: + None + ''' + + # load embeddings database + db = sqlite_usearch_impl.SQLiteUsearchDB.create(sqlite_dir) + + # load dataset metadata + metadata = pd.read_csv(metadata_path, index_col=0) + + embeddings_data = [] + + n_embeddings = db.count_embeddings() + + for i in range(n_embeddings): + + file_name = db.get_embedding_source(i+1).source_id + embedding = db.get_embedding(i+1) + + base_dict = {'segment': file_name, + 'embedding': embedding} + + embeddings_data.append(base_dict) + + embeddings_df = pd.DataFrame(embeddings_data) + merged_df = pd.merge(embeddings_df, metadata, on='segment') + merged_df = merged_df.drop('segment_duration_s', axis=1) + merged_df = merged_df[['segment', 'label', 'fold', 'embedding']] + + output_filename = os.path.join(output_dir, f'{embeddings_description}_perch_embeddings.pkl') + merged_df.to_pickle(output_filename) + + print(f'Embeddings saved at:\n\t{output_filename}') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='Input Perch Embeddings sqlite database and output directory') + + parser.add_argument('sqlite_dir', type=str, + help='Path to directory that contains ' + 'hoplite.sqlite and usearch.index') + parser.add_argument('metadata_path', type=str, + help='Path to metadata file') + parser.add_argument('output_dir', type=str, + help='Directory for output file') + parser.add_argument('embeddings_description', type=str, + help='Name of embeddings group') + + args = parser.parse_args() + prepare_perch_embeddings(args.sqlite_dir, args.metadata_path, args.output_dir, args.embeddings_description) diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..4b1c907 --- /dev/null +++ b/pylintrc @@ -0,0 +1,3 @@ +[disable] +max-args=10 +max-positional-arguments=10 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0f44af2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[project] +name = "whoot" +dynamic = ["version"] +description = "Tools for capturing, analyzing, and parsing audio data" +readme = "README.md" +requires-python = ">= 3.10.0, < 3.13.0" +dependencies = [ + "librosa>=0.10.2.post1", + "numba==0.61.0", + "pandas>=2.3.0", + "pydub>=0.25.1", + "pyyaml>=6.0.2", + "scikit-learn>=1.7.0", + "tqdm>=4.67.1", +] + +[dependency-groups] +dev = [ + "flake8>=7.3.0", + "pylint>=3.3.7", + "flake8-docstrings>=1.7.0", +] + +[tool.setuptools.dynamic] +version = {attr = "whoot.__version__"} + +[project.optional-dependencies] +cpu = [ + "torch>=2.7.0", + "torchvision>=0.22.0", +] +cu128 = [ + "torch>=2.7.0", + "torchvision>=0.22.0", +] +model-training = [ + "datasets>=3.5.1,<4.0.0", + "timm>=1.0.15", + "pyha-analyzer@git+https://github.com/UCSD-E4E/pyha-analyzer-2.0.git@support_whoot", + "comet-ml>=3.43.2", +] + +notebooks = [ + "ipykernel>=6.29.5", + "ipywidgets>=8.1.6", +] + + +[packages.index] +cu128 = "https://download.pytorch.org/whl/cu128" + +[tool.setuptools] +packages = ["make_model", "assess_birdnet", "whoot_model_training"] + +[tool.uv.sources] +pyha-analyzer = { git = "https://github.com/UCSD-E4E/pyha-analyzer-2.0.git", branch = "support_whoot" } + diff --git a/whoot_model_training/README.md b/whoot_model_training/README.md new file mode 100644 index 0000000..4138f86 --- /dev/null +++ b/whoot_model_training/README.md @@ -0,0 +1,57 @@ +Toolkit for training Machine Learning Classification Models over audio dataset + +Key inspiration is https://github.com/UCSD-E4E/pyha-analyzer-2.0/tree/main. This repo differs in that it uses a traditional training pipeline rather than the Hugging Face Trainer. Hugging face trainer abstracts the training code, which should be explicit for this toolkit. + + +# Install + +To set up environment for model training: + +1) run steps 1 - 3 of the installation instructions in `whoot/README.md` +2) For step 4, specifically run `pip install -e .[model-training, cpu]` for cpu training, `pip install -e .[model-training, cu128]` for training on Nvidia GPUs + +Note that you should check what is supported by CUDA on your machine. See developers if you need a different CUDA version + +# Running + +0) Add your Comet-ML API to your local environment. See https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/ +1) Create a copy of the config found in `configs/config.yml` and fill it out for your dataset. See the [config](#config) section +2) Edit train.py to set up training for your dataset. If you are using a new dataset which an extractor does not exist for, contact code authors. +3) run `python train.py path/to/your/config/file.yml` + +# Config + +## Default Config Properties +The properties of `config.yml` are as follows: +### Data paths +`metadata_csv`: the path to the metadata file for your dataset. +`data_path`: Path to the highest level parent folder containing audio. Audio can be in a different path than the metadata! +`hf_cache_path`: cache for hugging face. This path will be automatically made as you run the script, this would be the location of where the new file should go + +### Required Variables +`COMET_PROJECT_NAME`: "whoot", this is the project on comet-ml training will run on. +`CUDA_VISIBLE_DEVICES`: "0" or "0,1", this controls how many GPUs the training uses. +`SUBPROJECT_NAME`: Some description to help filter which training this is used for, can be the task being done (multi_label_classification) or something else (fun_training_test) +`DATASET_NAME`: Name of the dataset being trained on, will be embedded on comet_ml to make searching easier + +## Project Specific config information +### Buowset +The filenames in metadata_csv are the audio files found in `data_path`. + +`SUBPROJECT_NAME` is either "binary" or "multilabelClass" +`DATASET_NAME` is buowset0 + +# Repo Philosophy + +The most challenging issue with machine learning is the dataset. This training repo intends to make it easy to modularize parts of the training pipeline, and integrate them together, ideally regardless of the dataset. + +The pipeline works in 5 parts: +- Extractors: Extractors take in raw data and reformats it into `AudioDatasets`, apache-arrow data structures implemented via HuggingFace with common columns between any dataset. Every label is one_hot_encoded and treated as multilabel regardless of the problem. Audio filepaths as casted into [Audio columns](https://huggingface.co/docs/datasets/v3.6.0/en/package_reference/main_classes#datasets.Audio). Extractors are *unique for each dataset* but *uniform in the AudioDataset*. + +- Preprocessors: Online preprocessors take rows in `AudioDatasets` and output `ModelInputs`, formatted data specific to a given model. Preprocessors read AudioDatasets and translate it so the Model can read it + +- Models: Models have defined `ModelInput` and `ModelOutput` formats. All ModelInputs and ModelOutputs have common data that they are required to have such that the `PyhaTrainer` can understand how to feed information to the Model, and how to read information from the model. All models implement their own loss functions and return a loss given labels. + +- Augmentations: TODO + +- PyhaTrainer: With few exceptions unrelated to bioacoustic classifications, all PyTorch training code is the same. The HuggingFace Trainer and the extension PyhaTrainer handle most training scripts you will ever write. Why not use it and focus on model design, dataset preprocessing and cleaning. As long as the trainer knows how to feed data into a model (`AudioDatasets` and `Preprocessors`) and how to read it (`ModelOutputs`), then it will have no issues. \ No newline at end of file diff --git a/whoot_model_training/configs/config.yml b/whoot_model_training/configs/config.yml new file mode 100644 index 0000000..c0563f5 --- /dev/null +++ b/whoot_model_training/configs/config.yml @@ -0,0 +1,12 @@ +# Data paths +metadata_csv: data/burrowing_owl_dataset/metadata.csv +data_path: data/burrowing_owl_dataset/audio +hf_cache_path: data/burrowing_owl_dataset/cache/metadata.hf + +# Required Variables +COMET_PROJECT_NAME: "whoot" +CUDA_VISIBLE_DEVICES: "0" #"0,1" +COMET_WORKSPACE: +SUBPROJECT_NAME: +DATASET_NAME: +COMET_WORKSPACE: \ No newline at end of file diff --git a/whoot_model_training/train.py b/whoot_model_training/train.py new file mode 100644 index 0000000..316f860 --- /dev/null +++ b/whoot_model_training/train.py @@ -0,0 +1,167 @@ +"""Trains a Mutliclass Model with Pytorch and Huggingface. + +This script can be used to run experiments with different +models and datasets to create any model for bioacoustic classification + +It is intended this script to be heavily modified with each experiment +(say one wants to use a different dataset, one should copy this and change the +extractor!) + +Usage: + $ python train.py /path/to/config.yml + +config.yml should contain frequently changed hyperparameters +""" +import os +import argparse +import yaml + +from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments +from whoot_model_training.data_extractor import buowset_extractor +from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig +from whoot_model_training import CometMLLoggerSupplement + +from whoot_model_training.preprocessors import ( + MelModelInputPreprocessor +) + +# Uncomment for use with data augmentation +# from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel +# from audiomentations import ( +# Compose, AddColorNoise, +# AddBackgroundNoise, PolarityInversion, Gain +# ) + + +def parse_config(config_path: str) -> dict: + """Wrapper to parse config. + + Args: + config_path (str): path to config file for training! + + Returns: + (dict): hyperparameters parameters + """ + config = {} + with open(config_path, "r", encoding="UTF-8") as f: + config = yaml.safe_load(f) + return config + + +def train(config): + """Highest level logic for training. + + Does the following: + - Formats the dataset into an AudioDataset + - Prepares preprocessing for each audio clip + - Builds the model + - Configures and runs the trainer + - Runs evaluation + + Args: + config (dict): the config used for training. Defined in yaml file + """ + # Extract the dataset + ds = buowset_extractor( + metadata_csv=config["metadata_csv"], + parent_path=config["data_path"], + output_path=config["hf_cache_path"], + ) + + # Create the model + run_name = "flac_pylint_test_efficientnet_b1_buowset" + model_config = TimmModelConfig( + timm_model="efficientnet_b1", + num_classes=ds.get_num_classes()) + model = TimmModel(model_config) + + # Preprocessors + + # Uncomment if doing work with data augmentation + # # Augmentations + # wav_augs = ComposeAudioLabel([ + # # AddBackgroundNoise( #We don't have background noise yet... + # # sounds_path="data_birdset/background_noise", + # # min_snr_db=10, + # # max_snr_db=30, + # # noise_transform=PolarityInversion(), + # # p=0.8 + # # ), + # Gain( + # min_gain_db = -12, + # max_gain_db = 12, + # p = 0.8 + # ), + # MixItUp( + # dataset_ref=ds["train"], + # min_snr_db=10, + # max_snr_db=30, + # noise_transform=PolarityInversion(), + # p=0.8 + # ) + # ]) + + # Online preprocessors prepare data for training + train_preprocessor = MelModelInputPreprocessor( + TimmInputs, duration=3 + ) + + preprocessor = MelModelInputPreprocessor( + TimmInputs, duration=3 + ) + + ds["train"].set_transform(train_preprocessor) + ds["valid"].set_transform(preprocessor) + ds["test"].set_transform(preprocessor) + + # Run training + training_args = WhootTrainingArguments( + run_name=run_name, + subproject_name=config["SUBPROJECT_NAME"], + dataset_name=config["DATASET_NAME"], + ) + + # COMMON OPTIONAL ARGS + training_args.num_train_epochs = 2 + training_args.eval_steps = 100 + training_args.per_device_train_batch_size = 32 + training_args.per_device_eval_batch_size = 32 + training_args.dataloader_num_workers = 36 + training_args.run_name = run_name + + trainer = WhootTrainer( + model=model, + dataset=ds, + training_args=training_args, + logger=CometMLLoggerSupplement( + augmentations=None, + name=training_args.run_name + ), + ) + + trainer.train() + model.save_pretrained("model_checkpoints/test") + + +def init_env(config: dict): + """Sets up local environment for COMET-ML training logging. + + Args: config (dict): at a minimum this has the project name + and CUDA devices that are allowed to be used. + """ + print(config) + os.environ["COMET_PROJECT_NAME"] = config["COMET_PROJECT_NAME"] + os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"] + check_for_comet = config["COMET_WORKSPACE"] is not None + assert check_for_comet, "Make sure to add a COMET_WORKSPACE to config" + os.environ["COMET_WORKSPACE"] = config["COMET_WORKSPACE"] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Input config path") + parser.add_argument("config", type=str, help="Path to config.yml") + args = parser.parse_args() + _config = parse_config(args.config) + + init_env(_config) + train(_config) diff --git a/whoot_model_training/train_binary.py b/whoot_model_training/train_binary.py new file mode 100644 index 0000000..4dc7707 --- /dev/null +++ b/whoot_model_training/train_binary.py @@ -0,0 +1,163 @@ +"""Trains a Mutliclass Model with Pytorch and Huggingface. + +This script can be used to run experiments with different +models and datasets to create any model for bioacoustic classification + +It is intended this script to be heavily modified with each experiment +(say one wants to use a different dataset, one should copy this and change the +extractor!) + +Usage: + $ python train.py /path/to/config.yml + +config.yml should contain frequently changed hyperparameters +""" +import os +import argparse +import yaml + +from whoot_model_training.trainer import WhootTrainer, WhootTrainingArguments +from whoot_model_training.data_extractor import buowset_binary_extractor +from whoot_model_training.models import TimmModel, TimmInputs, TimmModelConfig +from whoot_model_training import CometMLLoggerSupplement + +from whoot_model_training.preprocessors import ( + MelModelInputPreprocessor +) + +# Uncomment for use with data augmentation +# from pyha_analyzer.preprocessors import MixItUp, ComposeAudioLabel +# from audiomentations import ( +# Compose, AddColorNoise, +# AddBackgroundNoise, PolarityInversion, Gain +# ) + + +def parse_config(config_path: str) -> dict: + """Wrapper to parse config. + + Args: + config_path (str): path to config file for training! + + Returns: + (dict): hyperparameters parameters + """ + config = {} + with open(config_path, "r", encoding="UTF-8") as f: + config = yaml.safe_load(f) + return config + + +def train(config): + """Highest level logic for training! + + Does the following: + - Formats the dataset into an AudioDataset + - Prepares preprocessing for each audio clip + - Builds the model + - Configures and runs the trainer + - Runs evaluation + + Args: + config (dict): the config used for training. Defined in yaml file + """ + # Extract the dataset + ds = buowset_binary_extractor( + metadata_csv=config["metadata_csv"], + parent_path=config["data_path"], + output_path=config["hf_cache_path"], + ) + + # Create the model + run_name = "efficientnet_b1_testing_confusion_matrix_no_data_aug" + model_config = TimmModelConfig( + timm_model="efficientnet_b1", + num_classes=ds.get_num_classes()) + model = TimmModel(model_config) + + # Preprocessors + + # Uncomment if doing work with data augmentation + # # Augmentations + # wav_augs = ComposeAudioLabel([ + # # AddBackgroundNoise( #We don't have background noise yet... + # # sounds_path="data_birdset/background_noise", + # # min_snr_db=10, + # # max_snr_db=30, + # # noise_transform=PolarityInversion(), + # # p=0.8 + # # ), + # Gain( + # min_gain_db = -12, + # max_gain_db = 12, + # p = 0.8 + # ), + # MixItUp( + # dataset_ref=ds["train"], + # min_snr_db=10, + # max_snr_db=30, + # noise_transform=PolarityInversion(), + # p=0.8 + # ) + # ]) + + # Offline preprocessors prepare data for training + train_preprocessor = MelModelInputPreprocessor( + TimmInputs, duration=3 + ) + + preprocessor = MelModelInputPreprocessor( + TimmInputs, duration=3 + ) + + ds["train"].set_transform(train_preprocessor) + ds["valid"].set_transform(preprocessor) + ds["test"].set_transform(preprocessor) + + # Run training + training_args = WhootTrainingArguments( + run_name=run_name, + subproject_name=config["SUBPROJECT_NAME"], + dataset_name=config["DATASET_NAME"], + ) + + # COMMON OPTIONAL ARGS + training_args.num_train_epochs = 2 + training_args.eval_steps = 20 + training_args.per_device_train_batch_size = 32 + training_args.per_device_eval_batch_size = 32 + training_args.dataloader_num_workers = 36 + training_args.run_name = run_name + + trainer = WhootTrainer( + model=model, + dataset=ds, + training_args=training_args, + logger=CometMLLoggerSupplement( + augmentations=None, + name=training_args.run_name + ), + ) + + trainer.train() + + +def init_env(config: dict): + """Sets up local environment for COMET-ML training logging. + + Args: config (dict): at a minimum this has the project name + and CUDA devices that are allowed to be used. + """ + print(config) + os.environ["COMET_PROJECT_NAME"] = config["COMET_PROJECT_NAME"] + os.environ["CUDA_VISIBLE_DEVICES"] = config["CUDA_VISIBLE_DEVICES"] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Input config path") + parser.add_argument("config", type=str, help="Path to config.yml") + args = parser.parse_args() + _config = parse_config(args.config) + + init_env(_config) + train(_config) diff --git a/whoot_model_training/whoot_model_training/__init__.py b/whoot_model_training/whoot_model_training/__init__.py new file mode 100644 index 0000000..638ac1f --- /dev/null +++ b/whoot_model_training/whoot_model_training/__init__.py @@ -0,0 +1,5 @@ +"""Logging Toolkit for different MLops platforms.""" + +from .logger import CometMLLoggerSupplement + +__all__ = ["CometMLLoggerSupplement"] diff --git a/whoot_model_training/whoot_model_training/data_extractor/__init__.py b/whoot_model_training/whoot_model_training/data_extractor/__init__.py new file mode 100644 index 0000000..5e0ffe7 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/__init__.py @@ -0,0 +1,12 @@ +"""A zoo for extractors. + +Extractors convert raw data into AudioDatasets +Ideally you make a new Extractor for each new raw dataset +""" +from .buowset_extractor import ( + buowset_extractor, + buowset_binary_extractor, +) +from .esc50_extractor import esc50_extractor + +__all__ = ["buowset_extractor", "buowset_binary_extractor", "esc50_extractor"] diff --git a/whoot_model_training/whoot_model_training/data_extractor/buowset_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/buowset_extractor.py new file mode 100644 index 0000000..10e9e83 --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/buowset_extractor.py @@ -0,0 +1,177 @@ +"""Standardizes the format of the buowset dataset. + +Inspired by https://github.com/UCSD-E4E/pyha-analyzer-2.0/ + tree/main/pyha_analyzer/extractors + +The idea being extractors is that they take raw data, and +format it into a uniform dataset format, AudioDataset + +This way, it should be easier to define what a +common audio dataset format is between +parts of the codebase for training + +Supports both multilabel and binary labels +""" + +import os +from dataclasses import dataclass + +import numpy as np +from datasets import ( + load_dataset, + Audio, + DatasetDict, + ClassLabel, + Sequence, +) +from ..dataset import AudioDataset + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeros(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + + +@dataclass +class BuowsetParams(): + """Parameters that describe the Buowset. + + Args: + validation_fold (int): label for valid split + test_fold (int): label for valid split + sample_rate (int): sample rate of the data + filepath (int): name of column in csv for filepaths + """ + validation_fold = 4 + test_fold = 3 + sr = 32_000 + filepath = "segment" + + +def buowset_extractor( + metadata_csv, + parent_path, + output_path, + params: BuowsetParams = BuowsetParams() +): + """Extracts raw data in the buowset format into an AudioDataset. + + Args: + Metdata_csv (str): Path to csv containing buowset metadata + parent_path (str): Path to the parent folder for all audio data. + Note its assumed the audio filepath + in the csv is relative to parent_path + output_path (str): Path to where HF cache for this dataset should live + validation_fold (int): which fold is considered the validation set + Default 4 + test_fold (int): Which fold is considered the test set Default 3 + sr (int): Sample Rate of the audio files Default: 32_000 + filepath (str): Name of the column in the dataset containing + the filepaths Default: segment + + Returns: + (AudioDataset): See dataset.py, AudioDatasets are consider + the universal dataset for the training pipeline. + """ + # Hugging face by default defines a train split + ds = load_dataset("csv", data_files=metadata_csv)["train"] + ds = ds.rename_column("label", "labels") # Convention here is labels + + # Convert to a uniform one_hot encoding for classes + ds = ds.class_encode_column("labels") + class_list = ds.features["labels"].names + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + ds = ds.map(lambda row: one_hot_encode(row, class_list)).cast_column( + "labels", multilabel_class_label + ) + + # Get audio into uniform format + ds = ds.add_column( + "audio", [ + os.path.join(parent_path, file) for file in ds[params.filepath] + ] + ) + + ds = ds.add_column("filepath", ds["audio"]) + ds = ds.cast_column("audio", Audio(sampling_rate=params.sr)) + + # Create splits of the data + test_ds = ds.filter(lambda x: x["fold"] == params.test_fold) + valid_ds = ds.filter(lambda x: x["fold"] == params.validation_fold) + train_ds = ds.filter( + lambda x: x[ + "fold" + ] != params.test_fold & x["fold"] != params.validation_fold + ) + ds = AudioDataset( + DatasetDict({"train": train_ds, "valid": valid_ds, "test": test_ds}) + ) + + ds.save_to_disk(output_path) + + return ds + + +def binarize_data(row, target_col=0): + """Convert a multilabel label into a binary one. + + Args: + row (dict): an example of data + target_col (int): which index is the label for no_buow + + Returns: + row (dict): now with a binary label instead + """ + row["labels"] = [row["labels"][target_col], 1-row["labels"][target_col]] + return row + + +def buowset_binary_extractor( + metadata_csv, + parent_path, + output_path, + target_col=0): + """Extracts raw data in the buowset format into an AudioDataset. + + BUT only allows for two classes: no_buow, yes_buow + + Args: + Metdata_csv (str): Path to csv containing buowset metadata + parent_path (str): Path to the parent folder for all audio data. + Note its assumed the audio filepath + in the csv is relative to parent_path + output_path (str): Path to where HF cache for this dataset should live + validation_fold (int): which fold is considered the validation set + Default 4 + test_fold (int): Which fold is considered the test set Default 3 + sr (int): Sample Rate of the audio files Default: 32_000 + target_col (int): label for no_buow + + Returns: + (AudioDataset): See dataset.py, AudioDatasets are consider + the universal dataset for the training pipeline. + """ + # Use the original extractor to create a multilabeled dataset + ads = buowset_extractor( + metadata_csv, + parent_path, + output_path, + ) + + # Now we just need to convert labels from multilabel to + # 0 or 1 + binary_class_label = Sequence(ClassLabel(names=["no_buow", "buow"])) + for split in ads: + ads[split] = ads[split].map( + lambda row: binarize_data(row, target_col=target_col) + ).cast_column("labels", binary_class_label) + + return ads diff --git a/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py new file mode 100644 index 0000000..4dd26bf --- /dev/null +++ b/whoot_model_training/whoot_model_training/data_extractor/esc50_extractor.py @@ -0,0 +1,128 @@ +"""Standardizes the format of the ESC-50 dataset. + +Inspired by https://github.com/UCSD-E4E/pyha-analyzer-2.0/ + tree/main/pyha_analyzer/extractors + +The idea being extractors is that they take raw data, and +format it into a uniform dataset format, AudioDataset + +This way, it should be easier to define what a +common audio dataset format is between +parts of the codebase for training + +Supports multilabel. + +Dataset: https://github.com/karolpiczak/ESC-50# +""" + +import os +from dataclasses import dataclass + +import numpy as np +from datasets import ( + load_dataset, + Audio, + DatasetDict, + ClassLabel, + Sequence, +) +from ..dataset import AudioDataset + + +def one_hot_encode(row: dict, classes: list): + """One hot Encodes a list of labels. + + Args: + row (dict): row of data in a dataset containing a labels column + classes: a list of classes + """ + one_hot = np.zeroes(len(classes)) + one_hot[row["labels"]] = 1 + row["labels"] = np.array(one_hot, dtype=float) + return row + + +@dataclass +class ESC50Params(): + """Parameters that describe ESC-50. + + validation_fold (int): label for valid split + test_fold (int): label for valid split + sample_rate (int): sample rate of the data + filepath (string): name of column in csv for filepaths + """ + validation_fold = 4 + test_fold = 5 + sample_rate = 44_100 + filepath = "filename" + + +def esc50_extractor( + metadata_csv, + parent_path, + output_path, + params: ESC50Params = ESC50Params() +): + """Extracts raw data in the ESC-50 format into an AudioDataset. + + Args: + Metdata_csv (str): Path to csv containing ESC-50 metadata + parent_path (str): Path to the parent folder for all audio data. + Note its assumed the audio filepath + in the csv is relative to parent_path + output_path (str): Path to where HF cache for this dataset should live + validation_fold (int): which fold is considered the validation set + Default 4 + test_fold (int): Which fold is considered the test set Default 3 + sr (int): Sample Rate of the audio files Default: 44_100 + filepath (str): Name of the column in the dataset containing + the filepaths Default: filename + + Returns: + (AudioDataset): See dataset.py, AudioDatasets are consider + the universal dataset for the training pipeline. + """ + # Hugging face by default defines a train split + dataset = load_dataset("csv", data_files=metadata_csv)["train"] + dataset = dataset.rename_column("category", "labels") + + dataset = dataset.class_encode_column("labels") + + class_list = dataset.features["labels"].names + + multilabel_class_label = Sequence(ClassLabel(names=class_list)) + + dataset = dataset.map( + lambda row: one_hot_encode(row, class_list) + ).cast_column( + "labels", + multilabel_class_label + ) + + dataset = dataset.add_column( + "audio", [ + os.path.join(parent_path, + file) for file in dataset[params.filepath] + ] + ) + dataset = dataset.add_column("filepath", dataset["audio"]) + dataset = dataset.cast_column("audio", + Audio(sampling_rate=params.sample_rate)) + + # Create splits of the data + test_ds = dataset.filter(lambda x: x["fold"] == params.test_fold) + valid_ds = dataset.filter(lambda x: x["fold"] == params.validation_fold) + train_ds = dataset.filter( + lambda x: ( + x["fold"] != params.test_fold + and x["fold"] != params.validation_fold + ) + ) + + dataset = AudioDataset( + DatasetDict({"train": train_ds, "valid": valid_ds, "test": test_ds}) + ) + + dataset.save_to_disk(output_path) + + return dataset diff --git a/whoot_model_training/whoot_model_training/dataset.py b/whoot_model_training/whoot_model_training/dataset.py new file mode 100644 index 0000000..70e86c4 --- /dev/null +++ b/whoot_model_training/whoot_model_training/dataset.py @@ -0,0 +1,92 @@ +"""The Canonical Dataset used for any and all bioacoustic training. + +Pulled from: +https://github.com/UCSD-E4E/pyha-analyzer-2.0/blob/main/pyha_analyzer/dataset.py +Key idea is we define a generic AudioDataset with uniform features + +Using an Arrow Dataset from Hugging Face's dataset library because +- Cool audio features https://huggingface.co/docs/datasets/en/audio_process +- Faster than pandas, better at managing memory +""" + +from datasets import DatasetDict, ClassLabel + +DEFAULT_COLUMNS = ["labels", "audio"] + + +class AudioDataset(DatasetDict): + """AudioDataset Class. + + If your dataset is an AudioDataset, it can be read by + the rest of the system + + Behind the scenes, this is a Apache Arrow Dataset Dict + (via hf library) where + each key is a split of the data (test/train/valid) + and the value is an arrow dataset + with at a minimum 2 columns: + - labels (Sequence of class labels, such as [0,10]) + - audio (Audio Column type from hugging face) + """ + def __init__(self, ds: DatasetDict): + """Creates the Audio Datasets. + + ds should be in the AudioDataset format after + being extracted by extractors + """ + self.validate_format(ds) + super().__init__(ds) + + def validate_format(self, ds: DatasetDict): + """Validates dataset is correctly formatted. + + Raises: + AssertionError if dataset is not correctly formatted. + """ + for split in ds.keys(): + dataset = ds[split] + for column in DEFAULT_COLUMNS: + phrase_one = "The column `" + phrase_two = "` is missing from dataset split `" + phrase_three = "`. Required by system" + state = ( + f"{phrase_one}{column}{phrase_two}{split}{phrase_three}" + ) + assert column in dataset.features, state + + def get_num_classes(self): + """Gets the number of classes in the dataset. + + Returns: + (int): the number of classes in this dataset + """ + return self["train"].features["labels"].feature.num_classes + + def get_number_species(self) -> int: + """Get the number of classes in the dataset! + + PyhaAnalyzer uses `get_number_species` for getting class count + This... isn't always the case that the dataset is species only + (could have calls!) + To support legacy PyhaAnalyzer, we therefore have this function. + + This should be deprecated in future versions of PyhaAnalyzer + + Returns: + (int): number of classes + """ + return self.get_num_classes() + + def get_class_labels(self) -> ClassLabel: + """Class mapping for this dataset. + + A common problem is when moving between datasets + creating mappings between classes + This aims to help standardize that by being + able to get the classLabels for this dataset + + Returns: + (ClassLabel): Mapping of all the names of + the labels to their index. + """ + return ClassLabel(names=self["train"].features["labels"].feature.names) diff --git a/whoot_model_training/whoot_model_training/logger.py b/whoot_model_training/whoot_model_training/logger.py new file mode 100644 index 0000000..5dad594 --- /dev/null +++ b/whoot_model_training/whoot_model_training/logger.py @@ -0,0 +1,62 @@ +"""Contains useful tools for additional logging. + +For example, CometMLLoggerSupplement adds additional +logging for data augmentations used compared +to the base logging done by the HF trainer +integration +""" + +import comet_ml + + +# pylint disable-next=R0903 +class CometMLLoggerSupplement(): + """Note, that is working with the Trainer! + + The Trainer class implements their own CometML Callback during training + This handles a lot but NOT ALL of the logging we want + + This class handles the last 10% of the logging we want such as + - Better dataset hashing + - git hash saving + - etc + """ + + def __init__(self, augmentations, name): + """Log in and start new experiment. + + Args: + augmentations: list of augmentations + To record what was used during run + name (str): run name + """ + comet_ml.login() + self.start(augmentations, name) + + def start(self, augmentations, name): + """Begins a new set of experiments. + + Helpful for cases where a new run has begun + + Args: + augmentations: list of augmentations + To record what was used during run + name (str): run name + """ + self.experiment = comet_ml.start() + + self.experiment.log_parameter("augmentations", augmentations) + self.experiment.set_name(name) + + def end(self): + """Fully ends experiment if still running.""" + return self.experiment.end() + + def log_task(self, task_name): + """Log what task this model should be listed under. + + Args: + task_name: usually what task the model is doing + and the dataset being used for training + """ + self.experiment.log_parameter("task", task_name) diff --git a/whoot_model_training/whoot_model_training/metrics.py b/whoot_model_training/whoot_model_training/metrics.py new file mode 100644 index 0000000..6206f43 --- /dev/null +++ b/whoot_model_training/whoot_model_training/metrics.py @@ -0,0 +1,80 @@ +"""Metrics for Bioacoustic multilabel Models. + +Helps us evaluate which models do well + +These the metrics with HF Trainer and are called +as part of a callback during training + +WhootMutliClassMetrics: Computes CMAP, ROCAUC and + confusion matrices each evaluation step of + the trainer +""" + +import comet_ml +import torch +from sklearn.metrics import confusion_matrix + +from pyha_analyzer.metrics.classification_metrics \ + import AudioClassificationMetrics + + +class WhootMutliClassMetrics(AudioClassificationMetrics): + """Report metrics to logging. + + Supports CMAP, ROCAUC, and confusion matrices. + and reports them to Comet-ML dashboards + """ + def __init__(self, classes: list): + """Initializes metric reporting. + + classes (list): all classes used by model + """ + self.classes = classes + self.training = True + super().__init__([], len(classes), multilabel=True) + + def __call__(self, eval_pred) -> dict[str, float]: + """Log all metrics. + + Args: + eval_pred: package of data provided by trainer + contains + - predictions: np.array of model outputs + - label_ids: np.array of ground truth targets + + Returns: + (dict) key name of metric, float metric score + """ + # CMAP / ROCAUC, done by AudioClassificationMetrics + initial_metrics = super().__call__(eval_pred=eval_pred) + + # Confusion Matrix + self.log_comet_ml_only(eval_pred) + + # Return the metrics that can be logged to console AND comet-ml + return initial_metrics + + def log_comet_ml_only(self, eval_pred): + """Logs confusion matrix. + + eval_pred: package of data provided by trainer + contains + - predictions: np.array of model outputs + - label_ids: np.array of ground truth targets + """ + # For metrics that are not loggable to console + # We can only have comet_ml for these metrics + experiment = comet_ml.get_running_experiment() + if experiment is None: + return + logits = torch.Tensor(eval_pred.predictions) + target = torch.Tensor(eval_pred.label_ids).to(torch.long) + + # Confusion Matrix WARNING, ONLY MAKES SENSE + # IF DATA IS MOSTLY MUTLICLASS + cm = confusion_matrix( + torch.argmax(target, dim=1), + torch.argmax(logits, dim=1) + ) + experiment.log_confusion_matrix( + matrix=cm.tolist(), labels=self.classes) diff --git a/whoot_model_training/whoot_model_training/models/__init__.py b/whoot_model_training/whoot_model_training/models/__init__.py new file mode 100644 index 0000000..3c539ff --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/__init__.py @@ -0,0 +1,17 @@ +"""a Bioacoustic Model Zoo! + +Example: + `from whoot_model_training.models import TimmModel +""" + +from .timm_model import TimmModel, TimmInputs, TimmModelConfig +from .model import Model, ModelInput, ModelOutput + +__all__ = [ + "TimmModel", + "TimmInputs", + "TimmModelConfig", + "Model", + "ModelInput", + "ModelOutput" +] diff --git a/whoot_model_training/whoot_model_training/models/model.py b/whoot_model_training/whoot_model_training/models/model.py new file mode 100644 index 0000000..aedc56f --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/model.py @@ -0,0 +1,225 @@ +"""Abstract Model Class for training. + +Any model trained with this repo SHOULD inherit from these classes found here + +There are 3 main classes +- ModelInput: dict-like class that define required input params to function +- ModelOutput: dict-like class that defines the output from the model +- Model: A PyTorch nn.Module class + +See timm_model.py for example about how these classes can be implemented. +""" + +from abc import abstractmethod +from functools import wraps +from collections import UserDict + +from pyha_analyzer.models.base_model import BaseModel +from transformers import PreTrainedModel, PretrainedConfig +import numpy as np + + +def has_required_inputs(): + """Wrapper for formatting input for a given Model! + + Checks to make sure a model is passed in the correct input + format, and returns the correct output format. + + Usually this is defined by `model.input_format` and + `model.output_format` + + MUST ALWAYS WRAP FORWARD FUNCTION OF MODEL + """ + def decorator(forward): + @wraps(forward) + def wrapper(self, x=None, **kwarg): + # During training, data is passed in as kwargs, (**ModelInput) + # due to how hugging face is designed + # this can be confusing if you are making custom models + # During inference, data is passed in as x, (ModelInput) + if x is None: + # ... but during training we just have the model + # pretend like it was passed in a ModelInput + x = self.input_format.from_dict(kwarg) + + assert isinstance(x, self.input_format) + model_output = forward(self, x) + assert isinstance(model_output, self.output_format) + + return model_output + + return wrapper + + return decorator + + +class ModelOutput(dict, UserDict): + """ModelOutput. + + Object that stores the output of a model + This allows for standardizing model outputs + So upstream applications don't need to change for specific models + + Inspired by HuggingFace Models + + Developer: recommended for each Model, to have an associated + ModelOutput class + """ + + # ignore some of the outputs when computing metrics + # When overwriting DON"T FORGET TO INCLUDE THIS + ignore_keys = ["predictions", "labels", "embeddings", "loss"] + + def __init__( + self, + _map: dict | None = None, + logits: np.ndarray | None = None, + embeddings: np.ndarray | None = None, + labels: np.ndarray | None = None, + loss: np.ndarray | None = None, + ): + """Create a new output to a model! + + Args: + logits: raw output from model + embeddings: some latent space encoding of data + Useful for transfer learning! + labels: labels for computing metrics + loss: loss as computed by the model + """ + super().__init__({ + "predictions": logits, + "logits": logits, + "labels": [labels], + # "label_ids": [labels], + "embeddings": embeddings, + "loss": loss + }) + + def items(self): + """Get all items in dict. + + But only if they are defined (not null)! + """ + return [ + (key, value) for ( + key, value + ) in super().items() if value is not None] + + +class ModelInput(UserDict, dict): + """ModelInput. + + Specifies Input Types + Hopefully should help standardize formatting for models + + Inspired by HuggingFace Models and Tokenizers + + Developer: recommended for each Model, to have an + associated ModelInput class + ALWAYS HAS A LABEL CATEGORY + """ + + def __init__( + self, + labels: np.ndarray, + waveform: np.ndarray | None = None, + spectrogram: np.ndarray | None = None, + ): + """Create a new input to a model! + + Args: + labels: one_hot encoded labels + waveform: raw audio signal + spectrogram: 2d matrix to represent the waveform + """ + super().__init__({ + "labels": labels, + "waveform": waveform, + "spectrogram": spectrogram + }) + + def items(self): + """Get all items in dict. + + But only if they are defined (not null)! + """ + return [ + (key, value) for ( + key, value + ) in super( + ).items() if value is not None] + + @classmethod + def from_dict(cls, some_input: dict): + """Recreates input for models! + + Sometimes inputs are given as kwargs + So lets recreate correct inputs for model + via building from a dictionary! + """ + spectrogram, waveform = None, None + labels = some_input["labels"] + if "spectrogram" in some_input: + spectrogram = some_input["spectrogram"] + if "waveform" in some_input: + waveform = some_input["waveform"] + + assert spectrogram is not None or waveform is not None + + return cls(labels, spectrogram=spectrogram, waveform=waveform) + + +class Model(PreTrainedModel, BaseModel): + """BaseModel Class for Whoot.""" + def __init__(self): + """Creates a basic model format. + + Anytime you create a new model, check if you need + to specify an input and output format for this model! + """ + self.input_format = ModelInput + self.output_format = ModelOutput + PreTrainedModel.__init__(self, PretrainedConfig()) + + assert hasattr(self.forward, "__wrapped__"), ( + "Please put `@has_required_inputs()", + "on the forward function of the model" + ) + + def get_embeddings(self, x: ModelInput) -> np.array: + """Gets an embedding for the model. + + This can be the final layer of a model backbone + or a set of useful features + + Args + x: Any | Either np.array or Torch.Tensor, the input for the model + + Returns + embedding: np.array, + some embedding vector representing the input data + """ + return self.forward(**x).embeddings + + @abstractmethod + @has_required_inputs() + def forward(self, x: ModelInput) -> ModelOutput: + """Runs some input x through the model. + + In PyTorch models, this is the same forward function logits + We just apply the convention for non Pytorch models, + Args: + x: Any + + Returns: + ModelOutput: dict, a dictionary like object that describes + """ + + def get_position_embeddings(self): + """Required by PretrainedModel, not needed for our work yet!""" + print("this model doesn't support position_embeddings") + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """Required by PretrainedModel, not needed for our work yet!""" + print("this model doesn't support position_embeddings") diff --git a/whoot_model_training/whoot_model_training/models/timm_model.py b/whoot_model_training/whoot_model_training/models/timm_model.py new file mode 100644 index 0000000..aac2094 --- /dev/null +++ b/whoot_model_training/whoot_model_training/models/timm_model.py @@ -0,0 +1,135 @@ +"""Wrapper around the timms model zoo! + +See https://timm.fast.ai/ + +Timm model zoo good for computer vision models +Like CNNs, which are useful for spectrograms + +Great repo for models, but currently using this for demoing pipeline +""" + +import timm +from torch import nn +from transformers import PretrainedConfig + +from .model import Model, ModelInput, ModelOutput, has_required_inputs + + +class TimmInputs(ModelInput): + """Input for TimmModels. + + Specifies TimmModels needs labels and spectrograms that are Tensors + """ + def __init__(self, labels, waveform=None, spectrogram=None): + """Creates TimmInputs. + + Args: + labels: the data's label for this batch + spectrogram: audio's spectrogram + waveform: Optional, audio waveform + """ + # # Can use inputs to verify correct shape for upstream model + # assert spectrogram.shape[1:] == (1, 100, 100) + super().__init__(labels, waveform, spectrogram) + self.labels = labels + self.spectrogram = spectrogram + + +class TimmModelConfig(PretrainedConfig): + """Config for Timm Model Zoo Models!""" + def __init__( + self, + timm_model="resnet34", + pretrained=True, + in_chans=1, + num_classes=6, + **kwargs + ): + """Creates Config. + + Args: + timm_model (str): name of a model in timm model zoo + pretrained (bool): use pretrain weights from timms + in_chans (int): channels in audio, mono is 1 + num_classes (int): number of classes in dataset, for cls + """ + self.timm_model = timm_model + self.pretrained = pretrained + self.in_chans = in_chans + self.num_classes = num_classes + super().__init__(**kwargs) + + +class TimmModel(Model, nn.Module): + """Model that uses a timm's model.""" + config_class = TimmModelConfig + + def __init__( + self, + config: TimmModelConfig + ): + """Init for TimmModel. + + kwargs: + timm_model (str): name of model backbone from timms to use, + Default: "resnet34" + pretrained (bool): use a pretrained model from timms, Default: True + in_chans (int): number of channels of audio: Default: 1 + num_classes (int): number of classes in the dataset: Default 6 + loss (any): custom loss function Default: BCEWithLogitsLoss + """ + super().__init__() + self.input_format = TimmInputs + self.output_format = ModelOutput + self.config = config + assert config.num_classes > 0 + + # Deep learning CNN backbone + self.backbone = timm.create_model( + config.timm_model, + pretrained=config.pretrained, + in_chans=config.in_chans + ) + + # Unsure if 1000 is default for all timm models. Need to check this + self.linear = nn.Linear(1000, config.num_classes) + + # different losses if you want to train for different problems + # BCEWithLogitsLoss is default as for Bioacoustics, the problem tends + # multilabel! + # the probability of class A occurring doesn't + # change the probability of Class B + # Many individuals can make calls at the same time! + self.loss = nn.BCEWithLogitsLoss() + + def set_custom_loss(self, loss_fn): + """Set a different loss function. + + For cases where we don't want BCEWithLogitsLoss + + Args: + loss_fn: Function to compute loss, ideally in pytorch + """ + self.loss = loss_fn + + @has_required_inputs() + def forward(self, x: TimmInputs) -> ModelOutput: + """Model forward function. + + Args: + x: (TimmInputs): The specific input format for Timm Models + + Returns + (ModelOutput): The model output (logits), + latent space representations (embeddings), loss and labels. + """ + embed = self.backbone(x.spectrogram) + logits = self.linear(embed) + loss = self.loss(logits, x.labels) + + return ModelOutput( + logits=logits, + embeddings=embed, + loss=loss, + labels=x.labels + ) diff --git a/whoot_model_training/whoot_model_training/preprocessors/__init__.py b/whoot_model_training/whoot_model_training/preprocessors/__init__.py new file mode 100644 index 0000000..8efacfb --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/__init__.py @@ -0,0 +1,20 @@ +"""A collection of online preprocessors. + +During training online preprocessors convert data +into data ready to be given to a model + +In traditional pytorch world, this would be like +the __get_item__ function of a dataset +""" + +from .base_preprocessor import ( + MelModelInputPreprocessor +) +from .spectrogram_preprocessors import ( + BuowMelSpectrogramPreprocessors +) + +__all__ = [ + "MelModelInputPreprocessor", + "BuowMelSpectrogramPreprocessors" +] diff --git a/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py new file mode 100644 index 0000000..a7ad953 --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/base_preprocessor.py @@ -0,0 +1,107 @@ +"""Default Class for Preprocessing the data. + +The dataset is one thing, what we feed into the models is another +Models may require spectrograms, waveforms, etc + +Not to mention any online augmentation we want to do + +The preprocessor class defines a function to preprocess our data during +training + +The default preprocessor allows for many types of preprocessors to run, +but it forces the output to fit the ModelInput class structure. +see `whoot_model_training/models/model.py` for more info. +""" +# pylint: disable=too-few-public-methods + +from pyha_analyzer.preprocessors import PreProcessorBase + +from .spectrogram_preprocessors import ( + BuowMelSpectrogramPreprocessors, + SpectrogramParams, + Augmentations +) +from ..models.model import ModelInput + + +class SpectrogramModelInPreprocessors(PreProcessorBase): + """Defines a preprocessor that after formatting the audio. + + Passes a spectrogram into a ModelInput object. + """ + def __init__( + self, + spec_preprocessor: PreProcessorBase, + model_input: ModelInput, + ): + """Wrapper to get the raw spectrogram output of spec_preprocessor. + + and format it neatly into a model_input + + Args: + spec_preprocessor (PreProcessorBase): a preprocessor that + creates spectrograms + model_input (ModelInput): How the model like input data formatted + """ + self.spec_preprocessor = spec_preprocessor + self.model_input = model_input + super().__init__(name="SpectrogramModelInPreprocessors") + + def __call__(self, batch: dict) -> ModelInput: + """Processes a batch of AudioDataset rows. + + For this specific preprocessor, it creates a spectrogram then + Formats the data as a ModelInput + """ + batch = self.spec_preprocessor(batch) + return self.model_input( + labels=batch["labels"], + spectrogram=batch["audio"] + ) + + +class MelModelInputPreprocessor(SpectrogramModelInPreprocessors): + """Demo of how SpectrogramModelInPreprocessors works. + + Uses a kind of Spectrogram Preprocessor, BuowMelSpectrogramPreprocessors + + This was created in part because legacy implementation of + SpectrogramModelInputPreprocessors had these parameters and subclassed + BuowMelSpectrogramPreprocessors. This class replicates the + format of the old SpectrogramModelInputPreprocessors + class with the new functionality + """ + def __init__( + self, + model_input: ModelInput, + duration=5, + augments: Augmentations = Augmentations(), + spectrogram_params: SpectrogramParams = SpectrogramParams(), + ): + """Creates a Online preprocessor for MelSpectrograms Based Models. + + Formats input into spefific ModelInput format. + + Args: + model_input (ModelInput): How the model like input data formatted + duration (int): Length in seconds of input + augments (dict): contains two keys: audio, + spectrogram each defining + a dict of augmentation names and augmentations to run + spectrogram_params (SpectrogramParams): + has the following parameters: + class_list (list): the classes we are + working with one-hot-encoding + n_fft (int): number of ffts + hop_length (int): hop length + power (int): power, defined by librosa + n_mels (int): number of mels for a melspectrogram + dataset_ref (AudioDataset): a + external ref to an AudioDataset + """ + spec_preprocessor = BuowMelSpectrogramPreprocessors( + duration=duration, + augments=augments, + spectrogram_params=spectrogram_params + ) + super().__init__(spec_preprocessor, model_input) diff --git a/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py new file mode 100644 index 0000000..555a801 --- /dev/null +++ b/whoot_model_training/whoot_model_training/preprocessors/spectrogram_preprocessors.py @@ -0,0 +1,143 @@ +"""Defines preprocessors for creating spectrograms. + +Pulled from pyha_analyzer/preprocessors/spectogram_preprocessors.py +""" +from dataclasses import dataclass + +import librosa +import numpy as np +from torchvision import transforms + +from pyha_analyzer.preprocessors import PreProcessorBase + + +@dataclass +class SpectrogramParams: + """Dataclass for spectrogram Parameters. + + n_fft: (int) number of fft bins + hop_length (int) skip count + power: (float) usually 2 + n_mels: (int) number of mel bins + """ + n_fft: int = 2048 + hop_length: int = 256 + power: float = 2.0 + n_mels: int = 256 + + +@dataclass +class Augmentations(): + """Dataclass for the augmentations of the model. + + audio (list[dict]): per item key name of augmentation, + value is the augmentation + spectrogram (list[dict]): same idea but augmentations + applied onto spectrograms + """ + audio = None + spectrogram = None + + +class BuowMelSpectrogramPreprocessors(PreProcessorBase): + """Preprocessor for processing audio into spectrograms. + + Particularly for the buow dataset + """ + + def __init__( + self, + duration=5, + augments: Augmentations = Augmentations(), + spectrogram_params: SpectrogramParams = SpectrogramParams() + ): + """Defines a BuowMelSpectrogramPreprocessors. + + Args: + duration (float): length of chunk of data to train on + augments (Augmentations): An augmentation to apply to waveforms + spectrogram_params (SpectrogramParams): + config for spectrogram generation + """ + self.duration = duration + self.augments = augments + + # Below parameter defaults from https://arxiv.org/pdf/2403.10380 pg 25 + self.n_fft = spectrogram_params.n_fft + self.hop_length = spectrogram_params.hop_length + self.power = spectrogram_params.power + self.n_mels = spectrogram_params.n_mels + self.spectrogram_params = spectrogram_params + + super().__init__(name="MelSpectrogramPreprocessor") + + def __call__(self, batch): + """Process a batch of data from an AudioDataset.""" + new_audio = [] + new_labels = [] + for item_idx in range(len(batch["audio"])): + label = batch["labels"][item_idx] + y, sr = librosa.load(path=batch["audio"][item_idx]["path"]) + start = 0 + + # Handle out of bound issues + end_sr = int(start * sr) + int(sr * self.duration) + if y.shape[-1] <= end_sr: + y = np.pad(y, end_sr - y.shape[-1]) + + # Audio Based Augmentations + if self.augments.audio is not None: + y, label = self.augments.audio(y, sr, label) + + pillow_transforms = transforms.ToPILImage() + + mels = ( + np.array( + pillow_transforms( + librosa.feature.melspectrogram( + y=y[int(start * sr):end_sr], + sr=sr, + n_fft=self.n_fft, + hop_length=self.hop_length, + power=self.power, + n_mels=self.n_mels, + ) + ), + np.float32, + )[np.newaxis, ::] + / 255 + ) + + if self.augments.spectrogram is not None: + mels = self.augments.spectrogram(mels) + + new_audio.append(mels) + new_labels.append(label) + + batch["audio"] = new_audio + batch["labels"] = np.array(new_labels, dtype=np.float32) + + return batch + + def get_augmentations(self): + """Returns a list of augmentations. + + Perhaps for logging purposes + + Returns: + (list) all the augmentations + """ + return self.augments + + def __repr__(self): + """Use representation to describe the augmentations. + + Returns: + (str) all information about this preprocessor + """ + return ( + f"""{self.name} + Augmentations: {self.augments} + MelSpectrogram: {self.spectrogram_params} + """ + ) diff --git a/whoot_model_training/whoot_model_training/trainer.py b/whoot_model_training/whoot_model_training/trainer.py new file mode 100644 index 0000000..93f8702 --- /dev/null +++ b/whoot_model_training/whoot_model_training/trainer.py @@ -0,0 +1,105 @@ +"""Everything needed to train! + +given a model and a dataset + +WhootTrainingArguments: A container for the + many many args for WhootTrainer + +WhootTrainer: The class that is going to run training +""" + +from datetime import datetime +import os + +from pyha_analyzer import PyhaTrainingArguments +from pyha_analyzer import PyhaTrainer + +from .metrics import WhootMutliClassMetrics +from .dataset import AudioDataset +from .models import Model + + +class WhootTrainingArguments(PyhaTrainingArguments): + """Holds arguments use for training.""" + def __init__(self, + run_name: str, + subproject_name: str = "TESTING", + dataset_name: str = "DS_404"): + """Create Arguments. + + Args: + run_name (str): name of the current run + subproject_name (str): name of subproject + These experiments are a part of + dataset_name (str): name of dataset + used for model experiments + """ + assert subproject_name is not None + assert dataset_name is not None + + default_checkpoint_path = "model_checkpoints" + checkpoint_created_at = datetime.now().strftime("%m_%d_%Y_%H:%M:%S") + + # run_name is name of the model + # task_name is name of the model task and dataset trained + self.run_name = f"{subproject_name}_{dataset_name}_{run_name}" + self.task_name = f"{subproject_name}_{dataset_name}" + + print( + f"Starting training on {dataset_name} for {subproject_name}" + ) + + super().__init__(os.path.join(f"{default_checkpoint_path}", + f"{run_name}_{checkpoint_created_at}")) + + # Required for whoot: override defaults in PyhaTrainingArguments + self.label_names = ["labels"] + self.remove_unused_columns = False + self.report_to = "comet_ml" + + +class WhootTrainer(PyhaTrainer): + """Trainers run the training of a model.""" + # WhootTrainer is ment to mimic the huggingface trainer + # Including number of arguments + # Aside, we really should consider how useful R0913,R0917 is... + + # pylint: disable-next=R0913,R0917 + def __init__( + self, + model: Model, + dataset: AudioDataset, + training_args: WhootTrainingArguments = None, + logger=None, + preprocessor=None, + ): + """Creates a trainer to hold training setup. + + Args: + model (Model): a pytorch model for training + should inherit from BaseModel see `models/model.py` + dataset (AudioDataset): A canonical audio dataset + Ideally attached some a preprocessor and returns ModelInputs + training_args (WhootTrainingArugments): + All the parameters that define training + logger (CometMLLoggerSupplement): + Class that adds additional logging + On top of logging done by PyhaTrainer + preprocessor (PreProcessorBase): + Preprocessor used for formatting the data + """ + metrics = WhootMutliClassMetrics(dataset.get_class_labels().names) + + if logger is not None: + logger.log_task(training_args.task_name) + + super().__init__( + model, + dataset, + metrics, + training_args, + logger, + None, # Data Collator, about to be deprecated + preprocessor, + model.output_format.ignore_keys + )