Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions create_dataset/strat_k_folds.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Split buowset into stratified k-folds.

Groups detections from the same wav file into 'groups'
Groups detections from the same site into 'groups'
and then determines the overall class distribution and
the class distribution for each 'group'. It allocates
all the groups to a 'fold' in a way where the folds
are roughly the same class distribution as the overall
dataset.

Usage:
python3 strat_k_folds.py /path/to/metadata.csv
python3 strat_k_folds.py /path/to/metadata.csv /path/to/site_list.txt
"""
import argparse
import pandas as pd
Expand All @@ -18,17 +18,23 @@
from k_fold_split_copy import solve


def create_strat_folds(df):
def create_strat_folds(df, site_ids):
"""Create grouped stratified k-folds.

Args:
df (pd.Dataframe): The metadata csv from when the dataset was created.
site_ids (list): Site names (found in original file path) to group by.

Returns:
pd.DataFrame: The same metadata but with labels as ints and a new fold
column to denote the fold that segment is apart of.
"""
num_classes = 6
for index, row in df.iterrows():
for site_id in site_ids:
if site_id in row['original_path']:
df.loc[index, 'site_id'] = site_id
print(df.head)
original_df = df
df['label'] = df['label'].replace('cluck', 0)
df['label'] = df['label'].replace('coocoo', 1)
Expand All @@ -37,7 +43,7 @@ def create_strat_folds(df):
df['label'] = df['label'].replace('chick begging', 4)
df['label'] = df['label'].replace('no_buow', 5)
# group is the subset of the index which is the wav file they all come from
grouped = df.groupby('original_path')
grouped = df.groupby('site_id')
group_names = []
group_matrix = []
for index, group in grouped:
Expand All @@ -59,7 +65,7 @@ def create_strat_folds(df):
# the % of each class in each fold
print(f"Fold percents: {fold_percents}")
print(folds)
grouped_original = original_df.groupby('original_path')
grouped_original = original_df.groupby('site_id')
df_with_folds = pd.DataFrame()
count = 0
for i, group in grouped_original:
Expand All @@ -69,14 +75,17 @@ def create_strat_folds(df):
return df_with_folds


def main(meta):
def main(meta, sites):
"""Execute main script.

Args:
meta (str): Path to metadata csv from creating the dataset.
sites (str): Path to sites to group by.
"""
df = pd.read_csv(meta, index_col=0)
df_with_folds = create_strat_folds(df)
with open(sites, 'r', encoding='utf-8') as file:
site_ids = [line.strip() for line in file.readlines()]
df_with_folds = create_strat_folds(df, site_ids)
df_with_folds.to_csv("5-fold_meta.csv")


Expand All @@ -86,5 +95,7 @@ def main(meta):
)
parser.add_argument('meta', type=str,
help='Path to metadata csv')
parser.add_argument('sites', type=str,
help='Path to site list')
args = parser.parse_args()
main(args.meta)
main(args.meta, args.sites)