diff --git a/create_dataset/strat_k_folds.py b/create_dataset/strat_k_folds.py index 0de2eb0..fbb799d 100644 --- a/create_dataset/strat_k_folds.py +++ b/create_dataset/strat_k_folds.py @@ -1,6 +1,6 @@ """Split buowset into stratified k-folds. -Groups detections from the same wav file into 'groups' +Groups detections from the same site into 'groups' and then determines the overall class distribution and the class distribution for each 'group'. It allocates all the groups to a 'fold' in a way where the folds @@ -8,7 +8,7 @@ dataset. Usage: - python3 strat_k_folds.py /path/to/metadata.csv + python3 strat_k_folds.py /path/to/metadata.csv /path/to/site_list.txt """ import argparse import pandas as pd @@ -18,17 +18,23 @@ from k_fold_split_copy import solve -def create_strat_folds(df): +def create_strat_folds(df, site_ids): """Create grouped stratified k-folds. Args: df (pd.Dataframe): The metadata csv from when the dataset was created. + site_ids (list): Site names (found in original file path) to group by. Returns: pd.DataFrame: The same metadata but with labels as ints and a new fold column to denote the fold that segment is apart of. """ num_classes = 6 + for index, row in df.iterrows(): + for site_id in site_ids: + if site_id in row['original_path']: + df.loc[index, 'site_id'] = site_id + print(df.head) original_df = df df['label'] = df['label'].replace('cluck', 0) df['label'] = df['label'].replace('coocoo', 1) @@ -37,7 +43,7 @@ def create_strat_folds(df): df['label'] = df['label'].replace('chick begging', 4) df['label'] = df['label'].replace('no_buow', 5) # group is the subset of the index which is the wav file they all come from - grouped = df.groupby('original_path') + grouped = df.groupby('site_id') group_names = [] group_matrix = [] for index, group in grouped: @@ -59,7 +65,7 @@ def create_strat_folds(df): # the % of each class in each fold print(f"Fold percents: {fold_percents}") print(folds) - grouped_original = original_df.groupby('original_path') + grouped_original = original_df.groupby('site_id') df_with_folds = pd.DataFrame() count = 0 for i, group in grouped_original: @@ -69,14 +75,17 @@ def create_strat_folds(df): return df_with_folds -def main(meta): +def main(meta, sites): """Execute main script. Args: meta (str): Path to metadata csv from creating the dataset. + sites (str): Path to sites to group by. """ df = pd.read_csv(meta, index_col=0) - df_with_folds = create_strat_folds(df) + with open(sites, 'r', encoding='utf-8') as file: + site_ids = [line.strip() for line in file.readlines()] + df_with_folds = create_strat_folds(df, site_ids) df_with_folds.to_csv("5-fold_meta.csv") @@ -86,5 +95,7 @@ def main(meta): ) parser.add_argument('meta', type=str, help='Path to metadata csv') + parser.add_argument('sites', type=str, + help='Path to site list') args = parser.parse_args() - main(args.meta) + main(args.meta, args.sites)