diff --git a/colabs/llm-finetuning-handson/Alpaca_finetunning_with_WandB.ipynb b/colabs/llm-finetuning-handson/Alpaca_finetunning_with_WandB.ipynb
new file mode 100644
index 00000000..29eb9d7e
--- /dev/null
+++ b/colabs/llm-finetuning-handson/Alpaca_finetunning_with_WandB.ipynb
@@ -0,0 +1,740 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3c7c21b5-4457-481f-b2cc-fb20cdcbfbe3"
+ },
+ "source": [
+ "
From Llama to Alpaca: Finetunning and LLM with Weights & Biases
\n",
+ "In this notebooks you will learn how to finetune a model on an Instruction dataset. We will use an updated version of the Alpaca dataset that, instead of davinci-003 (GPT3) generations uses GPT4 to get an even better instruction dataset!\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "original github: https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM#how-good-is-the-data"
+ ],
+ "id": "3c7c21b5-4457-481f-b2cc-fb20cdcbfbe3"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "03ef319f-bf26-4192-8951-8d536181ab67"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git\n",
+ "!pip install -q wandb\n",
+ "!pip install -q ctranslate2\n",
+ "!pip install -q bitsandbytes datasets accelerate loralib"
+ ],
+ "id": "03ef319f-bf26-4192-8951-8d536181ab67"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-psLFcagZD0D"
+ },
+ "outputs": [],
+ "source": [
+ "import bitsandbytes as bnb\n",
+ "import copy\n",
+ "import glob\n",
+ "import os\n",
+ "import wandb\n",
+ "import json\n",
+ "from tqdm import tqdm\n",
+ "from types import SimpleNamespace\n",
+ "import datasets\n",
+ "from datasets import Dataset\n",
+ "import transformers\n",
+ "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator,GenerationConfig\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torch.nn.utils.rnn import pad_sequence\n",
+ "import pandas as pd\n",
+ "from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model\n",
+ "\n",
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\"\n",
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
+ ],
+ "id": "-psLFcagZD0D"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7804f904-5746-4530-867d-c766f4501dea"
+ },
+ "source": [
+ "## Prepare your Instruction Dataset"
+ ],
+ "id": "7804f904-5746-4530-867d-c766f4501dea"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "59bd7517-70d9-4dee-9f92-1a2891caf385"
+ },
+ "source": [
+ "An Instruction dataset is a list of instructions/outputs pairs that are relevant to your own domain. For instance it could be question and answers from an specific domain, problems and solution for a technical domain, or just instruction and outputs.\n",
+ "\n",
+ "\n",
+ "https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM#how-good-is-the-data\n",
+ "\n",
+ "So let's explore how one could do this?\n",
+ "After grabbing a finetuned model and curated your own dataset, how do I create a dataset that has the right format to fine tune a model?"
+ ],
+ "id": "59bd7517-70d9-4dee-9f92-1a2891caf385"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "da04c0a5-f481-4364-880d-10c254388987"
+ },
+ "source": [
+ "Let's grab the Alpaca (GPT-4 curated instructions and outputs) dataset:"
+ ],
+ "id": "da04c0a5-f481-4364-880d-10c254388987"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "52ff363e-8a24-4085-9b7e-6564d106d2e9"
+ },
+ "outputs": [],
+ "source": [
+ "!wget https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json"
+ ],
+ "id": "52ff363e-8a24-4085-9b7e-6564d106d2e9"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0fce67d2-3703-4042-816c-7a13ba9eab3e"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_file = \"alpaca_gpt4_data.json\"\n",
+ "with open(dataset_file, \"r\") as f:\n",
+ " alpaca = json.load(f)\n",
+ "print(alpaca[0])"
+ ],
+ "id": "0fce67d2-3703-4042-816c-7a13ba9eab3e"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e596e369-56aa-4721-9271-6686eed8fb35"
+ },
+ "source": [
+ "So the dataset has instruction and outputs. The model is trained to predict the next token, so one option would be just to concat both, and train on that. We ideally format the prompt in a way that we make explicit where is the input and output. Let's log the dataset to W&B so we keep everything organised"
+ ],
+ "id": "e596e369-56aa-4721-9271-6686eed8fb35"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gqAFsykRoQGh"
+ },
+ "outputs": [],
+ "source": [
+ "os.environ[\"WANDB_ENTITY\"]=\"keisuke-kamata\"\n",
+ "os.environ[\"WANDB_PROJECT\"]=\"alpaca_finetuning_with_wandb\"\n",
+ "wandb.login()"
+ ],
+ "id": "gqAFsykRoQGh"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "c9786466-4012-4cc4-8f9a-e673c74aa965"
+ },
+ "outputs": [],
+ "source": [
+ "# log to wandb\n",
+ "with wandb.init():\n",
+ " # log as a table\n",
+ " table = wandb.Table(columns=list(alpaca[0].keys()))\n",
+ " for row in alpaca:\n",
+ " table.add_data(*row.values())\n",
+ " wandb.log({\"alpaca_gpt4_table\": table})\n",
+ "\n",
+ " # log file with artifact\n",
+ " artifact = wandb.Artifact(\n",
+ " name=\"alpaca_gpt4\",\n",
+ " type=\"dataset\",\n",
+ " description=\"A GPT4 generated Alpaca like dataset for instruction finetunning\",\n",
+ " metadata={\"url\":\"https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM#how-good-is-the-data\"},\n",
+ " )\n",
+ " artifact.add_file(dataset_file)\n",
+ " wandb.log_artifact(artifact)"
+ ],
+ "id": "c9786466-4012-4cc4-8f9a-e673c74aa965"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "--J-1AEPoodp"
+ },
+ "source": [
+ "Data split"
+ ],
+ "id": "--J-1AEPoodp"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5dfa1958-c58d-4b40-a7d9-5e0cc6d2abf5"
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import pandas as pd\n",
+ "\n",
+ "seed = 42\n",
+ "random.seed(seed)\n",
+ "random.shuffle(alpaca) # this could also be a parameter\n",
+ "train_dataset_alpaca = alpaca[:10000]\n",
+ "val_dataset_alpaca = alpaca[-1000:]"
+ ],
+ "id": "5dfa1958-c58d-4b40-a7d9-5e0cc6d2abf5"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "72fc9e43-55b5-4a7b-902f-b8a3b82dbf06"
+ },
+ "source": [
+ "We should save the split to W&B"
+ ],
+ "id": "72fc9e43-55b5-4a7b-902f-b8a3b82dbf06"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BgJI83G-wKz6"
+ },
+ "outputs": [],
+ "source": [
+ "artifact_path = f'{os.environ[\"WANDB_ENTITY\"]}/{os.environ[\"WANDB_PROJECT\"]}/alpaca_gpt4:latest'"
+ ],
+ "id": "BgJI83G-wKz6"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fdbc0fd0-de6c-447d-abb9-3176c6ceeaf6"
+ },
+ "outputs": [],
+ "source": [
+ "with wandb.init(job_type=\"split_data\") as run:\n",
+ " artifact = run.use_artifact(artifact_path, type='dataset')\n",
+ " #artifact_folder = artifact.download()\n",
+ "\n",
+ " train_df = pd.DataFrame(train_dataset_alpaca)\n",
+ " eval_df = pd.DataFrame(val_dataset_alpaca)\n",
+ "\n",
+ " train_df.to_json(\"alpaca_gpt4_train.jsonl\", orient='records', lines=True)\n",
+ " eval_df.to_json(\"alpaca_gpt4_eval.jsonl\", orient='records', lines=True)\n",
+ "\n",
+ "\n",
+ " at = wandb.Artifact(\n",
+ " name=\"alpaca_gpt4_splitted\",\n",
+ " type=\"dataset\",\n",
+ " description=\"A GPT4 generated Alpaca like dataset for instruction finetunning\",\n",
+ " metadata={\"url\":\"https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM#how-good-is-the-data\"},\n",
+ " )\n",
+ " at.add_file(\"alpaca_gpt4_train.jsonl\")\n",
+ " at.add_file(\"alpaca_gpt4_eval.jsonl\")\n",
+ " train_table = wandb.Table(dataframe=train_df)\n",
+ " eval_table = wandb.Table(dataframe=eval_df)\n",
+ " run.log_artifact(at)\n",
+ " run.log({\"train_dataset\":train_table, \"eval_dataset\":eval_table})"
+ ],
+ "id": "fdbc0fd0-de6c-447d-abb9-3176c6ceeaf6"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "26cb96e0-b2f2-4a79-ba65-e1c8d6395d54"
+ },
+ "source": [
+ "Let's log the dataset also as a table so we can inspect it on the workspace."
+ ],
+ "id": "26cb96e0-b2f2-4a79-ba65-e1c8d6395d54"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fb99a7dc-0654-4985-ac3f-60ecdc6f6558"
+ },
+ "source": [
+ "## Train"
+ ],
+ "id": "fb99a7dc-0654-4985-ac3f-60ecdc6f6558"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Y2Q7FFhOYRj1"
+ },
+ "outputs": [],
+ "source": [
+ "config = {\n",
+ " \"BASE_MODEL\":\"facebook/opt-125m\",\n",
+ " \"lora_config\":{\n",
+ " \"r\":32,\n",
+ " \"lora_alpha\":16,\n",
+ " 'target_modules': [f\"model.decoder.layers.{i}.self_attn.{proj}_proj\" for i in range(31) for proj in ['q', 'k', 'v']],\n",
+ " \"lora_dropout\":.1,\n",
+ " \"bias\":\"none\",\n",
+ " \"task_type\":\"CAUSAL_LM\"\n",
+ " },\n",
+ " \"training_args\":{\n",
+ " \"dataloader_num_workers\":16,\n",
+ " \"evaluation_strategy\":\"steps\",\n",
+ " \"per_device_train_batch_size\":8,\n",
+ " \"max_steps\": 50,\n",
+ " \"gradient_accumulation_steps\":2,\n",
+ " \"report_to\":\"wandb\",#wandb integration\n",
+ " \"warmup_steps\":10,\n",
+ " \"num_train_epochs\":1,\n",
+ " \"learning_rate\":2e-4,\n",
+ " \"fp16\":True,\n",
+ " \"logging_steps\":10,\n",
+ " \"save_steps\":10,\n",
+ " \"output_dir\":'./outputs'\n",
+ " }\n",
+ "}"
+ ],
+ "id": "Y2Q7FFhOYRj1"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iFMlXGgmYVh9"
+ },
+ "outputs": [],
+ "source": [
+ "torch.cuda.empty_cache()\n",
+ "model = AutoModelForCausalLM.from_pretrained(\n",
+ " config[\"BASE_MODEL\"],\n",
+ " #load_in_8bit=True,\n",
+ " device_map=\"auto\",\n",
+ ")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(config[\"BASE_MODEL\"])\n",
+ "tokenizer.pad_token = tokenizer.eos_token"
+ ],
+ "id": "iFMlXGgmYVh9"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "89XTPJhjhz2H"
+ },
+ "outputs": [],
+ "source": [
+ "PROMPT_DICT = {\n",
+ " \"prompt_input\": (\n",
+ " \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n",
+ " \"### Instruction:{instruction} \\n\\n Input:{input} \\n\\n ###Response\"\n",
+ " ),\n",
+ " \"prompt_no_input\": (\n",
+ " \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n",
+ " \"### Instruction:{instruction} \\n\\n ###Response\"\n",
+ " )\n",
+ "}\n",
+ "\n",
+ "class InstructDataset(Dataset):\n",
+ " def __init__(self, json_list, tokenizer, ignore_index=-100):\n",
+ " self.tokenizer = tokenizer\n",
+ " self.ignore_index = ignore_index\n",
+ " self.features = []\n",
+ "\n",
+ " for j in tqdm(json_list):\n",
+ " if 'input' in j:\n",
+ " source_text = PROMPT_DICT['prompt_input'].format_map(j)\n",
+ " else:\n",
+ " source_text = PROMPT_DICT['prompt_no_input'].format_map(j)\n",
+ " example_text = source_text + j['output'] + self.tokenizer.eos_token\n",
+ "\n",
+ " source_tokenized = self.tokenizer(\n",
+ " source_text,\n",
+ " padding='longest',\n",
+ " truncation=True,\n",
+ " max_length=512,\n",
+ " return_length=True,\n",
+ " return_tensors='pt'\n",
+ " )\n",
+ "\n",
+ " example_tokenized = self.tokenizer(\n",
+ " example_text,\n",
+ " padding='longest',\n",
+ " truncation=True,\n",
+ " max_length=512,\n",
+ " return_tensors='pt'\n",
+ " )\n",
+ "\n",
+ " input_ids = example_tokenized['input_ids'][0]\n",
+ " labels = copy.deepcopy(input_ids)\n",
+ " source_len = source_tokenized['length'][0]\n",
+ " labels[:source_len] = self.ignore_index\n",
+ "\n",
+ " self.features.append({\n",
+ " 'input_ids': input_ids,\n",
+ " 'labels': labels\n",
+ " })\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.features)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " return self.features[idx]\n",
+ "\n",
+ "\n",
+ "class InstructCollator():\n",
+ " def __init__(self, tokenizer, ignore_index=-100):\n",
+ " self.tokenizer = tokenizer\n",
+ " self.ignore_index = -100\n",
+ "\n",
+ " def __call__(self, examples):\n",
+ " input_batch = []\n",
+ " label_batch = []\n",
+ " for example in examples:\n",
+ " input_batch.append(example['input_ids'])\n",
+ " label_batch.append(example['labels'])\n",
+ " input_ids = pad_sequence(\n",
+ " input_batch, batch_first=True, padding_value=self.tokenizer.pad_token_id\n",
+ " )\n",
+ " labels = pad_sequence(\n",
+ " label_batch, batch_first=True, padding_value=self.ignore_index\n",
+ " )\n",
+ " attention_mask = input_ids.ne(self.tokenizer.pad_token_id)\n",
+ " return {\n",
+ " 'input_ids': input_ids,\n",
+ " 'labels': labels,\n",
+ " 'attention_mask': attention_mask\n",
+ " }\n",
+ "\n",
+ "\n",
+ "train_dataset = InstructDataset(train_dataset_alpaca, tokenizer)\n",
+ "val_dataset = InstructDataset(val_dataset_alpaca , tokenizer)\n",
+ "\n",
+ "# Create the collator with the device\n",
+ "collator = InstructCollator(tokenizer, ignore_index=-100)"
+ ],
+ "id": "89XTPJhjhz2H"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "x1COOZoWsoHj"
+ },
+ "outputs": [],
+ "source": [
+ "# cast the small parameters (e.g. layernorm) to fp32 for stability\n",
+ "for param in model.parameters():\n",
+ " param.requires_grad = False # freeze the model - train adapters later\n",
+ " if param.ndim == 1:\n",
+ " param.data = param.data.to(torch.float32)\n",
+ "model.gradient_checkpointing_enable() # reduce number of stored activations\n",
+ "model.enable_input_require_grads()\n",
+ "class CastOutputToFloat(nn.Sequential):\n",
+ " def forward(self, x): return super().forward(x).to(torch.float32)\n",
+ "model.lm_head = CastOutputToFloat(model.lm_head)"
+ ],
+ "id": "x1COOZoWsoHj"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WoAiDU3c_xYG"
+ },
+ "outputs": [],
+ "source": [
+ "path_dataset_for_trainig = f'{os.environ[\"WANDB_ENTITY\"]}/{os.environ[\"WANDB_PROJECT\"]}/alpaca_gpt4_splitted:latest'"
+ ],
+ "id": "WoAiDU3c_xYG"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cttlNl_3XkgF"
+ },
+ "outputs": [],
+ "source": [
+ "with wandb.init(config=config, job_type=\"training\") as run:\n",
+ " # track data\n",
+ " run.use_artifact(path_dataset_for_trainig)\n",
+ " # Setup for LoRa\n",
+ " lora_config = LoraConfig(**wandb.config[\"lora_config\"])\n",
+ " model_peft = get_peft_model(model, lora_config)\n",
+ " model_peft.print_trainable_parameters()\n",
+ " model_peft.config.use_cache = False\n",
+ "\n",
+ " trainer = transformers.Trainer(\n",
+ " model=model_peft,\n",
+ " data_collator= collator,\n",
+ " args=transformers.TrainingArguments(**wandb.config[\"training_args\"]),\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=val_dataset\n",
+ " )\n",
+ "\n",
+ "\n",
+ " trainer.train()\n",
+ " run.log_code()"
+ ],
+ "id": "cttlNl_3XkgF"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a936e007-47fa-400f-873c-5e038bd685c5"
+ },
+ "source": [
+ "## Full Eval Dataset evaluation"
+ ],
+ "id": "a936e007-47fa-400f-873c-5e038bd685c5"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "399fafb9-b41f-401b-a1c4-37e1ede2e639"
+ },
+ "source": [
+ "Let's log a table with model predictions on the eval_dataset (or at least the 250 first samples)"
+ ],
+ "id": "399fafb9-b41f-401b-a1c4-37e1ede2e639"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OjDqorpCB7K-"
+ },
+ "outputs": [],
+ "source": [
+ "def create_prompt(row):\n",
+ " return prompt_no_input(row) if row[\"input\"] == \"\" else prompt_input(row)\n",
+ "\n",
+ "def prompt_no_input(row):\n",
+ " return (\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n",
+ " \"### Instruction:{instruction} \\n\\n ###Response\").format_map(row)\n",
+ "def prompt_input(row):\n",
+ " return (\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\"\n",
+ " \"### Instruction:{instruction} \\n\\n Input:{input} \\n\\n ###Response\").format_map(row)\n",
+ "\n",
+ "def pad_eos(ds):\n",
+ " EOS_TOKEN = \"\"\n",
+ " return [f\"{row['output']}{EOS_TOKEN}\" for row in ds]\n",
+ "\n",
+ "eval_prompts = [create_prompt(row) for row in val_dataset_alpaca]\n",
+ "eval_outputs = pad_eos(val_dataset_alpaca)\n",
+ "eval_dataset = [{\"prompt\":s, \"output\":t, \"example\": s + t} for s, t in zip(eval_prompts, eval_outputs)]"
+ ],
+ "id": "OjDqorpCB7K-"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "b6gVu8YEBXWF"
+ },
+ "outputs": [],
+ "source": [
+ "model_artifact_path ='' # change here!"
+ ],
+ "id": "b6gVu8YEBXWF"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3b89717a-ac70-43e8-9b7c-edd8d2c54cdc"
+ },
+ "outputs": [],
+ "source": [
+ "gen_config = GenerationConfig.from_pretrained(config[\"BASE_MODEL\"])\n",
+ "test_config = SimpleNamespace(\n",
+ " max_new_tokens=256,\n",
+ " gen_config=gen_config)\n",
+ "\n",
+ "def prompt_table(examples, log=False, table_name=\"predictions\"):\n",
+ " table = wandb.Table(columns=[\"prompt\", \"generation\", \"concat\", \"GPT-4 output\"])\n",
+ " for example in tqdm(examples, leave=False):\n",
+ " prompt, gpt4_output = example[\"prompt\"], example[\"output\"]\n",
+ " out = generate(prompt, test_config.max_new_tokens, test_config.gen_config)\n",
+ " table.add_data(prompt, out, prompt+out, gpt4_output)\n",
+ " if log:\n",
+ " wandb.log({table_name:table})\n",
+ " return table\n",
+ "\n",
+ "def generate(prompt, max_new_tokens=test_config.max_new_tokens, gen_config=gen_config):\n",
+ " tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()\n",
+ " with torch.inference_mode():\n",
+ " output = model.generate(tokenized_prompt,\n",
+ " max_new_tokens=max_new_tokens,\n",
+ " generation_config=gen_config,\n",
+ " temperature=0.9,\n",
+ " top_k=40,\n",
+ " top_p=0.70,\n",
+ " do_sample=True)\n",
+ " return tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)\n",
+ "\n",
+ "\n",
+ "with wandb.init(entity=wandb_entity,\n",
+ " project=wandb_project,\n",
+ " job_type=\"eval\",\n",
+ " config=config):\n",
+ "\n",
+ " artifact = wandb.use_artifact(model_artifact_path)\n",
+ " artifact_dir = artifact.download()\n",
+ "\n",
+ " merged_model = PeftModel.from_pretrained(model, artifact_dir)\n",
+ " merged_model = merged_model.merge_and_unload()\n",
+ "\n",
+ " merged_model.eval();\n",
+ " prompt_table(eval_dataset[:10], log=True, table_name=\"eval_predictions\")"
+ ],
+ "id": "3b89717a-ac70-43e8-9b7c-edd8d2c54cdc"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NQqpBZXAG0dj"
+ },
+ "source": [
+ "# (Advanced) Sweep"
+ ],
+ "id": "NQqpBZXAG0dj"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sOgTNIydwcwn"
+ },
+ "outputs": [],
+ "source": [
+ "#wandb.sdk.wandb_setup._setup(_reset=True)"
+ ],
+ "id": "sOgTNIydwcwn"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-9oEu7S6BfQe"
+ },
+ "outputs": [],
+ "source": [
+ "sweep_configuration= {\n",
+ " \"method\": \"random\",\n",
+ " \"metric\": {\"goal\": \"minimize\", \"name\": \"eval/loss\"},\n",
+ " \"parameters\": {\n",
+ " \"r\":{\"values\": [2,4,8,16,32]},\n",
+ " \"lora_alpha\":{\"values\": [2,4,8,16]},\n",
+ " \"learning_rate\":{'max': 2e-3, 'min': 2e-4}\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "default_config = {\n",
+ " \"BASE_MODEL\":\"facebook/opt-125m\",\n",
+ " \"lora_config\":{\n",
+ " \"r\":32,\n",
+ " \"lora_alpha\":16,\n",
+ " \"target_modules\":[f\"model.decoder.layers.{i}.self_attn.{proj}_proj\" for i in range(31) for proj in ['q', 'k', 'v']],\n",
+ " \"lora_dropout\":.1,\n",
+ " \"bias\":\"none\",\n",
+ " \"task_type\":\"CAUSAL_LM\"\n",
+ " },\n",
+ " \"training_args\":{\n",
+ " \"dataloader_num_workers\":16,\n",
+ " \"evaluation_strategy\":\"steps\",\n",
+ " \"per_device_train_batch_size\":8,\n",
+ " \"max_steps\": 50,\n",
+ " \"gradient_accumulation_steps\":2,\n",
+ " \"report_to\":\"wandb\",#wandb integration\n",
+ " \"warmup_steps\":10,\n",
+ " \"num_train_epochs\":1,\n",
+ " \"learning_rate\":2e-4,\n",
+ " \"fp16\":True,\n",
+ " \"logging_steps\":10,\n",
+ " \"save_steps\":10,\n",
+ " \"output_dir\":'./outputs'\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "\n",
+ "def train_func():\n",
+ " with wandb.init(project=wandb_project, config=config, job_type=\"training\") as run:\n",
+ " # Setup for LoRa\n",
+ " run.use_artifact(path_dataset_for_trainig)\n",
+ "\n",
+ " default_config[\"lora_config\"][\"r\"] = wandb.config[\"r\"]\n",
+ " default_config[\"lora_config\"][\"lora_alpha\"] = wandb.config[\"lora_alpha\"]\n",
+ " default_config[\"training_args\"][\"learning_rate\"] = wandb.config[\"learning_rate\"]\n",
+ "\n",
+ " lora_config = LoraConfig(**default_config[\"lora_config\"])\n",
+ " model_peft = get_peft_model(model, lora_config)\n",
+ " model_peft.print_trainable_parameters()\n",
+ " model_peft.config.use_cache = False\n",
+ "\n",
+ " trainer = transformers.Trainer(\n",
+ " model=model_peft,\n",
+ " data_collator=collator,\n",
+ " args=transformers.TrainingArguments(**default_config[\"training_args\"]),\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=val_dataset\n",
+ " )\n",
+ " trainer.train()\n",
+ " run.log_code()\n",
+ "\n",
+ "sweep_id = wandb.sweep(sweep=sweep_configuration)\n",
+ "wandb.agent(sweep_id, function=train_func, count=20)"
+ ],
+ "id": "-9oEu7S6BfQe"
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "gpuType": "V100"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/colabs/llm-finetuning-handson/requirements.txt b/colabs/llm-finetuning-handson/requirements.txt
new file mode 100644
index 00000000..5bc861c0
--- /dev/null
+++ b/colabs/llm-finetuning-handson/requirements.txt
@@ -0,0 +1,252 @@
+Babel==2.13.0
+Bottleneck==1.3.7
+Brotli==1.1.0
+Cython==3.0.4
+GitPython==3.1.40
+Jinja2==3.1.2
+Mako==1.2.4
+MarkupSafe==2.1.3
+Pillow==10.1.0
+PyJWT==2.8.0
+PySocks==1.7.1
+PyWavelets==1.4.1
+PyYAML==6.0.1
+Pygments==2.16.1
+SQLAlchemy==2.0.22
+Send2Trash==1.8.2
+accelerate==0.28.0
+aiohttp==3.8.6
+aiosignal==1.3.1
+alembic==1.12.0
+altair==5.1.2
+anyio==4.0.0
+appdirs==1.4.4
+argon2-cffi-bindings==21.2.0
+argon2-cffi==23.1.0
+arrow==1.3.0
+asttokens==2.4.0
+async-generator==1.10
+async-lru==2.0.4
+async-timeout==4.0.3
+attrs==23.1.0
+backcall==0.2.0
+backports.functools-lru-cache==1.6.5
+beautifulsoup4==4.12.2
+bitsandbytes==0.43.0
+bleach==6.1.0
+blinker==1.6.3
+bokeh==3.3.0
+boltons==23.0.0
+cached-property==1.5.2
+certifi==2023.7.22
+certipy==0.1.3
+cffi==1.16.0
+charset-normalizer==3.3.0
+click==8.1.7
+cloudpickle==3.0.0
+colorama==0.4.6
+comm==0.1.4
+conda-package-handling==2.2.0
+conda==23.9.0
+conda_package_streaming==0.9.0
+contourpy==1.1.1
+cryptography==41.0.4
+ctranslate2==4.1.0
+cycler==0.12.1
+cytoolz==0.12.2
+dask==2023.10.0
+datasets==2.18.0
+debugpy==1.8.0
+decorator==5.1.1
+defusedxml==0.7.1
+dill==0.3.8
+distributed==2023.10.0
+docker-pycreds==0.4.0
+entrypoints==0.4
+et-xmlfile==1.1.0
+exceptiongroup==1.1.3
+executing==1.2.0
+fastjsonschema==2.18.1
+filelock==3.13.1
+fonttools==4.43.1
+fqdn==1.5.1
+frozenlist==1.4.0
+fsspec==2023.9.2
+gitdb==4.0.11
+gmpy2==2.1.2
+greenlet==3.0.0
+h5py==3.10.0
+huggingface-hub==0.21.4
+idna==3.4
+imagecodecs==2023.9.18
+imageio==2.31.5
+importlib-metadata==6.8.0
+importlib-resources==6.1.0
+ipykernel==6.25.2
+ipympl==0.9.3
+ipython-genutils==0.2.0
+ipython==8.16.1
+ipywidgets==8.1.1
+isoduration==20.11.0
+jedi==0.19.1
+joblib==1.3.2
+json5==0.9.14
+jsonpatch==1.33
+jsonpointer==2.4
+jsonschema-specifications==2023.7.1
+jsonschema==4.19.1
+jupyter-events==0.8.0
+jupyter-lsp==2.2.0
+jupyter-pluto-proxy==0.1.2
+jupyter-server-mathjax==0.2.6
+jupyter-telemetry==0.1.0
+jupyter_client==8.4.0
+jupyter_core==5.4.0
+jupyter_server==2.8.0
+jupyter_server_proxy==4.1.0
+jupyter_server_terminals==0.4.4
+jupyterhub==4.0.2
+jupyterlab-git==0.41.0
+jupyterlab-pygments==0.2.2
+jupyterlab-widgets==3.0.9
+jupyterlab==4.0.7
+jupyterlab_server==2.25.0
+kiwisolver==1.4.5
+lazy_loader==0.3
+libmambapy==1.5.2
+llvmlite==0.40.1
+locket==1.0.0
+loralib==0.1.2
+lz4==4.3.2
+mamba==1.5.2
+matplotlib-inline==0.1.6
+matplotlib==3.8.0
+mistune==3.0.1
+mpmath==1.3.0
+msgpack==1.0.6
+multidict==6.0.4
+multiprocess==0.70.16
+munkres==1.1.4
+nbclassic==1.0.0
+nbclient==0.8.0
+nbconvert==7.9.2
+nbdime==3.2.1
+nbformat==5.9.2
+nest-asyncio==1.5.8
+networkx==3.2
+notebook==7.0.6
+notebook_shim==0.2.3
+numba==0.57.1
+numexpr==2.8.7
+numpy==1.24.4
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.19.3
+nvidia-nvjitlink-cu12==12.4.99
+nvidia-nvtx-cu12==12.1.105
+oauthlib==3.2.2
+openpyxl==3.1.2
+overrides==7.4.0
+packaging==23.2
+pamela==1.1.0
+pandas==2.1.1
+pandocfilters==1.5.0
+parso==0.8.3
+partd==1.4.1
+patsy==0.5.3
+peft==0.9.1.dev0
+pexpect==4.8.0
+pickleshare==0.7.5
+pip==23.3
+pkgutil_resolve_name==1.3.10
+platformdirs==3.11.0
+pluggy==1.3.0
+prometheus-client==0.17.1
+prompt-toolkit==3.0.39
+protobuf==4.24.3
+psutil==5.9.5
+ptyprocess==0.7.0
+pure-eval==0.2.2
+py-cpuinfo==9.0.0
+pyOpenSSL==23.2.0
+pyarrow-hotfix==0.6
+pyarrow==13.0.0
+pycosat==0.6.6
+pycparser==2.21
+pycurl==7.45.1
+pyparsing==3.1.1
+python-dateutil==2.8.2
+python-json-logger==2.0.7
+pytz==2023.3.post1
+pyzmq==25.1.1
+referencing==0.30.2
+regex==2023.12.25
+requests==2.31.0
+rfc3339-validator==0.1.4
+rfc3986-validator==0.1.1
+rpds-py==0.10.6
+rpy2==3.5.11
+ruamel.yaml.clib==0.2.7
+ruamel.yaml==0.17.39
+safetensors==0.4.2
+scikit-image==0.22.0
+scikit-learn==1.3.1
+scipy==1.11.3
+seaborn==0.13.0
+sentry-sdk==1.43.0
+setproctitle==1.3.3
+setuptools==68.2.2
+simpervisor==1.0.0
+simplegeneric==0.8.1
+six==1.16.0
+smmap==5.0.0
+sniffio==1.3.0
+sortedcontainers==2.4.0
+soupsieve==2.5
+stack-data==0.6.2
+statsmodels==0.14.0
+sympy==1.12
+tables==3.9.1
+tblib==2.0.0
+terminado==0.17.1
+threadpoolctl==3.2.0
+tifffile==2023.9.26
+tinycss2==1.2.1
+tokenizers==0.15.2
+tomli==2.0.1
+toolz==0.12.0
+torch==2.2.1
+tornado==6.3.3
+tqdm==4.66.1
+traitlets==5.11.2
+transformers==4.40.0.dev0
+triton==2.2.0
+truststore==0.8.0
+types-python-dateutil==2.8.19.14
+typing-utils==0.1.0
+typing_extensions==4.8.0
+tzdata==2023.3
+tzlocal==5.1
+uri-template==1.3.0
+urllib3==2.0.7
+wandb==0.16.4
+wcwidth==0.2.8
+webcolors==1.13
+webencodings==0.5.1
+websocket-client==1.6.4
+wheel==0.41.2
+widgetsnbextension==4.0.9
+xlrd==2.0.1
+xxhash==3.4.1
+xyzservices==2023.10.0
+yarl==1.9.2
+zict==3.0.0
+zipp==3.17.0
+zstandard==0.21.0
\ No newline at end of file