22from textwrap import wrap
33
44import click
5+ from click .core import ParameterSource # type: ignore[attr-defined]
56from tabulate import tabulate
67
78from together import Together
@@ -26,7 +27,22 @@ def fine_tuning(ctx: click.Context) -> None:
2627 "--n-checkpoints" , type = int , default = 1 , help = "Number of checkpoints to save"
2728)
2829@click .option ("--batch-size" , type = int , default = 32 , help = "Train batch size" )
29- @click .option ("--learning-rate" , type = float , default = 3e-5 , help = "Learning rate" )
30+ @click .option ("--learning-rate" , type = float , default = 1e-5 , help = "Learning rate" )
31+ @click .option (
32+ "--lora/--no-lora" ,
33+ type = bool ,
34+ default = False ,
35+ help = "Whether to use LoRA adapters for fine-tuning" ,
36+ )
37+ @click .option ("--lora-r" , type = int , default = 8 , help = "LoRA adapters' rank" )
38+ @click .option ("--lora-dropout" , type = float , default = 0 , help = "LoRA adapters' dropout" )
39+ @click .option ("--lora-alpha" , type = float , default = 8 , help = "LoRA adapters' alpha" )
40+ @click .option (
41+ "--lora-trainable-modules" ,
42+ type = str ,
43+ default = "all-linear" ,
44+ help = "Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'" ,
45+ )
3046@click .option (
3147 "--suffix" , type = str , default = None , help = "Suffix for the fine-tuned model name"
3248)
@@ -39,19 +55,44 @@ def create(
3955 n_checkpoints : int ,
4056 batch_size : int ,
4157 learning_rate : float ,
58+ lora : bool ,
59+ lora_r : int ,
60+ lora_dropout : float ,
61+ lora_alpha : float ,
62+ lora_trainable_modules : str ,
4263 suffix : str ,
4364 wandb_api_key : str ,
4465) -> None :
4566 """Start fine-tuning"""
4667 client : Together = ctx .obj
4768
69+ if lora :
70+ learning_rate_source = click .get_current_context ().get_parameter_source ( # type: ignore[attr-defined]
71+ "learning_rate"
72+ )
73+ if learning_rate_source == ParameterSource .DEFAULT :
74+ learning_rate = 1e-3
75+ else :
76+ for param in ["lora_r" , "lora_dropout" , "lora_alpha" , "lora_trainable_modules" ]:
77+ param_source = click .get_current_context ().get_parameter_source (param ) # type: ignore[attr-defined]
78+ if param_source != ParameterSource .DEFAULT :
79+ raise click .BadParameter (
80+ f"You set LoRA parameter `{ param } ` for a full fine-tuning job. "
81+ f"Please change the job type with --lora or remove `{ param } ` from the arguments"
82+ )
83+
4884 response = client .fine_tuning .create (
4985 training_file = training_file ,
5086 model = model ,
5187 n_epochs = n_epochs ,
5288 n_checkpoints = n_checkpoints ,
5389 batch_size = batch_size ,
5490 learning_rate = learning_rate ,
91+ lora = lora ,
92+ lora_r = lora_r ,
93+ lora_dropout = lora_dropout ,
94+ lora_alpha = lora_alpha ,
95+ lora_trainable_modules = lora_trainable_modules ,
5596 suffix = suffix ,
5697 wandb_api_key = wandb_api_key ,
5798 )
0 commit comments