diff --git a/examples/data/mnist.py b/examples/data/mnist.py index f5cce3e..9339fc0 100644 --- a/examples/data/mnist.py +++ b/examples/data/mnist.py @@ -44,7 +44,7 @@ def parse_labels(filepath): test_labels = parse_labels(os.path.join(data_dir, "t10k-labels-idx1-ubyte.gz")) if normalize: - train_images = train_images.astype(np.float32) / 255.0 - test_images = test_images.astype(np.float32) / 255.0 + train_images = train_images.astype(np.float32) / 127.5 - 1. + test_images = test_images.astype(np.float32) / 127.5 - 1. return train_images, train_labels, test_images, test_labels \ No newline at end of file diff --git a/examples/hello-gpt.ipynb b/examples/hello-gpt.ipynb index c3c171f..a8ed039 100644 --- a/examples/hello-gpt.ipynb +++ b/examples/hello-gpt.ipynb @@ -32,21 +32,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading Shakespeare dataset...\n", - "Processing Shakespeare dataset...\n", - "Length of dataset in characters: 1,115,394\n", - "Vocabulary size: 65\n", - "Train has 1,003,854 tokens\n", - "Val has 111,540 tokens\n", - "Shakespeare dataset processing complete.\n" - ] - } - ], + "outputs": [], "source": [ "context = 64\n", "batch_size = 12\n", @@ -362,228 +348,228 @@ "name": "stdout", "output_type": "stream", "text": [ - "Step 0: loss 4.226325988769531\n", - "--> val loss 4.179544448852539\n", - "Step 10: loss 3.8738746643066406\n", - "Step 20: loss 3.3448646068573\n", - "Step 30: loss 2.805002212524414\n", - "Step 40: loss 2.68573260307312\n", - "Step 50: loss 2.6098480224609375\n", - "Step 60: loss 2.407468557357788\n", - "Step 70: loss 2.418379783630371\n", - "Step 80: loss 2.359757423400879\n", - "Step 90: loss 2.2685279846191406\n", - "Step 100: loss 2.314124584197998\n", - "--> val loss 2.541980743408203\n", - "Step 110: loss 2.283424139022827\n", - "Step 120: loss 2.2063167095184326\n", - "Step 130: loss 2.1598031520843506\n", - "Step 140: loss 2.252727508544922\n", - "Step 150: loss 2.124152660369873\n", - "Step 160: loss 2.23785662651062\n", - "Step 170: loss 2.2059123516082764\n", - "Step 180: loss 2.102996587753296\n", - "Step 190: loss 2.132392168045044\n", - "Step 200: loss 2.130244255065918\n", - "--> val loss 2.359212636947632\n", - "Step 210: loss 2.0895276069641113\n", - "Step 220: loss 2.1278815269470215\n", - "Step 230: loss 1.9647449254989624\n", - "Step 240: loss 2.1118733882904053\n", - "Step 250: loss 1.9459623098373413\n", - "Step 260: loss 2.118051290512085\n", - "Step 270: loss 2.0605385303497314\n", - "Step 280: loss 2.0378551483154297\n", - "Step 290: loss 2.0237479209899902\n", - "Step 300: loss 1.982785940170288\n", - "--> val loss 2.2887392044067383\n", - "Step 310: loss 2.073058605194092\n", - "Step 320: loss 2.082066535949707\n", - "Step 330: loss 2.130162239074707\n", - "Step 340: loss 2.092909336090088\n", - "Step 350: loss 1.9229984283447266\n", - "Step 360: loss 1.9037134647369385\n", - "Step 370: loss 2.0083131790161133\n", - "Step 380: loss 2.0236263275146484\n", - "Step 390: loss 2.0116419792175293\n", - "Step 400: loss 2.091407299041748\n", - "--> val loss 2.2199790477752686\n", - "Step 410: loss 2.0855846405029297\n", - "Step 420: loss 1.8506882190704346\n", - "Step 430: loss 1.9745848178863525\n", - "Step 440: loss 1.9135173559188843\n", - "Step 450: loss 2.0486648082733154\n", - "Step 460: loss 1.983982801437378\n", - "Step 470: loss 1.9958977699279785\n", - "Step 480: loss 1.9868993759155273\n", - "Step 490: loss 2.009216785430908\n", - "Step 500: loss 2.073169231414795\n", - "--> val loss 2.141632556915283\n", - "Step 510: loss 2.0603322982788086\n", - "Step 520: loss 2.0025858879089355\n", - "Step 530: loss 1.9482192993164062\n", - "Step 540: loss 1.9092429876327515\n", - "Step 550: loss 2.109374761581421\n", - "Step 560: loss 1.9060167074203491\n", - "Step 570: loss 1.9423940181732178\n", - "Step 580: loss 1.9405231475830078\n", - "Step 590: loss 1.9132475852966309\n", - "Step 600: loss 2.0125274658203125\n", - "--> val loss 2.2273831367492676\n", - "Step 610: loss 2.0854687690734863\n", - "Step 620: loss 1.9796791076660156\n", - "Step 630: loss 1.982351303100586\n", - "Step 640: loss 2.044363021850586\n", - "Step 650: loss 2.030698299407959\n", - "Step 660: loss 2.0731544494628906\n", - "Step 670: loss 1.9660027027130127\n", - "Step 680: loss 1.933128833770752\n", - "Step 690: loss 1.8852118253707886\n", - "Step 700: loss 1.8401598930358887\n", - "--> val loss 2.0958476066589355\n", - "Step 710: loss 1.9790323972702026\n", - "Step 720: loss 2.0329394340515137\n", - "Step 730: loss 1.929424524307251\n", - "Step 740: loss 1.950282335281372\n", - "Step 750: loss 1.938680648803711\n", - "Step 760: loss 1.9717748165130615\n", - "Step 770: loss 1.8411779403686523\n", - "Step 780: loss 2.085500717163086\n", - "Step 790: loss 1.8778104782104492\n", - "Step 800: loss 1.9712986946105957\n", - "--> val loss 2.1469686031341553\n", - "Step 810: loss 1.949462652206421\n", - "Step 820: loss 1.9898126125335693\n", - "Step 830: loss 1.9045312404632568\n", - "Step 840: loss 1.9053363800048828\n", - "Step 850: loss 1.8944416046142578\n", - "Step 860: loss 1.8389015197753906\n", - "Step 870: loss 1.9189136028289795\n", - "Step 880: loss 2.0141639709472656\n", - "Step 890: loss 1.9987534284591675\n", - "Step 900: loss 1.947631597518921\n", - "--> val loss 2.1903281211853027\n", - "Step 910: loss 2.031083106994629\n", - "Step 920: loss 1.988853931427002\n", - "Step 930: loss 2.0356318950653076\n", - "Step 940: loss 1.8823192119598389\n", - "Step 950: loss 2.0429515838623047\n", - "Step 960: loss 2.021817684173584\n", - "Step 970: loss 2.003168821334839\n", - "Step 980: loss 2.0105528831481934\n", - "Step 990: loss 2.014195680618286\n", - "Step 1000: loss 1.9518741369247437\n", - "--> val loss 2.0813283920288086\n", - "Step 1010: loss 2.016996383666992\n", - "Step 1020: loss 2.04374098777771\n", - "Step 1030: loss 1.8839387893676758\n", - "Step 1040: loss 1.96620512008667\n", - "Step 1050: loss 2.0463950634002686\n", - "Step 1060: loss 1.9169645309448242\n", - "Step 1070: loss 2.038651943206787\n", - "Step 1080: loss 2.0474071502685547\n", - "Step 1090: loss 1.9452462196350098\n", - "Step 1100: loss 1.8884999752044678\n", - "--> val loss 2.1541106700897217\n", - "Step 1110: loss 1.9775495529174805\n", - "Step 1120: loss 1.96068274974823\n", - "Step 1130: loss 1.8553755283355713\n", - "Step 1140: loss 1.9422013759613037\n", - "Step 1150: loss 2.0833449363708496\n", - "Step 1160: loss 1.840619444847107\n", - "Step 1170: loss 2.032219409942627\n", - "Step 1180: loss 1.9345749616622925\n", - "Step 1190: loss 1.934565544128418\n", - "Step 1200: loss 1.9528722763061523\n", - "--> val loss 2.1688506603240967\n", - "Step 1210: loss 1.8676490783691406\n", - "Step 1220: loss 1.9311145544052124\n", - "Step 1230: loss 1.9905321598052979\n", - "Step 1240: loss 1.8773740530014038\n", - "Step 1250: loss 1.9832658767700195\n", - "Step 1260: loss 1.8256521224975586\n", - "Step 1270: loss 2.037313461303711\n", - "Step 1280: loss 1.9440114498138428\n", - "Step 1290: loss 1.9472723007202148\n", - "Step 1300: loss 1.862718105316162\n", - "--> val loss 2.0632894039154053\n", - "Step 1310: loss 1.944453239440918\n", - "Step 1320: loss 1.869157075881958\n", - "Step 1330: loss 1.9843480587005615\n", - "Step 1340: loss 1.9083728790283203\n", - "Step 1350: loss 1.920233130455017\n", - "Step 1360: loss 1.7926225662231445\n", - "Step 1370: loss 1.8765363693237305\n", - "Step 1380: loss 1.9374698400497437\n", - "Step 1390: loss 1.9032771587371826\n", - "Step 1400: loss 1.8976068496704102\n", - "--> val loss 2.0361690521240234\n", - "Step 1410: loss 1.8799960613250732\n", - "Step 1420: loss 1.9112414121627808\n", - "Step 1430: loss 1.8797309398651123\n", - "Step 1440: loss 1.9040837287902832\n", - "Step 1450: loss 1.8828296661376953\n", - "Step 1460: loss 1.83419930934906\n", - "Step 1470: loss 1.8327134847640991\n", - "Step 1480: loss 1.857541799545288\n", - "Step 1490: loss 1.8209788799285889\n", - "Step 1500: loss 1.780470371246338\n", - "--> val loss 2.0466208457946777\n", - "Step 1510: loss 1.8544996976852417\n", - "Step 1520: loss 1.8710064888000488\n", - "Step 1530: loss 1.8195044994354248\n", - "Step 1540: loss 1.874974250793457\n", - "Step 1550: loss 1.7101812362670898\n", - "Step 1560: loss 1.8439801931381226\n", - "Step 1570: loss 1.967679500579834\n", - "Step 1580: loss 1.888682246208191\n", - "Step 1590: loss 1.6926288604736328\n", - "Step 1600: loss 1.875901222229004\n", - "--> val loss 2.044935941696167\n", - "Step 1610: loss 1.8210939168930054\n", - "Step 1620: loss 1.7439773082733154\n", - "Step 1630: loss 1.7956527471542358\n", - "Step 1640: loss 1.792572021484375\n", - "Step 1650: loss 1.7985519170761108\n", - "Step 1660: loss 1.8520288467407227\n", - "Step 1670: loss 1.680544137954712\n", - "Step 1680: loss 1.7917392253875732\n", - "Step 1690: loss 1.8400462865829468\n", - "Step 1700: loss 1.6793416738510132\n", - "--> val loss 1.995697021484375\n", - "Step 1710: loss 1.7414367198944092\n", - "Step 1720: loss 1.8606326580047607\n", - "Step 1730: loss 1.7578084468841553\n", - "Step 1740: loss 1.6292760372161865\n", - "Step 1750: loss 1.7017428874969482\n", - "Step 1760: loss 1.8407533168792725\n", - "Step 1770: loss 1.7789411544799805\n", - "Step 1780: loss 1.802499532699585\n", - "Step 1790: loss 1.7586851119995117\n", - "Step 1800: loss 1.7281568050384521\n", - "--> val loss 1.9875770807266235\n", - "Step 1810: loss 1.7767337560653687\n", - "Step 1820: loss 1.7158925533294678\n", - "Step 1830: loss 1.7596324682235718\n", - "Step 1840: loss 1.7826766967773438\n", - "Step 1850: loss 1.7769875526428223\n", - "Step 1860: loss 1.6953961849212646\n", - "Step 1870: loss 1.7714271545410156\n", - "Step 1880: loss 1.6994340419769287\n", - "Step 1890: loss 1.7252253293991089\n", - "Step 1900: loss 1.566367506980896\n", - "--> val loss 1.9310436248779297\n", - "Step 1910: loss 1.7057380676269531\n", - "Step 1920: loss 1.7441104650497437\n", - "Step 1930: loss 1.7951183319091797\n", - "Step 1940: loss 1.8611491918563843\n", - "Step 1950: loss 1.787139654159546\n", - "Step 1960: loss 1.788725733757019\n", - "Step 1970: loss 1.7919573783874512\n", - "Step 1980: loss 1.706597089767456\n", - "Step 1990: loss 1.771501898765564\n", - "Step 2000: loss 1.7121562957763672\n", - "--> val loss 1.8968441486358643\n" + "Step 0: loss 4.203415393829346\n", + "--> val loss 4.184397220611572\n", + "Step 10: loss 3.8764960765838623\n", + "Step 20: loss 3.3471152782440186\n", + "Step 30: loss 2.8259036540985107\n", + "Step 40: loss 2.6758830547332764\n", + "Step 50: loss 2.6396656036376953\n", + "Step 60: loss 2.4299607276916504\n", + "Step 70: loss 2.4175362586975098\n", + "Step 80: loss 2.3901185989379883\n", + "Step 90: loss 2.2698512077331543\n", + "Step 100: loss 2.3211278915405273\n", + "--> val loss 2.493386745452881\n", + "Step 110: loss 2.280397415161133\n", + "Step 120: loss 2.2105860710144043\n", + "Step 130: loss 2.1742873191833496\n", + "Step 140: loss 2.265498161315918\n", + "Step 150: loss 2.1193413734436035\n", + "Step 160: loss 2.1948325634002686\n", + "Step 170: loss 2.171501636505127\n", + "Step 180: loss 2.100039482116699\n", + "Step 190: loss 2.1237497329711914\n", + "Step 200: loss 2.1287221908569336\n", + "--> val loss 2.2659342288970947\n", + "Step 210: loss 2.053070306777954\n", + "Step 220: loss 2.1026082038879395\n", + "Step 230: loss 1.9534435272216797\n", + "Step 240: loss 2.1179442405700684\n", + "Step 250: loss 1.957367181777954\n", + "Step 260: loss 2.099113702774048\n", + "Step 270: loss 2.046205759048462\n", + "Step 280: loss 2.0597729682922363\n", + "Step 290: loss 2.0586631298065186\n", + "Step 300: loss 1.9834868907928467\n", + "--> val loss 2.2871968746185303\n", + "Step 310: loss 2.076115369796753\n", + "Step 320: loss 2.0850110054016113\n", + "Step 330: loss 2.159264087677002\n", + "Step 340: loss 2.1108834743499756\n", + "Step 350: loss 1.9584811925888062\n", + "Step 360: loss 1.9317572116851807\n", + "Step 370: loss 2.047001361846924\n", + "Step 380: loss 1.9950387477874756\n", + "Step 390: loss 1.9855890274047852\n", + "Step 400: loss 2.1213014125823975\n", + "--> val loss 2.281414747238159\n", + "Step 410: loss 2.056396484375\n", + "Step 420: loss 1.8612827062606812\n", + "Step 430: loss 1.948122262954712\n", + "Step 440: loss 1.9674848318099976\n", + "Step 450: loss 2.03259539604187\n", + "Step 460: loss 2.011803388595581\n", + "Step 470: loss 2.0299034118652344\n", + "Step 480: loss 1.985011100769043\n", + "Step 490: loss 2.021385669708252\n", + "Step 500: loss 2.052276372909546\n", + "--> val loss 2.203423023223877\n", + "Step 510: loss 2.0946555137634277\n", + "Step 520: loss 1.9836640357971191\n", + "Step 530: loss 2.009483575820923\n", + "Step 540: loss 1.9228500127792358\n", + "Step 550: loss 2.1320910453796387\n", + "Step 560: loss 1.9897129535675049\n", + "Step 570: loss 1.9571495056152344\n", + "Step 580: loss 1.9846587181091309\n", + "Step 590: loss 1.9535086154937744\n", + "Step 600: loss 2.020153045654297\n", + "--> val loss 2.2494421005249023\n", + "Step 610: loss 2.084792137145996\n", + "Step 620: loss 2.021752119064331\n", + "Step 630: loss 1.9838941097259521\n", + "Step 640: loss 2.0131120681762695\n", + "Step 650: loss 2.071953535079956\n", + "Step 660: loss 2.086606502532959\n", + "Step 670: loss 2.058492660522461\n", + "Step 680: loss 1.9665050506591797\n", + "Step 690: loss 1.943424940109253\n", + "Step 700: loss 1.8730368614196777\n", + "--> val loss 2.2508132457733154\n", + "Step 710: loss 2.0128204822540283\n", + "Step 720: loss 2.1044511795043945\n", + "Step 730: loss 2.0091733932495117\n", + "Step 740: loss 2.0575761795043945\n", + "Step 750: loss 2.080972909927368\n", + "Step 760: loss 2.093989372253418\n", + "Step 770: loss 1.932627558708191\n", + "Step 780: loss 2.214707851409912\n", + "Step 790: loss 1.9911472797393799\n", + "Step 800: loss 2.02778959274292\n", + "--> val loss 2.181839942932129\n", + "Step 810: loss 2.049180269241333\n", + "Step 820: loss 2.064502477645874\n", + "Step 830: loss 2.044832468032837\n", + "Step 840: loss 2.010237455368042\n", + "Step 850: loss 2.007585287094116\n", + "Step 860: loss 1.928617238998413\n", + "Step 870: loss 2.0367650985717773\n", + "Step 880: loss 2.093597412109375\n", + "Step 890: loss 2.107056140899658\n", + "Step 900: loss 2.0655107498168945\n", + "--> val loss 2.271533727645874\n", + "Step 910: loss 2.094755172729492\n", + "Step 920: loss 1.978514552116394\n", + "Step 930: loss 2.0587754249572754\n", + "Step 940: loss 1.903182029724121\n", + "Step 950: loss 2.006319761276245\n", + "Step 960: loss 2.1091904640197754\n", + "Step 970: loss 2.0761780738830566\n", + "Step 980: loss 1.9954557418823242\n", + "Step 990: loss 2.0989832878112793\n", + "Step 1000: loss 2.0245680809020996\n", + "--> val loss 2.2492425441741943\n", + "Step 1010: loss 2.0572500228881836\n", + "Step 1020: loss 2.049919843673706\n", + "Step 1030: loss 2.0195538997650146\n", + "Step 1040: loss 2.065272331237793\n", + "Step 1050: loss 2.08404803276062\n", + "Step 1060: loss 1.991067886352539\n", + "Step 1070: loss 2.059558629989624\n", + "Step 1080: loss 2.0323941707611084\n", + "Step 1090: loss 1.988283395767212\n", + "Step 1100: loss 1.899468183517456\n", + "--> val loss 2.224306106567383\n", + "Step 1110: loss 2.032186269760132\n", + "Step 1120: loss 2.083733558654785\n", + "Step 1130: loss 1.9714925289154053\n", + "Step 1140: loss 2.0304412841796875\n", + "Step 1150: loss 2.1165449619293213\n", + "Step 1160: loss 1.8574061393737793\n", + "Step 1170: loss 2.053088665008545\n", + "Step 1180: loss 1.9868009090423584\n", + "Step 1190: loss 2.0184385776519775\n", + "Step 1200: loss 1.9368431568145752\n", + "--> val loss 2.214960813522339\n", + "Step 1210: loss 1.9542896747589111\n", + "Step 1220: loss 1.9777746200561523\n", + "Step 1230: loss 2.0913281440734863\n", + "Step 1240: loss 1.9256919622421265\n", + "Step 1250: loss 2.0141139030456543\n", + "Step 1260: loss 1.7947438955307007\n", + "Step 1270: loss 2.0271081924438477\n", + "Step 1280: loss 2.0013082027435303\n", + "Step 1290: loss 1.947591781616211\n", + "Step 1300: loss 1.9167909622192383\n", + "--> val loss 2.083812713623047\n", + "Step 1310: loss 2.013084650039673\n", + "Step 1320: loss 1.909340262413025\n", + "Step 1330: loss 1.9620397090911865\n", + "Step 1340: loss 1.9765715599060059\n", + "Step 1350: loss 1.8904399871826172\n", + "Step 1360: loss 1.7793501615524292\n", + "Step 1370: loss 1.9580650329589844\n", + "Step 1380: loss 1.9796538352966309\n", + "Step 1390: loss 1.959943175315857\n", + "Step 1400: loss 1.9130542278289795\n", + "--> val loss 2.0795230865478516\n", + "Step 1410: loss 1.8990237712860107\n", + "Step 1420: loss 1.953758955001831\n", + "Step 1430: loss 1.9493964910507202\n", + "Step 1440: loss 1.9433190822601318\n", + "Step 1450: loss 1.9356145858764648\n", + "Step 1460: loss 1.8823035955429077\n", + "Step 1470: loss 1.830541968345642\n", + "Step 1480: loss 1.9393571615219116\n", + "Step 1490: loss 1.900813102722168\n", + "Step 1500: loss 1.8914613723754883\n", + "--> val loss 2.0876729488372803\n", + "Step 1510: loss 1.9165732860565186\n", + "Step 1520: loss 1.9649326801300049\n", + "Step 1530: loss 1.8890880346298218\n", + "Step 1540: loss 1.9110441207885742\n", + "Step 1550: loss 1.7660144567489624\n", + "Step 1560: loss 1.9185322523117065\n", + "Step 1570: loss 2.0385913848876953\n", + "Step 1580: loss 1.877274513244629\n", + "Step 1590: loss 1.7854435443878174\n", + "Step 1600: loss 1.927487850189209\n", + "--> val loss 2.0727601051330566\n", + "Step 1610: loss 1.9156030416488647\n", + "Step 1620: loss 1.7924439907073975\n", + "Step 1630: loss 1.8454530239105225\n", + "Step 1640: loss 1.8715399503707886\n", + "Step 1650: loss 1.8894051313400269\n", + "Step 1660: loss 1.897505760192871\n", + "Step 1670: loss 1.7877782583236694\n", + "Step 1680: loss 1.867870807647705\n", + "Step 1690: loss 1.842137336730957\n", + "Step 1700: loss 1.7311986684799194\n", + "--> val loss 1.9756139516830444\n", + "Step 1710: loss 1.8065001964569092\n", + "Step 1720: loss 1.9407305717468262\n", + "Step 1730: loss 1.8354434967041016\n", + "Step 1740: loss 1.7024110555648804\n", + "Step 1750: loss 1.7469775676727295\n", + "Step 1760: loss 1.8872644901275635\n", + "Step 1770: loss 1.8048818111419678\n", + "Step 1780: loss 1.8275573253631592\n", + "Step 1790: loss 1.779404640197754\n", + "Step 1800: loss 1.7527437210083008\n", + "--> val loss 1.9299205541610718\n", + "Step 1810: loss 1.7771308422088623\n", + "Step 1820: loss 1.7643531560897827\n", + "Step 1830: loss 1.8124539852142334\n", + "Step 1840: loss 1.8404573202133179\n", + "Step 1850: loss 1.8412466049194336\n", + "Step 1860: loss 1.7188396453857422\n", + "Step 1870: loss 1.8082847595214844\n", + "Step 1880: loss 1.7587134838104248\n", + "Step 1890: loss 1.7570083141326904\n", + "Step 1900: loss 1.6468451023101807\n", + "--> val loss 1.9047764539718628\n", + "Step 1910: loss 1.7798075675964355\n", + "Step 1920: loss 1.8022470474243164\n", + "Step 1930: loss 1.8122791051864624\n", + "Step 1940: loss 1.9106707572937012\n", + "Step 1950: loss 1.8597577810287476\n", + "Step 1960: loss 1.8255852460861206\n", + "Step 1970: loss 1.8626060485839844\n", + "Step 1980: loss 1.7441591024398804\n", + "Step 1990: loss 1.8292932510375977\n", + "Step 2000: loss 1.7385106086730957\n", + "--> val loss 1.8767629861831665\n" ] } ], @@ -638,16 +624,18 @@ "text": [ "Sample 0:\n", "\n", - "If where his elperiend and is here in think the comfore be pray virtue deather I the grouth a pears my\n", + "If what this death\n", + "The cours the hand thinke\n", + "This the enter there well drate but the grome wout sean t\n", "--------------------------------------------------------------------------------\n", "Sample 1:\n", "\n", - "If as the conture the weet to the man's death the greeen he with thought rame the prosates he palousen\n", + "If my lord, and preat,\n", + "And to the can here there you all Iss may thought natter with have I will did n\n", "--------------------------------------------------------------------------------\n", "Sample 2:\n", "\n", - "If him the be not me were and let for the earth the forth,\n", - "That the his a wort of you the fearshould a\n", + "If my the proving me with the life the know man the forther for the him to came haver the ither oum th\n", "--------------------------------------------------------------------------------\n" ] } @@ -671,6 +659,13 @@ " print(f\"Sample {seed}:\\n\\n{generate_text('If', max_tokens=100, seed=seed)}\")\n", " print(\"-\" * 80)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -689,7 +684,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/examples/hello-mnist-vit.ipynb b/examples/hello-mnist-vit.ipynb new file mode 100644 index 0000000..fb05a11 --- /dev/null +++ b/examples/hello-mnist-vit.ipynb @@ -0,0 +1,775 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dd286c88-ce33-4be7-8aec-3c3fe5176c40", + "metadata": {}, + "source": [ + "# Hello, MNIST!" + ] + }, + { + "cell_type": "markdown", + "id": "58d2fdb1-575a-4cc3-8aa5-0fb5ee6fca9f", + "metadata": {}, + "source": [ + "Let's extend [Hello, World!](hello-world.ipynb) to an actual learning problem: MNIST! First, we download the data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "477a30ee-9242-4207-a1f7-c8e4b5c702b1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training images shape: (60000, 28, 28)\n", + "Training labels shape: (60000,)\n", + "Test images shape: (10000, 28, 28)\n", + "Test labels shape: (10000,)\n" + ] + } + ], + "source": [ + "from data.mnist import load_mnist\n", + "\n", + "# Load the MNIST dataset\n", + "train_images, train_labels, test_images, test_labels = load_mnist()\n", + "\n", + "# Print shapes to verify loading\n", + "print(f\"Training images shape: {train_images.shape}\")\n", + "print(f\"Training labels shape: {train_labels.shape}\")\n", + "print(f\"Test images shape: {test_images.shape}\")\n", + "print(f\"Test labels shape: {test_labels.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ad9d808d-7a30-4255-94dc-aac83c8023be", + "metadata": {}, + "source": [ + "Next, let's plot a few training points." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7322b0da-0d22-4742-8885-69500e549774", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Create a figure with 5 subplots\n", + "fig, axes = plt.subplots(1, 5, figsize=(10, 3))\n", + "\n", + "# Plot each image and its label\n", + "for i in range(5):\n", + " axes[i].imshow(train_images[i], cmap='gray')\n", + " axes[i].set_title(f'Label: {train_labels[i]}')\n", + " axes[i].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b79b80ac-34fd-4f15-b83c-41d43f37c07e", + "metadata": {}, + "source": [ + "Now we flatten the inputs into vectors, encode the targets as one-hot vectors, and write a mini-batch sampler." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "X_train = jnp.reshape(train_images, (-1, 28, 28, 1))\n", + "y_train = train_labels\n", + "\n", + "# Getting a batch like this results in horrible performance!\n", + "# def get_batch(key, batch_size):\n", + "# idx = jax.random.choice(key, X_train.shape[1], shape=(batch_size,))\n", + "# return X_train[idx, :], y_train[idx]\n", + "\n", + "def stream(batch_size):\n", + " num_samples = X_train.shape[0]\n", + " while True:\n", + " key = jax.random.PRNGKey(0)\n", + " perm = jax.random.permutation(key, num_samples)\n", + " for i in range(num_samples // batch_size):\n", + " indices = perm[i * batch_size:(i + 1) * batch_size]\n", + " yield X_train[indices], y_train[indices]" + ] + }, + { + "cell_type": "markdown", + "id": "24819e5b-f810-4d55-ab53-d57f592eed82", + "metadata": {}, + "source": [ + "Now we're ready to build our ViT." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CompositeModule\n", + "...consists of 54 atoms and 91 bonds\n", + "...non-smooth\n", + "...input sensitivity is 51.72763417352536\n", + "...contributes proportion 12.0 to feature learning of any supermodule\n" + ] + } + ], + "source": [ + "from modula.compound import ViT\n", + "\n", + "input_dim = X_train.shape[1:-1]\n", + "output_dim = 10\n", + "\n", + "vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0)\n", + "\n", + "print(vit)\n", + "\n", + "vit.jit()" + ] + }, + { + "cell_type": "markdown", + "id": "fa210de2-0ea1-4466-96f3-7cc1e2aca11d", + "metadata": {}, + "source": [ + "Let's train the MLP for 1000 steps." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8abb8bc4-1643-4d1d-8dea-006e5c873b74", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.notebook import tqdm\n", + "\n", + "def cross_entropy_loss(w, inputs, targets):\n", + " # We use the logsumexp trick for stable cross entropy\n", + " logits = vit(inputs, w) # shape is [batch, num_classes]\n", + " batch_indices = jnp.arange(logits.shape[0])\n", + " losses = -logits[batch_indices, targets] + jax.nn.logsumexp(logits, axis=-1)\n", + " return losses.mean()\n", + "\n", + "loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7b9d477e-afa9-4150-851e-5c685e1dad01", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9c5737ccfa3f4c6eb73fe117321154c9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loss: 0.0000: 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy on shown samples: 5/5\n", + "Overall test accuracy: 94.69%\n" + ] + } + ], + "source": [ + "# Get predictions for test images\n", + "X_test = jnp.reshape(test_images, (-1, 28, 28, 1))\n", + "test_outputs = vit(X_test, w)\n", + "predicted_labels = jnp.argmax(test_outputs, axis=1)\n", + "\n", + "# Create a figure with subplots for multiple test images\n", + "n_samples = 5 # Number of test images to display\n", + "fig, axes = plt.subplots(1, n_samples, figsize=(10, 3))\n", + "\n", + "# Plot each test image with predicted labels\n", + "for i in range(n_samples):\n", + " axes[i].imshow(test_images[i], cmap='gray')\n", + " axes[i].set_title(f'Predicted label: {int(predicted_labels[i])}')\n", + " axes[i].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Print accuracy for these samples\n", + "correct = (predicted_labels[:n_samples] == test_labels[:n_samples]).sum()\n", + "print(f\"Accuracy on shown samples: {correct}/{n_samples}\")\n", + "\n", + "# Calculate and print overall test accuracy\n", + "total_correct = (predicted_labels == test_labels).sum()\n", + "total_samples = len(test_labels)\n", + "print(f\"Overall test accuracy: {100 * total_correct/total_samples:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ca803b74-6095-482b-8608-ddd43e478f4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Barebone without bias or LN\n", + "# vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=False, bias=False, scale=False)\n", + "# Just LN, no bias or scale\n", + "# vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=True, bias=False, scale=False)\n", + "# Both would diverge" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "860605de-198b-402e-b77a-6e4e0f4ded4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CompositeModule\n", + "...consists of 40 atoms and 77 bonds\n", + "...non-smooth\n", + "...input sensitivity is 2.84765625\n", + "...contributes proportion 10.0 to feature learning of any supermodule\n" + ] + } + ], + "source": [ + "# No LN or scale, just bias:\n", + "vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=False, bias=True, scale=False)\n", + "\n", + "print(vit)\n", + "\n", + "vit.jit()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "61303c30-072d-4316-8b7c-8b9bdeca5173", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "56f30e99a1ad487e9b8342c710ef8cc6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loss: 0.0000: 0%| | 0/1000 [00:00" ] @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1", "metadata": {}, "outputs": [], @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 4, "id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765", "metadata": {}, "outputs": [ @@ -172,14 +172,14 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 5, "id": "080bbf4f-0b73-4d6a-a3d5-f64a2875da9c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b71f8a04aa254476b5b883d8121bbf3d", + "model_id": "ba34de950ef549d3bb7c3408ec946be5", "version_major": 2, "version_minor": 0 }, @@ -229,13 +229,13 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "id": "9a08a8ea-d1e8-49b5-8166-05dcbde47f4c", "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -248,7 +248,7 @@ "output_type": "stream", "text": [ "Accuracy on shown samples: 5/5\n", - "Overall test accuracy: 90.58%\n" + "Overall test accuracy: 89.48%\n" ] } ], @@ -280,6 +280,14 @@ "total_samples = len(test_labels)\n", "print(f\"Overall test accuracy: {100 * total_correct/total_samples:.2f}%\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba18cc4c-75d0-4f9e-a7fb-35c997a0e951", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -298,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/modula/abstract.py b/modula/abstract.py index ec9ca01..c2e971f 100644 --- a/modula/abstract.py +++ b/modula/abstract.py @@ -1,6 +1,8 @@ -import jax import copy +import jax +import einops + class Module: def __init__(self): self.children = [] @@ -204,3 +206,26 @@ def __init__(self, scalar): def forward(self, x, w): return x * self.sensitivity + +class Mean(Bond): + def __init__(self, axis, size): + super().__init__() + self.smooth = True + self.axis = axis + self.size = size + self.sensitivity = 1 / size + + def forward(self, x, w): + assert x.shape[self.axis] == self.size + return jax.numpy.mean(x, axis=self.axis) + +class Patchify(Bond): + def __init__(self, size): + super().__init__() + self.smooth = True + self.sensitivity = 1 + self.size = size + + def forward(self, x, w): + p1, p2 = self.size + return einops.rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=p1, p2=p2) diff --git a/modula/atom.py b/modula/atom.py index c46855f..e994e80 100644 --- a/modula/atom.py +++ b/modula/atom.py @@ -90,6 +90,54 @@ def dualize(self, grad_w, target_norm=1.0): d_weight = jnp.nan_to_num(d_weight) return [d_weight] +class Bias(Atom): + def __init__(self, d): + super().__init__() + self.d = d + self.smooth = True + self.mass = 1 + self.sensitivity = 1 + + def forward(self, x, w): + weights = w[0] # shape [d] + return weights + + def initialize(self, key): + return [jnp.zeros(shape=self.d)] + + def project(self, w): + weight = w[0] + weight = weight / jnp.linalg.norm(weight) * jnp.sqrt(self.d) + return [weight] + + def dualize(self, grad_w, target_norm=1.0): + grad = grad_w[0] + d_weight = grad / jnp.linalg.norm(grad) * jnp.sqrt(self.d) * target_norm + d_weight = jnp.nan_to_num(d_weight) + return [d_weight] + +class Scale(Atom): + def __init__(self, d): + super().__init__() + self.d = d + self.smooth = True + self.mass = 1 + self.sensitivity = 1 + + def forward(self, x, w): + weights = w[0] # shape [d] + return weights * x + + def initialize(self, key): + return [jnp.ones(shape=self.d)] + + def project(self, w): + weight = w[0] + return [jnp.sign(weight)] + + def dualize(self, grad_w, target_norm=1.0): + grad = grad_w[0] + return [jnp.sign(grad) * target_norm] if __name__ == "__main__": diff --git a/modula/bond.py b/modula/bond.py index 7bdbf00..7814022 100644 --- a/modula/bond.py +++ b/modula/bond.py @@ -97,6 +97,28 @@ def forward(self, x, w): v, scores = x return scores @ v +class Constant(Bond): + def __init__(self, f): + super().__init__() + self.f = f + self.smooth = True + self.sensitivity = 0 + + def forward(self, x, w): + return self.f() + +class LayerNorm(Bond): + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + self.smooth = True + self.sensitivity = 1 + + def forward(self, x, w): + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + return (x - mean) / jnp.sqrt(var + self.eps) + class Rope(Bond): """Rotates queries and keys by relative context window distance.""" def __init__(self, d_head, base=10000): @@ -106,18 +128,14 @@ def __init__(self, d_head, base=10000): self.rope_dim = d_head // 2 self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim) - self.seq_len_cached = None - self.sin_cached = None - self.cos_cached = None def get_cached(self, seq_len): - if self.seq_len_cached != seq_len: - self.seq_len_cached = seq_len - distance = jnp.arange(seq_len) - freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim] - self.cos_cached = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim] - self.sin_cached = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim] - return self.sin_cached, self.cos_cached + # Actually caching the return value may lead to leaked intermediate value error + distance = jnp.arange(seq_len) + freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim] + cos = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim] + sin = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim] + return sin, cos def rotate(self, x): batch, n_heads, seq_len, d_head = x.shape @@ -126,6 +144,7 @@ def rotate(self, x): x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] + # Why is the order reversed!? cos, sin = self.get_cached(seq_len) y1 = cos * x1 + sin * x2 y2 = -sin * x1 + cos * x2 diff --git a/modula/compound.py b/modula/compound.py index 7d3701a..3cb0107 100644 --- a/modula/compound.py +++ b/modula/compound.py @@ -1,3 +1,5 @@ +import jax.numpy as jnp + from modula.abstract import * from modula.atom import * from modula.bond import * @@ -8,14 +10,21 @@ def MLP(output_dim, input_dim, width, depth): m = m @ Linear(width, width) @ ReLU() return m @ Linear(width, input_dim) -def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal): +def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal, posemb="rope", bias=False): """Multi-head attention""" - Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) - K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) - V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed) - W = Linear(d_embed, num_heads * d_value) @ MergeHeads() - - AttentionScores = Softmax(softmax_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K) + Q, K, V = Linear(num_heads * d_query, d_embed), Linear(num_heads * d_query, d_embed), Linear(num_heads * d_value, d_embed) + Q = SplitIntoHeads(num_heads) @ (Q + Bias(num_heads * d_query) if bias else Q) + K = SplitIntoHeads(num_heads) @ (K + Bias(num_heads * d_query) if bias else K) + V = SplitIntoHeads(num_heads) @ (V + Bias(num_heads * d_value) if bias else V) + W = Linear(d_embed, num_heads * d_value) + W = (W + Bias(d_embed) if bias else W) @ MergeHeads() + QK = (Q, K) + if posemb == "rope": + QK = Rope(d_query) @ QK + attn = AttentionQK() @ QK + if causal: + attn = CausalMask() @ attn + AttentionScores = Softmax(softmax_scale) @ attn return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores) def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0): @@ -31,4 +40,50 @@ def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mas out = final_scale * Linear(vocab_size, d_embed) - return out @ blocks @ embed \ No newline at end of file + return out @ blocks @ embed + +def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] + + assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = jnp.einsum("m,d->md", y.flatten(), omega) + x = jnp.einsum("m,d->md", x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + +def ViT(num_classes, image_size=(28, 28), patch_size=(7, 7), num_heads=4, d_embed=32, d_query=8, d_value=8, num_blocks=4, blocks_mass=5, attention_scale=1.0, final_scale=1.0, channels=1, LN=True, bias=True, scale=True): + i1, i2 = image_size + p1, p2 = patch_size + h, w = i1 // p1, i2 // p2 + patchify = Linear(d_embed, p1 * p2 * channels) @ Patchify(patch_size) + if bias: + patchify = patchify + Bias(d_embed) + posemb = Constant(lambda: posemb_sincos_2d(h, w, d_embed)) + + att = Attention(num_heads, d_embed, d_query, d_value, attention_scale, causal=False, posemb="none", bias=bias) + mlp = (Linear(d_embed, 4*d_embed) + Bias(d_embed) if bias else Linear(d_embed, 4*d_embed)) @ GeLU() @ (Linear(4*d_embed, d_embed) + Bias(4*d_embed) if bias else Linear(4*d_embed, d_embed)) + if LN: + ln = LayerNorm() + if bias and scale: + ln = (Scale(d_embed) + Bias(d_embed)) @ ln + elif bias: + ln = ln + Bias(d_embed) + elif scale: + ln = Scale(d_embed) @ ln + att = att @ ln + mlp = mlp @ ln + att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att + mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp + blocks = (mlp_block @ att_block) ** num_blocks + blocks.tare(absolute=blocks_mass) + + gap = Mean(axis=1, size=h * w) + out = final_scale * (Linear(num_classes, d_embed) + Bias(num_classes) if bias else Linear(num_classes, d_embed)) + + ret = blocks @ (patchify + posemb) + if LN: # Final LN + ret = ln @ ret + return out @ gap @ ret