diff --git a/src/vak/__main__.py b/src/vak/__main__.py index a25d3f833..e9f1c3aca 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -2,51 +2,31 @@ Invokes __main__ when the module is run as a script. Example: python -m vak --help """ - -import argparse -from pathlib import Path +import sys from .cli import cli -def get_parser(): - """returns ArgumentParser instance used by main()""" - parser = argparse.ArgumentParser( - prog="vak", - description="vak command-line interface", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "command", - type=str, - metavar="command", - choices=cli.CLI_COMMANDS, - help="Command to run, valid options are:\n" - f"{cli.CLI_COMMANDS}\n" - "$ vak train ./configs/config_2018-12-17.toml", - ) - parser.add_argument( - "configfile", - type=Path, - help="name of config.toml file to use \n" - "$ vak train ./configs/config_2018-12-17.toml", - ) - return parser - - -def main(args=None): +def main(args_list:list[str] | None = None): """Main function called when run as script or through command-line interface called when package is run with `python -m vak` or alternatively just calling `vak` at the command line (because this function is installed under just `vak` as a console script) - ``args`` is used for unit testing only + ``args_list`` is used for unit testing only """ - if args is None: - parser = get_parser() + parser = cli.get_parser() + + if len(sys.argv) < 2: + parser.print_help() + sys.exit(1) + + if args_list is None: args = parser.parse_args() - cli.cli(command=args.command, config_file=args.configfile) + else: + args = parser.parse_args(args_list) + cli.cli(args) if __name__ == "__main__": diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index d6d2eaca3..468db7972 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -1,56 +1,220 @@ -def eval(toml_path): +"""Implements the vak command-line interface""" +import argparse +import pathlib +from dataclasses import dataclass +from typing import Callable + + +def eval(args): from .eval import eval - eval(toml_path=toml_path) + eval(toml_path=args.configfile) -def train(toml_path): +def train(args): from .train import train - train(toml_path=toml_path) + train(toml_path=args.configfile) -def learncurve(toml_path): +def learncurve(args): from .learncurve import learning_curve - learning_curve(toml_path=toml_path) + learning_curve(toml_path=args.configfile) -def predict(toml_path): +def predict(args): from .predict import predict - predict(toml_path=toml_path) + predict(toml_path=args.configfile) -def prep(toml_path): +def prep(args): from .prep import prep - prep(toml_path=toml_path) + prep(toml_path=args.configfile) -COMMAND_FUNCTION_MAP = { - "prep": prep, - "train": train, - "eval": eval, - "predict": predict, - "learncurve": learncurve, -} +def configfile(args): + from ..config.generate import generate + generate( + kind=args.kind, + add_prep=args.add_prep, + dst=args.dst, + ) -CLI_COMMANDS = tuple(COMMAND_FUNCTION_MAP.keys()) +@dataclass +class CLICommand: + """Dataclass representing a cli command + + Attributes + ---------- + name : str + Name of the command, that gets added to the CLI as a sub-parser + help : str + Help for the command, that gets added to the CLI as a sub-parser + func : Callable + Function to call for command + add_parser_args_func: Callable + Function to call to add arguments to sub-parser representing command + """ + name: str + help: str + func: Callable + add_parser_args_func : Callable + + +def add_single_arg_configfile_to_command( + cli_command, + cli_command_parser +): + """Most of the CLICommands call this function + to add arguments to their sub-parser. + It adds a single positional argument, `configfile`. + Not to be confused with the *command* configfile, + that adds different arguments + """ + cli_command_parser.add_argument( + "configfile", + type=pathlib.Path, + help="name of TOML configuration file to use \n" + f"$ vak {cli_command.name} ./configs/config_rat01337.toml", + ) + + +KINDS_OF_CONFIG_FILES = [ + # FIXME: there's no way to have a stand-alone prep file right now + # we need to add a `purpose` key-value pair to the file format + # to make this possible + # "prep", + "train", + "eval", + "predict", + "learncurve", +] + + +def add_args_to_configfile_command( + cli_command, + cli_command_parser +): + """This is the function that gets called + to add arguments to the sub-parser + for the configfile command + """ + cli_command_parser.add_argument( + "kind", + type=str, + choices=KINDS_OF_CONFIG_FILES, + help="kind: the kind of TOML configuration file to generate" + ) + cli_command_parser.add_argument( + "--add-prep", + action=argparse.BooleanOptionalAction, + default=False, + help="Adding this option will add a 'prep' table to the TOML configuration file. Default is False." + ) + cli_command_parser.add_argument( + "--dst", + type=pathlib.Path, + default=pathlib.Path.cwd(), + help="Destination, where TOML configuration file should be generated. Default is current working directory." + ) + # TODO: add this option + # cli_command_parser.add_argument( + # "--from", + # type=pathlib.Path, + # help="Path to another configuration file that this file should be generated from\n" + # ) + + +CLI_COMMANDS = [ + CLICommand( + name='prep', + help='prepare a dataset', + func=prep, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='train', + help='train a model', + func=train, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='eval', + help='evaluate a trained model', + func=eval, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='predict', + help='generate predictions from trained model', + func=predict, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='learncurve', + help='run a learning curve', + func=learncurve, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='configfile', + help='generate a TOML configuration file for vak', + func=configfile, + add_parser_args_func=add_args_to_configfile_command, + ), +] + + +def get_parser(): + """returns ArgumentParser instance used by main()""" + parser = argparse.ArgumentParser( + prog="vak", + description="Vak command-line interface", + formatter_class=argparse.RawTextHelpFormatter, + ) + + # create sub-parser + sub_parsers = parser.add_subparsers( + title="Command", + description="Commands for the vak command-line interface", + dest="command", + required=True, + ) + + for cli_command in CLI_COMMANDS: + cli_command_parser = sub_parsers.add_parser( + cli_command.name, + help=cli_command.help + ) + cli_command.add_parser_args_func( + cli_command, + cli_command_parser + ) + + return parser + + +CLI_COMMAND_FUNCTION_MAP = { + cli_command.name: cli_command.func + for cli_command in CLI_COMMANDS +} -def cli(command, config_file): +def cli(args: argparse.Namespace): """Execute the commands of the command-line interface. Parameters ---------- - command : string - One of {'prep', 'train', 'eval', 'predict', 'learncurve'} - config_file : str, Path - path to a config.toml file + args : argparse.Namespace + Result of calling :meth:`ArgumentParser.parse_args` + on the :class:`ArgumentParser` instance returned by + :func:`vak.cli.cli.get_parser`. """ - if command in COMMAND_FUNCTION_MAP: - COMMAND_FUNCTION_MAP[command](toml_path=config_file) + if args.command in CLI_COMMAND_FUNCTION_MAP: + CLI_COMMAND_FUNCTION_MAP[args.command](args) else: - raise ValueError(f"command not recognized: {command}") + raise ValueError(f"command not recognized: {args.command}") diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index d86c4c0a9..de40ae810 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -51,7 +51,7 @@ def purpose_from_toml( # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory -# see https://github.com/NickleDave/vak/issues/334 +# see https://github.com/vocalpy/vak/issues/334 TABLES_PREP_SHOULD_PARSE = "prep" diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index c1828aff9..4c8ab1947 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -17,6 +17,7 @@ from .config import Config from .dataset import DatasetConfig from .eval import EvalConfig +from ._generate import generate from .learncurve import LearncurveConfig from .model import ModelConfig from .predict import PredictConfig @@ -29,6 +30,7 @@ "config", "dataset", "eval", + "generate", "learncurve", "model", "load", diff --git a/src/vak/config/_generate.py b/src/vak/config/_generate.py new file mode 100644 index 000000000..4393cb3e2 --- /dev/null +++ b/src/vak/config/_generate.py @@ -0,0 +1,113 @@ +import importlib.resources +import pathlib + +import tomlkit + + +CONFIGFILE_KIND_FILENAME_MAP = { + "train": "configfile_train.toml", + "eval": "configfile_eval.toml", + "predict": "configfile_predict.toml", + "learncurve": "configfile_learncurve.toml", +} + +# next line: can't use `.items()`, we'll get `RuntimeError` about dictionary changed sized during iteration +for key in list(CONFIGFILE_KIND_FILENAME_MAP.keys()): + val = CONFIGFILE_KIND_FILENAME_MAP[key] + CONFIGFILE_KIND_FILENAME_MAP[f"{key}_prep"] = val.replace(key, f"{key}_prep") + + +def generate( + kind: str, + add_prep: bool = False, + dst: str | pathlib.Path | None = None, +) -> None: + """Generate a TOML configuration file for :mod:`vak`. + + Parameters + ---------- + kind : str + The kind of TOML configuration file to generate. + One of: ``{'train', 'eval', 'predict', 'learncurve'}`` + add_prep : bool + If True, add a ``[vak.prep]`` table to the + TOML configuration file. + dst : string, pathlib.Path + Destination for the generated configuration file. + Either a full path including filename, + or a directory, in which case a default filename + will be used. + The default `dst` is the current working directory. + + Examples + -------- + + Generate a TOML configuration file in the current working directory to prepare a dataset and train a model. + + >>> vak.config.generate("train", add_prep=True) + + Generate a TOML configuration file in a specified directory to train a model, e.g. on an existing dataset. + + >>> import pathlib + >>> dst = pathlib.Path("./data/configs") + >>> vak.config.generate("train", add_prep=True, dst=dst) + + Generate a TOML configuration file with a specific file name to train a model, e.g. on an existing dataset. + + >>> import pathlib + >>> dst = pathlib.Path("./data/configs/train-bfsongrepo.toml") + >>> vak.config.generate("train", add_prep=True, dst=dst) + + + Notes + ----- + This is the function called by + :func:`vak.cli.cli.generate` + when a user runs the command ``vak configfile`` + using the command-line interface. + """ + if dst is None: + # we can't make this the default value of the parameter in the function signature + # since it would get the value at import time, and we need the value at runtime + dst = pathlib.Path.cwd() + + dst = pathlib.Path(dst) + if not dst.is_dir() and dst.exists(): + raise FileExistsError( + f"Destination for generated config file `dst` is already a file that exists:\n{dst}\n" + "Please specify a value for the `--dst` argument that will not overwrite an existing file." + ) + + if not dst.is_dir() and dst.suffix != ".toml": + raise ValueError( + f"If `dst` is a path that ends in a filename, not a directory, then the extension must be '.toml', but was: {dst.suffix}" + ) + + # for now, we "add a prep section" by using a naming convention + # and loading an existing toml file that has a `[vak.prep]` table + if add_prep: + kind = f"{kind}_prep" + + try: + src_filename = CONFIGFILE_KIND_FILENAME_MAP[kind] + except KeyError: + raise ValueError( + f"Invalid kind: {kind}" + ) + + src_path = pathlib.Path( + importlib.resources.files("vak.config._toml_config_templates").joinpath(src_filename) + ) + # even though we are loading an existing file, + # we use tomlkit to load and dump. + # TODO: add "interactive" arg and use tomlkit with `input` to interactively build config file + with src_path.open("r") as fp: + tomldoc = tomlkit.load(fp) + + if dst.is_dir(): + dst_path = dst / src_filename + else: + dst_path = dst + + with dst_path.open("w") as fp: + tomlkit.dump(tomldoc, fp) diff --git a/src/vak/config/_toml_config_templates/configfile_eval.toml b/src/vak/config/_toml_config_templates/configfile_eval.toml new file mode 100644 index 000000000..229a5d53d --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_eval.toml @@ -0,0 +1,58 @@ +# [vak.eval]: options for evaluating a trained model. This is done using the "test" split in a dataset by default. +[vak.eval] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 11 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/eval" +# dataset_path : path to dataset created by prep +# ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) + +# [vak.eval.post_tfm_kwargs]: options for post-processing +[vak.eval.post_tfm_kwargs] +# both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.02 + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# [vak.eval.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] +[vak.eval.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# [vak.eval.trainer]: this sub-table configures the `lightning.pytorch.Trainer` +[vak.eval.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_eval_prep.toml b/src/vak/config/_toml_config_templates/configfile_eval_prep.toml new file mode 100644 index 000000000..ca8398281 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_eval_prep.toml @@ -0,0 +1,87 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gy6or6/032212" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/train" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# annot_format: format of annotations +annot_format = "simple-seq" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 50 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 15 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 + +# [vak.eval]: options for evaluating a trained model. This is done using the "test" split in a dataset by default. +[vak.eval] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 11 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/eval" +# dataset_path : path to dataset created by prep +# ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) + +# [vak.eval.post_tfm_kwargs]: options for post-processing +[vak.eval.post_tfm_kwargs] +# both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.02 + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# [vak.eval.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] +[vak.eval.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# [vak.eval.trainer]: this sub-table configures the `lightning.pytorch.Trainer` +[vak.eval.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_learncurve.toml b/src/vak/config/_toml_config_templates/configfile_learncurve.toml new file mode 100644 index 000000000..a3d5294f4 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_learncurve.toml @@ -0,0 +1,59 @@ +# [vak.learncurve]: options for running the learning curve +# that estimates model performance +# as a function of training set size +[vak.learncurve] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 11 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 50 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 200 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 4 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 + +[vak.learncurve.post_tfm_kwargs] +majority_vote = true +min_segment_dur = 0.02 + +[vak.learncurve.dataset] +# params : parameters that configure the `vak.datapipes` or `vak.datasets` class +# for a frame classification model, we use dataset classes with a specific `window_size` +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +params = { window_size = 88 } +# path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.learncurve.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.learncurve.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml b/src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml new file mode 100644 index 000000000..43ca6080f --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml @@ -0,0 +1,106 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/Users/davidnicholson/Documents/repos/vocalpy/vak/tests/scripts/vaktestdata/../../data_for_tests/generated/spect-output-dir/audio_cbin_annot_notmat/gy6or6/032312" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "./tests/data_for_tests/generated/prep/learncurve/audio_cbin_annot_notmat/TweetyNet" +# audio_format: format of audio, either wav or cbin +spect_format = "npz" +# annot_format: format of annotations +annot_format = "notmat" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 50 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 15 +# test_dur: duration of test split in dataset, in seconds +test_dur = 30 +train_set_durs = [ 4, 6,] +num_replicates = 2 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 +# qualitatively, we find that log transforming the spectrograms improves performance; +# think of this as increasing the contrast between high power and low power regions +transform_type = "log_spect" +# specifying cutoff frequencies of the spectrogram can (1) make the model more +# computationally efficient and (2) improve performance by only fitting the model +# to parts of the spectrum that are relevant for sounds of interest. +# Note these cutoffs are applied by computing the whole spectrogram first +# and then throwing away frequencies above and below the cutoffs; +# we do not apply a bandpass filter to the audio. +freq_cutoffs = [ 500, 10000,] +# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN +# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the +# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, +# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). + +# [vak.learncurve]: options for running the learning curve +# that estimates model performance +# as a function of training set size +[vak.learncurve] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 11 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 50 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 200 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 4 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 + +[vak.learncurve.post_tfm_kwargs] +majority_vote = true +min_segment_dur = 0.02 + +[vak.learncurve.dataset] +# params : parameters that configure the `vak.datapipes` or `vak.datasets` class +# for a frame classification model, we use dataset classes with a specific `window_size` +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +params = { window_size = 88 } +# path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.learncurve.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.learncurve.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_predict.toml b/src/vak/config/_toml_config_templates/configfile_predict.toml new file mode 100644 index 000000000..6303c6537 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_predict.toml @@ -0,0 +1,57 @@ +# [vak.predict]: options for generating predictions with a trained model +[vak.predict] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 1 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/predict" +# annot_csv_filename +annot_csv_filename = "gy6or6.032312.annot.csv" +# The next two options are for post-processing transforms. +# Both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.01 +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# [vak.predict.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] +[vak.predict.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# [vak.predict.trainer]: this sub-table configures the `lightning.pytorch.Trainer` +[vak.predict.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_predict_prep.toml b/src/vak/config/_toml_config_templates/configfile_predict_prep.toml new file mode 100644 index 000000000..b82cf048c --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_predict_prep.toml @@ -0,0 +1,80 @@ +# PREP: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gy6or6/032312" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/predict" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# note that for predictions we don't need to specify labelset or annot_format +# note also that we do not specify train_dur / val_dur / test_dur; +# all data found in `data_dir` will be assigned to a "predict split" instead + +# SPECT_PARAMS: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 + +# PREDICT: options for generating predictions with a trained model +[vak.predict] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 1 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/predict" +# annot_csv_filename +annot_csv_filename = "gy6or6.032312.annot.csv" +# The next two options are for post-processing transforms. +# Both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.01 +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# We put this table though vak knows which model we are using +[vak.predict.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.predict.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_train.toml b/src/vak/config/_toml_config_templates/configfile_train.toml new file mode 100644 index 000000000..a6300d1d3 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_train.toml @@ -0,0 +1,56 @@ +# [vak.train]: options for training model +[vak.train] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "/PATH/TO/FOLDER/results/train" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 8 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 1000 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 500 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 6 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +window_size = 2000 + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet] +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.train.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.train.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/_toml_config_templates/configfile_train_prep.toml b/src/vak/config/_toml_config_templates/configfile_train_prep.toml new file mode 100644 index 000000000..1aa0dc45f --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_train_prep.toml @@ -0,0 +1,101 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gyor6/032212" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/train" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# annot_format: format of annotations +annot_format = "simple-seq" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 2000 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 170 +# test_dur: duration of test split in dataset, in seconds +test_dur = 350 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 +# qualitatively, we find that log transforming the spectrograms improves performance; +# think of this as increasing the contrast between high power and low power regions +transform_type = "log_spect" +# specifying cutoff frequencies of the spectrogram can (1) make the model more +# computationally efficient and (2) improve performance by only fitting the model +# to parts of the spectrum that are relevant for sounds of interest. +# Note these cutoffs are applied by computing the whole spectrogram first +# and then throwing away frequencies above and below the cutoffs; +# we do not apply a bandpass filter to the audio. +freq_cutoffs = [500, 8000] +# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN +# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the +# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, +# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). + +# [vak.train]: options for training model +[vak.train] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "/PATH/TO/FOLDER/results/train" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 8 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 1000 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 500 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 6 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +window_size = 2000 + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet] +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.train.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.train.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/load.py b/src/vak/config/load.py index 3134dc85e..6d4597490 100644 --- a/src/vak/config/load.py +++ b/src/vak/config/load.py @@ -1,4 +1,4 @@ -"""Functions to parse toml config files.""" +"""Functions to load TOML configuration files.""" from __future__ import annotations diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index f349db746..f4f519286 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -79,7 +79,7 @@ def are_tables_valid(config_dict, toml_path=None): from ..cli.cli import CLI_COMMANDS # avoid circular import cli_commands_besides_prep = [ - command for command in CLI_COMMANDS if command != "prep" + command.name for command in CLI_COMMANDS if command.name != "prep" ] tables_that_are_commands_besides_prep = [ table for table in tables if table in cli_commands_besides_prep diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 18d506be7..e0899b8ae 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -8,6 +8,7 @@ from .device import * from .trainer import * from .model import * +from .parser import * from .path import * from .source_files import * from .spect import * diff --git a/tests/fixtures/parser.py b/tests/fixtures/parser.py new file mode 100644 index 000000000..81d48ab6e --- /dev/null +++ b/tests/fixtures/parser.py @@ -0,0 +1,10 @@ +import pytest + +import vak.cli.cli + + +@pytest.fixture +def parser(): + """Return an instance of the parser used by the command-line interface, + by calling :func:`vak.cli.cli.get_parser`""" + return vak.cli.cli.get_parser() diff --git a/tests/test___main__.py b/tests/test___main__.py index eb94797a5..00b2d2304 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -1,4 +1,4 @@ -import pathlib +import subprocess from unittest import mock import pytest @@ -6,73 +6,57 @@ import vak -@pytest.fixture -def parser(): - return vak.__main__.get_parser() - - -def test_parser_usage(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(args=['']) - captured = capsys.readouterr() - assert captured.err.startswith( - "usage: vak [-h] command configfile" - ) - - -def test_parser_help(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(['-h']) - captured = capsys.readouterr() - assert captured.out.startswith( - "usage: vak [-h] command configfile" - ) - - DUMMY_CONFIGFILE = './configs/config_2018-12-17.toml' @pytest.mark.parametrize( - 'command, raises', + 'args_list', [ - ('prep', False), - ('train', False), - ('learncurve', False), - ('eval', False), - ('predict', False), - ('not-a-valid-command', True), + ['prep', DUMMY_CONFIGFILE], + ['train', DUMMY_CONFIGFILE], + ['learncurve', DUMMY_CONFIGFILE], + ['eval', DUMMY_CONFIGFILE], + ['predict', DUMMY_CONFIGFILE], + ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE] ] ) -def test_parser(command, - raises, - parser, - capsys): - if raises: - with pytest.raises(SystemExit): - parser.parse_args([command, DUMMY_CONFIGFILE]) - else: - args = parser.parse_args([command, DUMMY_CONFIGFILE]) - assert args.command == command - assert args.configfile == pathlib.Path(DUMMY_CONFIGFILE) - - -@pytest.mark.parametrize( - 'command', - [ - 'prep', - 'train', - 'learncurve', - 'eval', - 'predict', - ] -) -def test_main(command, - parser): - args = parser.parse_args([command, DUMMY_CONFIGFILE]) +def test_main(args_list): + """Test that :func:`vak.__main__.main` calls the function we expect through :func:`vak.cli.cli` + + Notes + ----- + We mock these and call it a unit test + because actually calling and running :func:vak.cli.prep` + would be expensive. + + The exception is `vak configfile` + that we test directly (in other test functions below). + """ + command = args_list[0] mock_cli_function = mock.Mock(name=f'mock_{command}') - with mock.patch.dict(vak.cli.cli.COMMAND_FUNCTION_MAP, - {command: mock_cli_function}) as mock_command_function_map: - vak.__main__.main(args) + with mock.patch.dict( + vak.cli.cli.CLI_COMMAND_FUNCTION_MAP, {command: mock_cli_function} + ): + # wAFAICT e can't do this with `subprocess` since the function won't be mocked in the subprocess, + # so we need to test indirectly with `arg_list` passed into `main` + vak.__main__.main(args_list) mock_cli_function.assert_called() + + +def test___main__prints_help_with_no_args(parser, capsys): + """Test that if we don't pass in any args, we get """ + parser.print_help() + expected_output = capsys.readouterr().out.rstrip() + + # doing this by calling a `subprocess` is slow but lets us test the CLI directly + result = subprocess.run("vak", capture_output=True, text=True) # call `vak` at CLI with no help + output = result.stdout.rstrip() + + assert output == expected_output + + +def test_configfile_command(): + # FIXME: copy whatever unit tests we write for `vak.config.generate.generate` + # FIXME: except we change the actual part of the test where we call the function + # FIXME: and we're going to use an `args_list` instead of providing parameters directly + assert False \ No newline at end of file diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py new file mode 100644 index 000000000..f2cac331c --- /dev/null +++ b/tests/test_cli/test_cli.py @@ -0,0 +1,96 @@ +import argparse +import pathlib +from unittest import mock + +import pytest + +import vak.cli.cli + + +DUMMY_CONFIGFILE_STR = './configs/config_2018-12-17.toml' +DUMMY_CONFIGFILE_PATH = pathlib.Path(DUMMY_CONFIGFILE_STR) + + +@pytest.mark.parametrize( + 'args_list, expected_attributes', + [ + ( + ['prep', DUMMY_CONFIGFILE_STR], + dict(command="prep", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['train', DUMMY_CONFIGFILE_STR], + dict(command="train", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['learncurve', DUMMY_CONFIGFILE_STR], + dict(command="learncurve", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['eval', DUMMY_CONFIGFILE_STR], + dict(command="eval", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['predict', DUMMY_CONFIGFILE_STR], + dict(command="predict", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['configfile', 'train'], + dict(command="configfile", kind="train", add_prep=False, dst=pathlib.Path.cwd()) + ), + ( + ['configfile', 'eval'], + dict(command="configfile", kind="eval", add_prep=False, dst=pathlib.Path.cwd()) + ), + ( + ['configfile', 'train', "--add-prep"], + dict(command="configfile", kind="train", add_prep=True, dst=pathlib.Path.cwd()) + ) + ] +) +def test_parser_commands_with_configfile(args_list, expected_attributes): + """Test that calling parser.parse_args gives us a Namespace with the expected args""" + parser = vak.cli.cli.get_parser() + assert isinstance(parser, argparse.ArgumentParser) + + args = parser.parse_args(args_list) + assert isinstance(args, argparse.Namespace) + + for attr_name, expected_value in expected_attributes.items(): + assert hasattr(args, attr_name) + assert getattr(args, attr_name) == expected_value + + + +def test_parser_raises(parser): + """Test that an invalid command passed into our ArgumentParser raises a SystemExit""" + with pytest.raises(SystemExit): + parser.parse_args(["not-a-valid-command", DUMMY_CONFIGFILE_STR]) + + +@pytest.mark.parametrize( + 'args_list', + [ + ['prep', DUMMY_CONFIGFILE_STR], + ['train', DUMMY_CONFIGFILE_STR], + ['learncurve', DUMMY_CONFIGFILE_STR], + ['eval', DUMMY_CONFIGFILE_STR], + ['predict', DUMMY_CONFIGFILE_STR], + ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE_STR] + ] +) +def test_cli( + args_list, parser, +): + """Test that :func:`vak.cli.cli.cli` calls the functions we expect""" + args = parser.parse_args(args_list) + + command = args_list[0] + mock_cli_function = mock.Mock(name=f'mock_{command}') + with mock.patch.dict( + vak.cli.cli.CLI_COMMAND_FUNCTION_MAP, {command: mock_cli_function} + ): + # we can't do this with `subprocess` since the function won't be mocked in the subprocess, + # so we need to test indirectly with `arg_list` passed into `main` + vak.cli.cli.cli(args) + mock_cli_function.assert_called() diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py new file mode 100644 index 000000000..96cd18a65 --- /dev/null +++ b/tests/test_config/test_generate.py @@ -0,0 +1,203 @@ +import os +import tempfile + +import pytest + +import vak.config + + +@pytest.mark.parametrize( + 'kind, add_prep, dst_name', + [ + # ---- train + ( + "train", + False, + None + ), + ( + "train", + True, + None + ), + ( + "train", + False, + "configs-dir" + ), + ( + "train", + True, + "configs-dir" + ), + ( + "train", + False, + "configs-dir/config.toml" + ), + ( + "train", + True, + "configs-dir/config.toml" + ), + # ---- eval + ( + "eval", + False, + None + ), + ( + "eval", + True, + None + ), + ( + "eval", + False, + "configs-dir" + ), + ( + "eval", + True, + "configs-dir" + ), + ( + "eval", + False, + "configs-dir/config.toml" + ), + ( + "eval", + True, + "configs-dir/config.toml" + ), + # ---- predict + ( + "predict", + False, + None + ), + ( + "predict", + True, + None + ), + ( + "predict", + False, + "configs-dir" + ), + ( + "predict", + True, + "configs-dir" + ), + ( + "predict", + False, + "configs-dir/config.toml" + ), + ( + "predict", + True, + "configs-dir/config.toml" + ), + # ---- learncurve + ( + "learncurve", + False, + None + ), + ( + "learncurve", + True, + None + ), + ( + "learncurve", + False, + "configs-dir" + ), + ( + "learncurve", + True, + "configs-dir" + ), + ( + "learncurve", + False, + "configs-dir/config.toml" + ), + ( + "learncurve", + True, + "configs-dir/config.toml" + ), + ] +) +def test_generate(kind, add_prep, dst_name, tmp_path): + """Test :func:`vak.config.generate.generate`""" + # FIXME: handle case where `dst` is a filename -- handle .toml extension + if dst_name is None: + dst = tmp_path / "tmp-dst-None" + else: + dst = tmp_path / dst_name + if dst.suffix == ".toml": + # if dst ends with a toml extension + # then its *parent* is the dir we need to make + dst.parent.mkdir() + else: + dst.mkdir() + + if dst_name is None: + os.chdir(dst) + vak.config.generate(kind=kind, add_prep=add_prep) + else: + vak.config.generate(kind=kind, add_prep=add_prep, dst=dst) + + if dst.is_dir(): + # we need to get the actual generated TOML + generated_toml_path = sorted(dst.glob("*toml")) + assert len(generated_toml_path) == 1 + generated_toml_path = generated_toml_path[0] + else: + generated_toml_path = dst + # next line: the rest of the assertions would fail if this one did + # but we're being super explicit here: + # if we specified a file name for dst then it should exist as a file + assert generated_toml_path.exists() + + # we can't load with `vak.config.Config.from_toml_path` + # because the generated config doesn't have a [vak.dataset.path] key-value pair yet, + # and the corresponding attrs class that represents that table will throw an error. + # So we load as a Python dict and check the expected keys are there. + # I don't have any better ideas at the moment for how to test + cfg_dict = vak.config.load._load_toml_from_path(generated_toml_path) + # N.B. that `vak.config.load._load_toml_from_path` accesses the top-level key "vak" + # and returns the result of that, so we don't need to do something like `cfg_dict["vak"]["prep"]` + assert kind in cfg_dict + if add_prep: + assert "prep" in cfg_dict + else: + assert "prep" not in cfg_dict + + +def test_generate_raises_FileExistsError(tmp_path): + """Test that func:`vak.config.generate.generate` raises + a FileExistsError if `dst` already exists""" + dst = tmp_path / "fake.config.toml" + with dst.open("w") as fp: + fp.write("[fake.config]") + with pytest.raises(FileExistsError): + vak.config.generate("train", add_prep=True, dst=dst) + + + +def test_generate_raises_ValueError(tmp_path): + """Test that :func:`vak.config.generate.generate` raises + a ValueError if `dst` is a path to a filename but the extension is not '.toml'""" + dst = tmp_path / "fake.config.json" + with dst.open("w") as fp: + fp.write("[fake.config]") + with pytest.raises(FileExistsError): + vak.config.generate("train", add_prep=True, dst=dst)