From 4861efd8eb9f609a74ef531160401c71567b1c63 Mon Sep 17 00:00:00 2001 From: Paul Paczuski <6165713+plpxsk@users.noreply.github.com> Date: Fri, 27 Sep 2024 12:29:34 -0400 Subject: [PATCH] Add tf requirements and initial nb --- alt/tf/requirements.txt | 1 + tf_qa.ipynb | 1162 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 1163 insertions(+) create mode 100644 alt/tf/requirements.txt create mode 100644 tf_qa.ipynb diff --git a/alt/tf/requirements.txt b/alt/tf/requirements.txt new file mode 100644 index 0000000..0f57144 --- /dev/null +++ b/alt/tf/requirements.txt @@ -0,0 +1 @@ +tensorflow diff --git a/tf_qa.ipynb b/tf_qa.ipynb new file mode 100644 index 0000000..65f1562 --- /dev/null +++ b/tf_qa.ipynb @@ -0,0 +1,1162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b049f8e0-c381-496f-9a17-17e48313f94a", + "metadata": {}, + "source": [ + "# With Tensorflow\n", + "\n", + "Maybe don't do full tensorflow alt implementation\n", + "\n", + "Instead:\n", + "* use a mix of existing code (HF, PT, MLX)\n", + "* and use functions in TF, eg, from google/bert" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c364076-fff7-46fb-bef1-bc123078d3cf", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "3b9b1637-3eca-40c3-89b4-bb9a809b0c10", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "import torch\n", + "import torch.nn as tnn\n", + "\n", + "import mlx.nn as nn\n", + "import mlx.core as mx\n", + "\n", + "from alt.pt.infer import load_model_tokenizer\n", + "from utils import preprocess_tokenize_function" + ] + }, + { + "cell_type": "markdown", + "id": "f535ed54-0e6c-421b-8c84-75f09bae409c", + "metadata": {}, + "source": [ + "# Load" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f093fb8b-5b2d-4d49-b9a6-e22c660daae1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "04166153-91fe-4c02-9940-12a511496525", + "metadata": {}, + "source": [ + "# Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f63ad29-cbbf-44f5-b53f-68784f7bdcc2", + "metadata": {}, + "outputs": [], + "source": [ + "model, tokenizer = load_model_tokenizer(device=\"mps\")\n", + "\n", + "inputs = tokenizer(question, context, return_tensors=\"pt\")\n", + "inputs.to(\"mps\")\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**inputs)\n", + "\n", + "inputs.keys(), outputs.keys()" + ] + }, + { + "cell_type": "markdown", + "id": "2cfb2d5a-0b42-4f2e-9c7d-7858ae89873f", + "metadata": {}, + "source": [ + "### Outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "id": "b5d64f84-d3db-4aa1-8233-76268067989e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_out():\n", + " start_logits = [0.9, 1.9, 0.1, 10.9, 1.5]\n", + " end_logits = [1.9, 0.8, 0.2, 3.3, 11.5]\n", + " start_positions = [5, 1, 3, 11, 2]\n", + " end_positions = [1, 6, 4, 9, 12]\n", + " return start_logits, end_logits, start_positions, end_positions" + ] + }, + { + "cell_type": "markdown", + "id": "973f079f-3d86-4f7b-afee-e1351fc35a77", + "metadata": {}, + "source": [ + "# Compute loss\n", + "\n", + "with TF, HF/PT, etc" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "089ffe7d-b8a2-4d86-a1e6-151c787b1d40", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"How many programming languages does BLOOM support?\"\n", + "context = \"BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages.\"" + ] + }, + { + "cell_type": "markdown", + "id": "db02012b-8433-4f70-a52f-56dc22978958", + "metadata": {}, + "source": [ + "### MLX" + ] + }, + { + "cell_type": "code", + "execution_count": 223, + "id": "0fca502f-dc10-4793-9daa-e8efc31a81e9", + "metadata": {}, + "outputs": [], + "source": [ + "start_logits, end_logits, start_positions, end_positions = (\n", + " mx.array(x, dtype=mx.float32) for x in get_out() \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 224, + "id": "3ec604b5-f1c6-42b0-b1a6-f6bee0b1fd9e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5,), array([0.9, 1.9, 0.1, 10.9, 1.5], dtype=float32))" + ] + }, + "execution_count": 224, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_logits.shape, start_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 234, + "id": "46401661-84a3-4c87-a715-be1bf361b8c0", + "metadata": {}, + "outputs": [], + "source": [ + "def ce(l, p):\n", + " return nn.losses.cross_entropy(l, p, reduction=\"none\").item()\n", + "\n", + "def ce2(l, p):\n", + " return nn.losses.nll_loss(nn.log_softmax(l), p).item()" + ] + }, + { + "cell_type": "code", + "execution_count": 243, + "id": "7df9daf8-0852-451e-8084-280b3c5d5172", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(-118.69972229003906, -163.69961547851562)" + ] + }, + "execution_count": 243, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ce(start_logits, start_positions), ce(end_logits, end_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 245, + "id": "51f1019c-3790-49ba-9258-4446808ff521", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([-10.0003, -9.00027, -10.8003, -0.000271797, -9.40027], dtype=float32),\n", + " (5,))" + ] + }, + "execution_count": 245, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = nn.log_softmax(start_logits)\n", + "x, x.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 303, + "id": "13f7c455-42bd-48f0-b7e7-d18728d820af", + "metadata": {}, + "outputs": [], + "source": [ + "logits = mx.array([[0.2, -2.0], [-1.9, 3.8]])\n", + "targets = mx.array([0, 1])\n", + "\n", + "logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])\n", + "targets = mx.array([0, 0, 1, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": 304, + "id": "e365e3ce-587b-44c9-87b8-6de2c1d1a198", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.105361, 0.223144, 1.20397, 0.916291], dtype=float32)" + ] + }, + "execution_count": 304, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits" + ] + }, + { + "cell_type": "code", + "execution_count": 305, + "id": "70de58a5-f70d-44f4-8d00-57099f500792", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 1, 1], dtype=int32)" + ] + }, + "execution_count": 305, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "targets" + ] + }, + { + "cell_type": "code", + "execution_count": 312, + "id": "51524755-cdc2-4a29-a7f3-6bc35ced9af5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(0, dtype=float32)" + ] + }, + "execution_count": 312, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mx.log(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 309, + "id": "9d4e7603-1f39-430e-9524-e2fb6100cea7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(0.105361, dtype=float32)" + ] + }, + "execution_count": 309, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 306, + "id": "f314a404-285c-4aff-ad79-230c91c49b09", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.747215, 0.81093, 0.262365, 0.336472], dtype=float32)" + ] + }, + "execution_count": 306, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.losses.binary_cross_entropy(logits, targets, with_logits=True, reduction=\"none\")" + ] + }, + { + "cell_type": "code", + "execution_count": 313, + "id": "068f7183-1e35-45c1-8437-e26837c119b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.747215, 0.81093, 0.262365, 0.336472], dtype=float32)" + ] + }, + "execution_count": 313, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mx.logaddexp(0.0, logits) - logits * targets" + ] + }, + { + "cell_type": "code", + "execution_count": 314, + "id": "35266440-5eb3-4933-9c6d-118444e41fc5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-0.0168054, dtype=float32)" + ] + }, + "execution_count": 314, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.losses.cross_entropy(logits, targets, reduction=\"none\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ce95ba9-e7dd-4a79-8785-7f393f9c9433", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1dc7daf-a04e-497e-8593-6282dfd1b7f2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2459c1d2-487b-4440-b4e2-58dbb0b79869", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7558418a-db14-49d0-92bc-f6ec4030ec9d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 276, + "id": "077df03a-3ecf-4e07-b15b-ba5cb3f3a08f", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "[take_along_axis] Indices of dimension 2 does not match array of dimension 1.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[276], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlogits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Envs/PyTorTenFloHuffinFace/lib/python3.12/site-packages/mlx/nn/losses.py:249\u001b[0m, in \u001b[0;36mnll_loss\u001b[0;34m(inputs, targets, axis, reduction)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnll_loss\u001b[39m(\n\u001b[1;32m 234\u001b[0m inputs: mx\u001b[38;5;241m.\u001b[39marray, targets: mx\u001b[38;5;241m.\u001b[39marray, axis: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, reduction: Reduction \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 235\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m mx\u001b[38;5;241m.\u001b[39marray:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124;03m Computes the negative log likelihood loss.\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;124;03m array: The computed NLL loss.\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 249\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtake_along_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce(loss, reduction)\n", + "\u001b[0;31mValueError\u001b[0m: [take_along_axis] Indices of dimension 2 does not match array of dimension 1." + ] + } + ], + "source": [ + "nn.losses.nll_loss(logits, targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 259, + "id": "04e75c4a-e817-4ab8-b8be-a07ea72a15c5", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "[take_along_axis] Indices of dimension 2 does not match array of dimension 1.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[259], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;241m-\u001b[39m\u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtake_along_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_logits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_positions\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mValueError\u001b[0m: [take_along_axis] Indices of dimension 2 does not match array of dimension 1." + ] + } + ], + "source": [ + "-mx.take_along_axis(start_logits, start_positions[..., None], axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a457ea94-cf4a-492a-a5e6-92c2411bcce9", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 251, + "id": "8a09365d-2e94-4b02-b4f3-cd665ca782ad", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "[gather] Got indices with invalid dtype. Indices must be integral.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[251], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_positions\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Envs/PyTorTenFloHuffinFace/lib/python3.12/site-packages/mlx/nn/losses.py:249\u001b[0m, in \u001b[0;36mnll_loss\u001b[0;34m(inputs, targets, axis, reduction)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnll_loss\u001b[39m(\n\u001b[1;32m 234\u001b[0m inputs: mx\u001b[38;5;241m.\u001b[39marray, targets: mx\u001b[38;5;241m.\u001b[39marray, axis: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, reduction: Reduction \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 235\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m mx\u001b[38;5;241m.\u001b[39marray:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124;03m Computes the negative log likelihood loss.\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;124;03m array: The computed NLL loss.\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 249\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtake_along_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce(loss, reduction)\n", + "\u001b[0;31mValueError\u001b[0m: [gather] Got indices with invalid dtype. Indices must be integral." + ] + } + ], + "source": [ + "nn.losses.nll_loss(x[..., None], start_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": 252, + "id": "b7d704ea-72d0-4787-9236-7080199103e7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5,), (5, 1))" + ] + }, + "execution_count": 252, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.shape, x[..., None].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 254, + "id": "f72106e3-868f-4cfa-ad5a-ab682a407b60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(mlx.core.float32, mlx.core.float32)" + ] + }, + "execution_count": 254, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.dtype, x[..., None].dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 248, + "id": "1d8b0477-a645-4c7c-b954-0fc79e3671a5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(5, 1)" + ] + }, + "execution_count": 248, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_positions[..., None].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17369360-0a82-47b1-83e0-80b06c874174", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f44b587a-a54f-4adb-b175-405aac1f5919", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 242, + "id": "c0f7da6b-0873-4c1e-ad0f-b09584a259e4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(5,)" + ] + }, + "execution_count": 242, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_positions.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 240, + "id": "9f1e1497-1d3a-4453-ba91-b5739030325a", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "[take_along_axis] Indices of dimension 2 does not match array of dimension 1.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[240], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_positions\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Envs/PyTorTenFloHuffinFace/lib/python3.12/site-packages/mlx/nn/losses.py:249\u001b[0m, in \u001b[0;36mnll_loss\u001b[0;34m(inputs, targets, axis, reduction)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnll_loss\u001b[39m(\n\u001b[1;32m 234\u001b[0m inputs: mx\u001b[38;5;241m.\u001b[39marray, targets: mx\u001b[38;5;241m.\u001b[39marray, axis: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, reduction: Reduction \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 235\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m mx\u001b[38;5;241m.\u001b[39marray:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124;03m Computes the negative log likelihood loss.\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;124;03m array: The computed NLL loss.\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 249\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtake_along_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce(loss, reduction)\n", + "\u001b[0;31mValueError\u001b[0m: [take_along_axis] Indices of dimension 2 does not match array of dimension 1." + ] + } + ], + "source": [ + "nn.losses.nll_loss(x, start_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5602281-6fbe-414b-83aa-fb2b9a0150f1", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7c731d5-e74a-4adc-9424-abd78f7956fd", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 236, + "id": "385d4590-4f19-4d88-bd8d-43d0b192c283", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "[take_along_axis] Indices of dimension 2 does not match array of dimension 1.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[236], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mce2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_logits\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_positions\u001b[49m\u001b[43m)\u001b[49m, ce2(end_logits, end_positions)\n", + "Cell \u001b[0;32mIn[234], line 5\u001b[0m, in \u001b[0;36mce2\u001b[0;34m(l, p)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mce2\u001b[39m(l, p):\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnll_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_softmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43ml\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mitem()\n", + "File \u001b[0;32m~/Envs/PyTorTenFloHuffinFace/lib/python3.12/site-packages/mlx/nn/losses.py:249\u001b[0m, in \u001b[0;36mnll_loss\u001b[0;34m(inputs, targets, axis, reduction)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnll_loss\u001b[39m(\n\u001b[1;32m 234\u001b[0m inputs: mx\u001b[38;5;241m.\u001b[39marray, targets: mx\u001b[38;5;241m.\u001b[39marray, axis: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, reduction: Reduction \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 235\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m mx\u001b[38;5;241m.\u001b[39marray:\n\u001b[1;32m 236\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;124;03m Computes the negative log likelihood loss.\u001b[39;00m\n\u001b[1;32m 238\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;124;03m array: The computed NLL loss.\u001b[39;00m\n\u001b[1;32m 248\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 249\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtake_along_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _reduce(loss, reduction)\n", + "\u001b[0;31mValueError\u001b[0m: [take_along_axis] Indices of dimension 2 does not match array of dimension 1." + ] + } + ], + "source": [ + "ce2(start_logits, start_positions), ce2(end_logits, end_positions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d34b4997-110d-48d2-8809-0076512cdaee", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0f3d366e-b7c9-4cae-b06b-b9d226835e88", + "metadata": {}, + "source": [ + "### PT\n", + "\n", + "https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/bert/modeling_bert.py#L1996\n", + "\n", + "such as:\n", + "\n", + "```\n", + "total_loss = None\n", + "if start_positions is not None and end_positions is not None:\n", + " # If we are on multi-GPU, split add a dimension\n", + " if len(start_positions.size()) > 1:\n", + " start_positions = start_positions.squeeze(-1)\n", + " if len(end_positions.size()) > 1:\n", + " end_positions = end_positions.squeeze(-1)\n", + " # sometimes the start/end positions are outside our model inputs, we ignore these terms\n", + " ignored_index = start_logits.size(1)\n", + " start_positions = start_positions.clamp(0, ignored_index)\n", + " end_positions = end_positions.clamp(0, ignored_index)\n", + "\n", + " loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n", + " start_loss = loss_fct(start_logits, start_positions)\n", + " end_loss = loss_fct(end_logits, end_positions)\n", + " total_loss = (start_loss + end_loss) / 2\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 181, + "id": "e01751cc-b256-42e5-a205-692674834834", + "metadata": {}, + "outputs": [], + "source": [ + "start_logits, end_logits, start_positions, end_positions = (\n", + " torch.tensor(x, dtype=torch.float32) for x in get_out() \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 182, + "id": "326c6996-2407-4069-a912-31da3c71dd99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([5]),\n", + " torch.Size([5]),\n", + " tensor([ 0.9000, 1.9000, 0.1000, 10.9000, 1.5000]))" + ] + }, + "execution_count": 182, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_logits.shape, start_logits.size(), start_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "id": "aac5885e-a6d8-460b-97f0-413bb61af293", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor(110.2060), tensor(192.8121))" + ] + }, + "execution_count": 189, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_fct = tnn.CrossEntropyLoss(reduction=\"none\") #ignore_index=ignored_index)\n", + "loss_fct(start_logits, start_positions), loss_fct(end_logits, end_positions)" + ] + }, + { + "cell_type": "markdown", + "id": "1e55e001-74af-4506-b479-f62f6ce42335", + "metadata": {}, + "source": [ + "### TF\n", + "\n", + "from https://github.com/google-research/bert/blob/master/run_squad.py#L646" + ] + }, + { + "cell_type": "code", + "execution_count": 190, + "id": "45804292-7143-4e08-a752-77dd3cf4bf3c", + "metadata": {}, + "outputs": [], + "source": [ + "start_logits, end_logits, start_positions, end_positions = (\n", + " tf.convert_to_tensor(x, dtype=tf.float32) for x in get_out() \n", + ")\n", + "\n", + "start_positions, end_positions = (\n", + " tf.cast(x, dtype=tf.int32) for x in \n", + " (start_positions, end_positions)\n", + ")\n", + "\n", + "# run_squad.py has:\n", + "# seq_length = modeling.get_shape_list(input_ids)[1]\n", + "seq_length = len(start_logits)" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "id": "924a64a3-a820-4a7d-aed8-28ac4279920d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(TensorShape([5]),\n", + " )" + ] + }, + "execution_count": 191, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_logits.shape, start_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "id": "933b8997-27a6-4c12-8e28-bd04a8a7e747", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 192, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_positions" + ] + }, + { + "cell_type": "code", + "execution_count": 193, + "id": "0a229e7d-8535-4302-8481-2ba666b74ffb", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_loss(logits, positions):\n", + " one_hot_positions = tf.one_hot(\n", + " positions, depth=seq_length, dtype=tf.float32)\n", + " log_probs = tf.nn.log_softmax(logits, axis=-1)\n", + " loss = -tf.reduce_mean(\n", + " tf.reduce_sum(one_hot_positions * log_probs, axis=-1))\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 194, + "id": "f6ed4032-3ae3-4f81-98fb-4d97c543623c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3.960163, 2.140151)" + ] + }, + "execution_count": 194, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compute_loss(start_logits, start_positions).numpy(), compute_loss(end_logits, end_positions).numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "5450a4ce-e205-4a52-a5dd-f755fc825b5e", + "metadata": {}, + "source": [ + "Check components" + ] + }, + { + "cell_type": "code", + "execution_count": 195, + "id": "7f763c02-896c-4112-b499-dc001277267c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 195, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tf.one_hot(start_positions, depth=seq_length, dtype=tf.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 201, + "id": "949c2afc-7fd1-4b30-8e3f-978fb509b907", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-1.0000272e+01, -9.0002718e+00, -1.0800271e+01, -2.7187943e-04,\n", + " -9.4002714e+00], dtype=float32)" + ] + }, + "execution_count": 201, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tf.nn.log_softmax(start_logits, axis=-1).numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 203, + "id": "048648c9-944c-4de6-bbf7-11d84133c5c7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-10.0003, -9.00027, -10.8003, -0.000271797, -9.40027], dtype=float32)" + ] + }, + "execution_count": 203, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.log_softmax(mx.array(start_logits), axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "id": "05c8006c-e1e0-4608-885d-044fc4f17293", + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax. Maybe you meant '==' or ':=' instead of '='? (1215032111.py, line 1)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m Cell \u001b[0;32mIn[204], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m -tf.reduce_mean((one_hot_positions * log_probs, axis=-1))\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax. Maybe you meant '==' or ':=' instead of '='?\n" + ] + } + ], + "source": [ + "-tf.reduce_mean((one_hot_positions * log_probs, axis=-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 206, + "id": "7f461951-20ec-4e76-abd9-14d594fbec92", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-118.7, dtype=float32)" + ] + }, + "execution_count": 206, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 207, + "id": "2fa9acda-fea7-4fab-afe2-eef217f3c4f8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-118.7, dtype=float32)" + ] + }, + "execution_count": 207, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.losses.cross_entropy(mx.array(start_logits), mx.array(start_positions), reduction='sum')" + ] + }, + { + "cell_type": "code", + "execution_count": 208, + "id": "5fc92270-7c0e-422a-afee-91a09fbe82ca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-118.7, dtype=float32)" + ] + }, + "execution_count": 208, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.losses.cross_entropy(mx.array(start_logits), mx.array(start_positions), reduction='mean')" + ] + }, + { + "cell_type": "code", + "execution_count": 209, + "id": "ba17df81-09a5-4d47-a9a1-fec49a94092f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-22.5947, dtype=float32)" + ] + }, + "execution_count": 209, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.losses.binary_cross_entropy(mx.array(start_logits), mx.array(start_positions))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20908ff7-85e1-4319-a886-b36b2b44c758", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "399f8628-547e-43e6-9cf4-a46ef8a5fb7a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "683279f5-5716-453d-afc8-90635c59643f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16a96a53-6a7d-43c8-b7c9-ce3787eec743", + "metadata": {}, + "outputs": [], + "source": [ + "start_positions = features[\"start_positions\"]\n", + "end_positions = features[\"end_positions\"]\n", + "\n", + "start_loss = compute_loss(start_logits, start_positions)\n", + "end_loss = compute_loss(end_logits, end_positions)\n", + "\n", + "total_loss = (start_loss + end_loss) / 2.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "189b2606-7ed2-4117-97f9-cd683b7f8253", + "metadata": {}, + "outputs": [], + "source": [ + "def loss_fn(model, input_ids, token_type_ids, attention_mask, start_positions,\n", + " end_positions, reduce=True):\n", + " start_logits, end_logits = model(\n", + " input_ids=input_ids,\n", + " token_type_ids=token_type_ids,\n", + " attention_mask=attention_mask)\n", + " slosses = nn.losses.cross_entropy(start_logits, start_positions)\n", + " elosses = nn.losses.cross_entropy(end_logits, end_positions)\n", + " if reduce:\n", + " slosses = mx.mean(slosses)\n", + " elosses = mx.mean(elosses)\n", + " loss = (slosses + elosses) / 2\n", + " return loss\n", + "\n", + "\n", + "def loss_from_logits_positions(logits, positions):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7596bc03-a4eb-4ae6-aaf0-b2f98f709449", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc94236-d8a1-4b06-ab76-780b0d884fa5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3e75aab-23c0-4c36-abdf-de8b2cc3fb4c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Envs_PyTorTenFloHuffinFace", + "language": "python", + "name": "envs_pytortenflohuffinface" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}