diff --git a/examples/data/mnist.py b/examples/data/mnist.py index f5cce3e..9339fc0 100644 --- a/examples/data/mnist.py +++ b/examples/data/mnist.py @@ -44,7 +44,7 @@ def parse_labels(filepath): test_labels = parse_labels(os.path.join(data_dir, "t10k-labels-idx1-ubyte.gz")) if normalize: - train_images = train_images.astype(np.float32) / 255.0 - test_images = test_images.astype(np.float32) / 255.0 + train_images = train_images.astype(np.float32) / 127.5 - 1. + test_images = test_images.astype(np.float32) / 127.5 - 1. return train_images, train_labels, test_images, test_labels \ No newline at end of file diff --git a/examples/hello-gpt.ipynb b/examples/hello-gpt.ipynb index c3c171f..a8ed039 100644 --- a/examples/hello-gpt.ipynb +++ b/examples/hello-gpt.ipynb @@ -32,21 +32,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading Shakespeare dataset...\n", - "Processing Shakespeare dataset...\n", - "Length of dataset in characters: 1,115,394\n", - "Vocabulary size: 65\n", - "Train has 1,003,854 tokens\n", - "Val has 111,540 tokens\n", - "Shakespeare dataset processing complete.\n" - ] - } - ], + "outputs": [], "source": [ "context = 64\n", "batch_size = 12\n", @@ -362,228 +348,228 @@ "name": "stdout", "output_type": "stream", "text": [ - "Step 0: loss 4.226325988769531\n", - "--> val loss 4.179544448852539\n", - "Step 10: loss 3.8738746643066406\n", - "Step 20: loss 3.3448646068573\n", - "Step 30: loss 2.805002212524414\n", - "Step 40: loss 2.68573260307312\n", - "Step 50: loss 2.6098480224609375\n", - "Step 60: loss 2.407468557357788\n", - "Step 70: loss 2.418379783630371\n", - "Step 80: loss 2.359757423400879\n", - "Step 90: loss 2.2685279846191406\n", - "Step 100: loss 2.314124584197998\n", - "--> val loss 2.541980743408203\n", - "Step 110: loss 2.283424139022827\n", - "Step 120: loss 2.2063167095184326\n", - "Step 130: loss 2.1598031520843506\n", - "Step 140: loss 2.252727508544922\n", - "Step 150: loss 2.124152660369873\n", - "Step 160: loss 2.23785662651062\n", - "Step 170: loss 2.2059123516082764\n", - "Step 180: loss 2.102996587753296\n", - "Step 190: loss 2.132392168045044\n", - "Step 200: loss 2.130244255065918\n", - "--> val loss 2.359212636947632\n", - "Step 210: loss 2.0895276069641113\n", - "Step 220: loss 2.1278815269470215\n", - "Step 230: loss 1.9647449254989624\n", - "Step 240: loss 2.1118733882904053\n", - "Step 250: loss 1.9459623098373413\n", - "Step 260: loss 2.118051290512085\n", - "Step 270: loss 2.0605385303497314\n", - "Step 280: loss 2.0378551483154297\n", - "Step 290: loss 2.0237479209899902\n", - "Step 300: loss 1.982785940170288\n", - "--> val loss 2.2887392044067383\n", - "Step 310: loss 2.073058605194092\n", - "Step 320: loss 2.082066535949707\n", - "Step 330: loss 2.130162239074707\n", - "Step 340: loss 2.092909336090088\n", - "Step 350: loss 1.9229984283447266\n", - "Step 360: loss 1.9037134647369385\n", - "Step 370: loss 2.0083131790161133\n", - "Step 380: loss 2.0236263275146484\n", - "Step 390: loss 2.0116419792175293\n", - "Step 400: loss 2.091407299041748\n", - "--> val loss 2.2199790477752686\n", - "Step 410: loss 2.0855846405029297\n", - "Step 420: loss 1.8506882190704346\n", - "Step 430: loss 1.9745848178863525\n", - "Step 440: loss 1.9135173559188843\n", - "Step 450: loss 2.0486648082733154\n", - "Step 460: loss 1.983982801437378\n", - "Step 470: loss 1.9958977699279785\n", - "Step 480: loss 1.9868993759155273\n", - "Step 490: loss 2.009216785430908\n", - "Step 500: loss 2.073169231414795\n", - "--> val loss 2.141632556915283\n", - "Step 510: loss 2.0603322982788086\n", - "Step 520: loss 2.0025858879089355\n", - "Step 530: loss 1.9482192993164062\n", - "Step 540: loss 1.9092429876327515\n", - "Step 550: loss 2.109374761581421\n", - "Step 560: loss 1.9060167074203491\n", - "Step 570: loss 1.9423940181732178\n", - "Step 580: loss 1.9405231475830078\n", - "Step 590: loss 1.9132475852966309\n", - "Step 600: loss 2.0125274658203125\n", - "--> val loss 2.2273831367492676\n", - "Step 610: loss 2.0854687690734863\n", - "Step 620: loss 1.9796791076660156\n", - "Step 630: loss 1.982351303100586\n", - "Step 640: loss 2.044363021850586\n", - "Step 650: loss 2.030698299407959\n", - "Step 660: loss 2.0731544494628906\n", - "Step 670: loss 1.9660027027130127\n", - "Step 680: loss 1.933128833770752\n", - "Step 690: loss 1.8852118253707886\n", - "Step 700: loss 1.8401598930358887\n", - "--> val loss 2.0958476066589355\n", - "Step 710: loss 1.9790323972702026\n", - "Step 720: loss 2.0329394340515137\n", - "Step 730: loss 1.929424524307251\n", - "Step 740: loss 1.950282335281372\n", - "Step 750: loss 1.938680648803711\n", - "Step 760: loss 1.9717748165130615\n", - "Step 770: loss 1.8411779403686523\n", - "Step 780: loss 2.085500717163086\n", - "Step 790: loss 1.8778104782104492\n", - "Step 800: loss 1.9712986946105957\n", - "--> val loss 2.1469686031341553\n", - "Step 810: loss 1.949462652206421\n", - "Step 820: loss 1.9898126125335693\n", - "Step 830: loss 1.9045312404632568\n", - "Step 840: loss 1.9053363800048828\n", - "Step 850: loss 1.8944416046142578\n", - "Step 860: loss 1.8389015197753906\n", - "Step 870: loss 1.9189136028289795\n", - "Step 880: loss 2.0141639709472656\n", - "Step 890: loss 1.9987534284591675\n", - "Step 900: loss 1.947631597518921\n", - "--> val loss 2.1903281211853027\n", - "Step 910: loss 2.031083106994629\n", - "Step 920: loss 1.988853931427002\n", - "Step 930: loss 2.0356318950653076\n", - "Step 940: loss 1.8823192119598389\n", - "Step 950: loss 2.0429515838623047\n", - "Step 960: loss 2.021817684173584\n", - "Step 970: loss 2.003168821334839\n", - "Step 980: loss 2.0105528831481934\n", - "Step 990: loss 2.014195680618286\n", - "Step 1000: loss 1.9518741369247437\n", - "--> val loss 2.0813283920288086\n", - "Step 1010: loss 2.016996383666992\n", - "Step 1020: loss 2.04374098777771\n", - "Step 1030: loss 1.8839387893676758\n", - "Step 1040: loss 1.96620512008667\n", - "Step 1050: loss 2.0463950634002686\n", - "Step 1060: loss 1.9169645309448242\n", - "Step 1070: loss 2.038651943206787\n", - "Step 1080: loss 2.0474071502685547\n", - "Step 1090: loss 1.9452462196350098\n", - "Step 1100: loss 1.8884999752044678\n", - "--> val loss 2.1541106700897217\n", - "Step 1110: loss 1.9775495529174805\n", - "Step 1120: loss 1.96068274974823\n", - "Step 1130: loss 1.8553755283355713\n", - "Step 1140: loss 1.9422013759613037\n", - "Step 1150: loss 2.0833449363708496\n", - "Step 1160: loss 1.840619444847107\n", - "Step 1170: loss 2.032219409942627\n", - "Step 1180: loss 1.9345749616622925\n", - "Step 1190: loss 1.934565544128418\n", - "Step 1200: loss 1.9528722763061523\n", - "--> val loss 2.1688506603240967\n", - "Step 1210: loss 1.8676490783691406\n", - "Step 1220: loss 1.9311145544052124\n", - "Step 1230: loss 1.9905321598052979\n", - "Step 1240: loss 1.8773740530014038\n", - "Step 1250: loss 1.9832658767700195\n", - "Step 1260: loss 1.8256521224975586\n", - "Step 1270: loss 2.037313461303711\n", - "Step 1280: loss 1.9440114498138428\n", - "Step 1290: loss 1.9472723007202148\n", - "Step 1300: loss 1.862718105316162\n", - "--> val loss 2.0632894039154053\n", - "Step 1310: loss 1.944453239440918\n", - "Step 1320: loss 1.869157075881958\n", - "Step 1330: loss 1.9843480587005615\n", - "Step 1340: loss 1.9083728790283203\n", - "Step 1350: loss 1.920233130455017\n", - "Step 1360: loss 1.7926225662231445\n", - "Step 1370: loss 1.8765363693237305\n", - "Step 1380: loss 1.9374698400497437\n", - "Step 1390: loss 1.9032771587371826\n", - "Step 1400: loss 1.8976068496704102\n", - "--> val loss 2.0361690521240234\n", - "Step 1410: loss 1.8799960613250732\n", - "Step 1420: loss 1.9112414121627808\n", - "Step 1430: loss 1.8797309398651123\n", - "Step 1440: loss 1.9040837287902832\n", - "Step 1450: loss 1.8828296661376953\n", - "Step 1460: loss 1.83419930934906\n", - "Step 1470: loss 1.8327134847640991\n", - "Step 1480: loss 1.857541799545288\n", - "Step 1490: loss 1.8209788799285889\n", - "Step 1500: loss 1.780470371246338\n", - "--> val loss 2.0466208457946777\n", - "Step 1510: loss 1.8544996976852417\n", - "Step 1520: loss 1.8710064888000488\n", - "Step 1530: loss 1.8195044994354248\n", - "Step 1540: loss 1.874974250793457\n", - "Step 1550: loss 1.7101812362670898\n", - "Step 1560: loss 1.8439801931381226\n", - "Step 1570: loss 1.967679500579834\n", - "Step 1580: loss 1.888682246208191\n", - "Step 1590: loss 1.6926288604736328\n", - "Step 1600: loss 1.875901222229004\n", - "--> val loss 2.044935941696167\n", - "Step 1610: loss 1.8210939168930054\n", - "Step 1620: loss 1.7439773082733154\n", - "Step 1630: loss 1.7956527471542358\n", - "Step 1640: loss 1.792572021484375\n", - "Step 1650: loss 1.7985519170761108\n", - "Step 1660: loss 1.8520288467407227\n", - "Step 1670: loss 1.680544137954712\n", - "Step 1680: loss 1.7917392253875732\n", - "Step 1690: loss 1.8400462865829468\n", - "Step 1700: loss 1.6793416738510132\n", - "--> val loss 1.995697021484375\n", - "Step 1710: loss 1.7414367198944092\n", - "Step 1720: loss 1.8606326580047607\n", - "Step 1730: loss 1.7578084468841553\n", - "Step 1740: loss 1.6292760372161865\n", - "Step 1750: loss 1.7017428874969482\n", - "Step 1760: loss 1.8407533168792725\n", - "Step 1770: loss 1.7789411544799805\n", - "Step 1780: loss 1.802499532699585\n", - "Step 1790: loss 1.7586851119995117\n", - "Step 1800: loss 1.7281568050384521\n", - "--> val loss 1.9875770807266235\n", - "Step 1810: loss 1.7767337560653687\n", - "Step 1820: loss 1.7158925533294678\n", - "Step 1830: loss 1.7596324682235718\n", - "Step 1840: loss 1.7826766967773438\n", - "Step 1850: loss 1.7769875526428223\n", - "Step 1860: loss 1.6953961849212646\n", - "Step 1870: loss 1.7714271545410156\n", - "Step 1880: loss 1.6994340419769287\n", - "Step 1890: loss 1.7252253293991089\n", - "Step 1900: loss 1.566367506980896\n", - "--> val loss 1.9310436248779297\n", - "Step 1910: loss 1.7057380676269531\n", - "Step 1920: loss 1.7441104650497437\n", - "Step 1930: loss 1.7951183319091797\n", - "Step 1940: loss 1.8611491918563843\n", - "Step 1950: loss 1.787139654159546\n", - "Step 1960: loss 1.788725733757019\n", - "Step 1970: loss 1.7919573783874512\n", - "Step 1980: loss 1.706597089767456\n", - "Step 1990: loss 1.771501898765564\n", - "Step 2000: loss 1.7121562957763672\n", - "--> val loss 1.8968441486358643\n" + "Step 0: loss 4.203415393829346\n", + "--> val loss 4.184397220611572\n", + "Step 10: loss 3.8764960765838623\n", + "Step 20: loss 3.3471152782440186\n", + "Step 30: loss 2.8259036540985107\n", + "Step 40: loss 2.6758830547332764\n", + "Step 50: loss 2.6396656036376953\n", + "Step 60: loss 2.4299607276916504\n", + "Step 70: loss 2.4175362586975098\n", + "Step 80: loss 2.3901185989379883\n", + "Step 90: loss 2.2698512077331543\n", + "Step 100: loss 2.3211278915405273\n", + "--> val loss 2.493386745452881\n", + "Step 110: loss 2.280397415161133\n", + "Step 120: loss 2.2105860710144043\n", + "Step 130: loss 2.1742873191833496\n", + "Step 140: loss 2.265498161315918\n", + "Step 150: loss 2.1193413734436035\n", + "Step 160: loss 2.1948325634002686\n", + "Step 170: loss 2.171501636505127\n", + "Step 180: loss 2.100039482116699\n", + "Step 190: loss 2.1237497329711914\n", + "Step 200: loss 2.1287221908569336\n", + "--> val loss 2.2659342288970947\n", + "Step 210: loss 2.053070306777954\n", + "Step 220: loss 2.1026082038879395\n", + "Step 230: loss 1.9534435272216797\n", + "Step 240: loss 2.1179442405700684\n", + "Step 250: loss 1.957367181777954\n", + "Step 260: loss 2.099113702774048\n", + "Step 270: loss 2.046205759048462\n", + "Step 280: loss 2.0597729682922363\n", + "Step 290: loss 2.0586631298065186\n", + "Step 300: loss 1.9834868907928467\n", + "--> val loss 2.2871968746185303\n", + "Step 310: loss 2.076115369796753\n", + "Step 320: loss 2.0850110054016113\n", + "Step 330: loss 2.159264087677002\n", + "Step 340: loss 2.1108834743499756\n", + "Step 350: loss 1.9584811925888062\n", + "Step 360: loss 1.9317572116851807\n", + "Step 370: loss 2.047001361846924\n", + "Step 380: loss 1.9950387477874756\n", + "Step 390: loss 1.9855890274047852\n", + "Step 400: loss 2.1213014125823975\n", + "--> val loss 2.281414747238159\n", + "Step 410: loss 2.056396484375\n", + "Step 420: loss 1.8612827062606812\n", + "Step 430: loss 1.948122262954712\n", + "Step 440: loss 1.9674848318099976\n", + "Step 450: loss 2.03259539604187\n", + "Step 460: loss 2.011803388595581\n", + "Step 470: loss 2.0299034118652344\n", + "Step 480: loss 1.985011100769043\n", + "Step 490: loss 2.021385669708252\n", + "Step 500: loss 2.052276372909546\n", + "--> val loss 2.203423023223877\n", + "Step 510: loss 2.0946555137634277\n", + "Step 520: loss 1.9836640357971191\n", + "Step 530: loss 2.009483575820923\n", + "Step 540: loss 1.9228500127792358\n", + "Step 550: loss 2.1320910453796387\n", + "Step 560: loss 1.9897129535675049\n", + "Step 570: loss 1.9571495056152344\n", + "Step 580: loss 1.9846587181091309\n", + "Step 590: loss 1.9535086154937744\n", + "Step 600: loss 2.020153045654297\n", + "--> val loss 2.2494421005249023\n", + "Step 610: loss 2.084792137145996\n", + "Step 620: loss 2.021752119064331\n", + "Step 630: loss 1.9838941097259521\n", + "Step 640: loss 2.0131120681762695\n", + "Step 650: loss 2.071953535079956\n", + "Step 660: loss 2.086606502532959\n", + "Step 670: loss 2.058492660522461\n", + "Step 680: loss 1.9665050506591797\n", + "Step 690: loss 1.943424940109253\n", + "Step 700: loss 1.8730368614196777\n", + "--> val loss 2.2508132457733154\n", + "Step 710: loss 2.0128204822540283\n", + "Step 720: loss 2.1044511795043945\n", + "Step 730: loss 2.0091733932495117\n", + "Step 740: loss 2.0575761795043945\n", + "Step 750: loss 2.080972909927368\n", + "Step 760: loss 2.093989372253418\n", + "Step 770: loss 1.932627558708191\n", + "Step 780: loss 2.214707851409912\n", + "Step 790: loss 1.9911472797393799\n", + "Step 800: loss 2.02778959274292\n", + "--> val loss 2.181839942932129\n", + "Step 810: loss 2.049180269241333\n", + "Step 820: loss 2.064502477645874\n", + "Step 830: loss 2.044832468032837\n", + "Step 840: loss 2.010237455368042\n", + "Step 850: loss 2.007585287094116\n", + "Step 860: loss 1.928617238998413\n", + "Step 870: loss 2.0367650985717773\n", + "Step 880: loss 2.093597412109375\n", + "Step 890: loss 2.107056140899658\n", + "Step 900: loss 2.0655107498168945\n", + "--> val loss 2.271533727645874\n", + "Step 910: loss 2.094755172729492\n", + "Step 920: loss 1.978514552116394\n", + "Step 930: loss 2.0587754249572754\n", + "Step 940: loss 1.903182029724121\n", + "Step 950: loss 2.006319761276245\n", + "Step 960: loss 2.1091904640197754\n", + "Step 970: loss 2.0761780738830566\n", + "Step 980: loss 1.9954557418823242\n", + "Step 990: loss 2.0989832878112793\n", + "Step 1000: loss 2.0245680809020996\n", + "--> val loss 2.2492425441741943\n", + "Step 1010: loss 2.0572500228881836\n", + "Step 1020: loss 2.049919843673706\n", + "Step 1030: loss 2.0195538997650146\n", + "Step 1040: loss 2.065272331237793\n", + "Step 1050: loss 2.08404803276062\n", + "Step 1060: loss 1.991067886352539\n", + "Step 1070: loss 2.059558629989624\n", + "Step 1080: loss 2.0323941707611084\n", + "Step 1090: loss 1.988283395767212\n", + "Step 1100: loss 1.899468183517456\n", + "--> val loss 2.224306106567383\n", + "Step 1110: loss 2.032186269760132\n", + "Step 1120: loss 2.083733558654785\n", + "Step 1130: loss 1.9714925289154053\n", + "Step 1140: loss 2.0304412841796875\n", + "Step 1150: loss 2.1165449619293213\n", + "Step 1160: loss 1.8574061393737793\n", + "Step 1170: loss 2.053088665008545\n", + "Step 1180: loss 1.9868009090423584\n", + "Step 1190: loss 2.0184385776519775\n", + "Step 1200: loss 1.9368431568145752\n", + "--> val loss 2.214960813522339\n", + "Step 1210: loss 1.9542896747589111\n", + "Step 1220: loss 1.9777746200561523\n", + "Step 1230: loss 2.0913281440734863\n", + "Step 1240: loss 1.9256919622421265\n", + "Step 1250: loss 2.0141139030456543\n", + "Step 1260: loss 1.7947438955307007\n", + "Step 1270: loss 2.0271081924438477\n", + "Step 1280: loss 2.0013082027435303\n", + "Step 1290: loss 1.947591781616211\n", + "Step 1300: loss 1.9167909622192383\n", + "--> val loss 2.083812713623047\n", + "Step 1310: loss 2.013084650039673\n", + "Step 1320: loss 1.909340262413025\n", + "Step 1330: loss 1.9620397090911865\n", + "Step 1340: loss 1.9765715599060059\n", + "Step 1350: loss 1.8904399871826172\n", + "Step 1360: loss 1.7793501615524292\n", + "Step 1370: loss 1.9580650329589844\n", + "Step 1380: loss 1.9796538352966309\n", + "Step 1390: loss 1.959943175315857\n", + "Step 1400: loss 1.9130542278289795\n", + "--> val loss 2.0795230865478516\n", + "Step 1410: loss 1.8990237712860107\n", + "Step 1420: loss 1.953758955001831\n", + "Step 1430: loss 1.9493964910507202\n", + "Step 1440: loss 1.9433190822601318\n", + "Step 1450: loss 1.9356145858764648\n", + "Step 1460: loss 1.8823035955429077\n", + "Step 1470: loss 1.830541968345642\n", + "Step 1480: loss 1.9393571615219116\n", + "Step 1490: loss 1.900813102722168\n", + "Step 1500: loss 1.8914613723754883\n", + "--> val loss 2.0876729488372803\n", + "Step 1510: loss 1.9165732860565186\n", + "Step 1520: loss 1.9649326801300049\n", + "Step 1530: loss 1.8890880346298218\n", + "Step 1540: loss 1.9110441207885742\n", + "Step 1550: loss 1.7660144567489624\n", + "Step 1560: loss 1.9185322523117065\n", + "Step 1570: loss 2.0385913848876953\n", + "Step 1580: loss 1.877274513244629\n", + "Step 1590: loss 1.7854435443878174\n", + "Step 1600: loss 1.927487850189209\n", + "--> val loss 2.0727601051330566\n", + "Step 1610: loss 1.9156030416488647\n", + "Step 1620: loss 1.7924439907073975\n", + "Step 1630: loss 1.8454530239105225\n", + "Step 1640: loss 1.8715399503707886\n", + "Step 1650: loss 1.8894051313400269\n", + "Step 1660: loss 1.897505760192871\n", + "Step 1670: loss 1.7877782583236694\n", + "Step 1680: loss 1.867870807647705\n", + "Step 1690: loss 1.842137336730957\n", + "Step 1700: loss 1.7311986684799194\n", + "--> val loss 1.9756139516830444\n", + "Step 1710: loss 1.8065001964569092\n", + "Step 1720: loss 1.9407305717468262\n", + "Step 1730: loss 1.8354434967041016\n", + "Step 1740: loss 1.7024110555648804\n", + "Step 1750: loss 1.7469775676727295\n", + "Step 1760: loss 1.8872644901275635\n", + "Step 1770: loss 1.8048818111419678\n", + "Step 1780: loss 1.8275573253631592\n", + "Step 1790: loss 1.779404640197754\n", + "Step 1800: loss 1.7527437210083008\n", + "--> val loss 1.9299205541610718\n", + "Step 1810: loss 1.7771308422088623\n", + "Step 1820: loss 1.7643531560897827\n", + "Step 1830: loss 1.8124539852142334\n", + "Step 1840: loss 1.8404573202133179\n", + "Step 1850: loss 1.8412466049194336\n", + "Step 1860: loss 1.7188396453857422\n", + "Step 1870: loss 1.8082847595214844\n", + "Step 1880: loss 1.7587134838104248\n", + "Step 1890: loss 1.7570083141326904\n", + "Step 1900: loss 1.6468451023101807\n", + "--> val loss 1.9047764539718628\n", + "Step 1910: loss 1.7798075675964355\n", + "Step 1920: loss 1.8022470474243164\n", + "Step 1930: loss 1.8122791051864624\n", + "Step 1940: loss 1.9106707572937012\n", + "Step 1950: loss 1.8597577810287476\n", + "Step 1960: loss 1.8255852460861206\n", + "Step 1970: loss 1.8626060485839844\n", + "Step 1980: loss 1.7441591024398804\n", + "Step 1990: loss 1.8292932510375977\n", + "Step 2000: loss 1.7385106086730957\n", + "--> val loss 1.8767629861831665\n" ] } ], @@ -638,16 +624,18 @@ "text": [ "Sample 0:\n", "\n", - "If where his elperiend and is here in think the comfore be pray virtue deather I the grouth a pears my\n", + "If what this death\n", + "The cours the hand thinke\n", + "This the enter there well drate but the grome wout sean t\n", "--------------------------------------------------------------------------------\n", "Sample 1:\n", "\n", - "If as the conture the weet to the man's death the greeen he with thought rame the prosates he palousen\n", + "If my lord, and preat,\n", + "And to the can here there you all Iss may thought natter with have I will did n\n", "--------------------------------------------------------------------------------\n", "Sample 2:\n", "\n", - "If him the be not me were and let for the earth the forth,\n", - "That the his a wort of you the fearshould a\n", + "If my the proving me with the life the know man the forther for the him to came haver the ither oum th\n", "--------------------------------------------------------------------------------\n" ] } @@ -671,6 +659,13 @@ " print(f\"Sample {seed}:\\n\\n{generate_text('If', max_tokens=100, seed=seed)}\")\n", " print(\"-\" * 80)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -689,7 +684,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/examples/hello-mnist-vit.ipynb b/examples/hello-mnist-vit.ipynb new file mode 100644 index 0000000..fb05a11 --- /dev/null +++ b/examples/hello-mnist-vit.ipynb @@ -0,0 +1,775 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dd286c88-ce33-4be7-8aec-3c3fe5176c40", + "metadata": {}, + "source": [ + "# Hello, MNIST!" + ] + }, + { + "cell_type": "markdown", + "id": "58d2fdb1-575a-4cc3-8aa5-0fb5ee6fca9f", + "metadata": {}, + "source": [ + "Let's extend [Hello, World!](hello-world.ipynb) to an actual learning problem: MNIST! First, we download the data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "477a30ee-9242-4207-a1f7-c8e4b5c702b1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training images shape: (60000, 28, 28)\n", + "Training labels shape: (60000,)\n", + "Test images shape: (10000, 28, 28)\n", + "Test labels shape: (10000,)\n" + ] + } + ], + "source": [ + "from data.mnist import load_mnist\n", + "\n", + "# Load the MNIST dataset\n", + "train_images, train_labels, test_images, test_labels = load_mnist()\n", + "\n", + "# Print shapes to verify loading\n", + "print(f\"Training images shape: {train_images.shape}\")\n", + "print(f\"Training labels shape: {train_labels.shape}\")\n", + "print(f\"Test images shape: {test_images.shape}\")\n", + "print(f\"Test labels shape: {test_labels.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ad9d808d-7a30-4255-94dc-aac83c8023be", + "metadata": {}, + "source": [ + "Next, let's plot a few training points." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7322b0da-0d22-4742-8885-69500e549774", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA94AAADgCAYAAAD19b5rAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHtZJREFUeJzt3Xu0llWdB/DfEZCL3IaLWpYoiZoF4gVwGBIM1FIsDJIsRcsxV4iyXMIwOqRMpqGCKd5i6RIlWYtcIGo2jTYDWBYeIdJZaNAJJQJZBCI3LzAMz/zRgok5z37hhfNwOO/5fNbiD777/T3P5ni27/vjOexdlWVZFgAAAEAhDqvvCQAAAEAl03gDAABAgTTeAAAAUCCNNwAAABRI4w0AAAAF0ngDAABAgTTeAAAAUCCNNwAAABRI4w0AAAAF0ngXbMWKFVFVVRWTJk2qs2vOnz8/qqqqYv78+XV2TThYrAnYkzUBe7ImYE/WRGXQeOd4/PHHo6qqKhYtWlTfUynEhAkToqqqqtavFi1a1PfUOERV+pqIiFi9enVccskl0b59+2jbtm18+ctfjrfeequ+p8UhqjGsib917rnnRlVVVYwaNaq+p8IhqtLXxLJly+KGG26Ivn37RosWLaKqqipWrFhR39PiEFbpayIiYubMmXH66adHixYtonPnznHVVVfF+vXr63tah6ym9T0B6s/DDz8crVu33v37Jk2a1ONsoP5s3bo1zjnnnNi0aVPcfPPN0axZs/jhD38Y/fv3j9deey06duxY31OEevP000/HggUL6nsaUK8WLFgQU6ZMiVNOOSU+/elPx2uvvVbfU4J69fDDD8fIkSNj4MCBcc8998SqVavivvvui0WLFkV1dbUHejk03o3YsGHDolOnTvU9Dah3Dz30UNTU1MSrr74avXr1ioiIL37xi/HZz342Jk+eHHfccUc9zxDqx0cffRQ33nhjjBs3Lm655Zb6ng7Umy996UuxcePGaNOmTUyaNEnjTaO2ffv2uPnmm+Pss8+OX/ziF1FVVRUREX379o2LLrooHnnkkbjuuuvqeZaHHj9qvp+2b98et9xyS5xxxhnRrl27OOKII+Jzn/tczJs3L1nzwx/+MLp06RItW7aM/v37x5IlS2q9ZunSpTFs2LDo0KFDtGjRIs4888x47rnn9jqfDz74IJYuXVrWj3dkWRabN2+OLMv2uQZSGvKamDVrVvTq1Wt30x0RcfLJJ8fAgQPjqaee2ms95GnIa2KXu+66K3bu3BljxozZ5xpIachrokOHDtGmTZu9vg7K0VDXxJIlS2Ljxo0xfPjw3U13RMTgwYOjdevWMXPmzL3eqzHSeO+nzZs3x6OPPhoDBgyIO++8MyZMmBDr1q2L888/P/dvQadPnx5TpkyJa6+9Nm666aZYsmRJfP7zn4+1a9fufs0bb7wRZ511Vvz+97+Pf/7nf47JkyfHEUccEUOGDIk5c+aUnM+rr74an/70p+OBBx7Y5z9D165do127dtGmTZu47LLL9pgLlKuhromdO3fGf/3Xf8WZZ55Za6x3796xfPny2LJly759EeBvNNQ1scvKlStj4sSJceedd0bLli3L+rNDnoa+JqCuNdQ1sW3btoiI3PeGli1bxu9+97vYuXPnPnwFGpmMWqZNm5ZFRLZw4cLka3bs2JFt27Ztj+y9997LjjrqqOxb3/rW7uztt9/OIiJr2bJltmrVqt15dXV1FhHZDTfcsDsbOHBg1r179+yjjz7ane3cuTPr27dv1q1bt93ZvHnzsojI5s2bVyu79dZb9/rnu/fee7NRo0ZlM2bMyGbNmpWNHj06a9q0adatW7ds06ZNe62n8ankNbFu3bosIrLvfe97tcYefPDBLCKypUuXlrwGjU8lr4ldhg0blvXt23f37yMiu/baa/eplsanMayJXe6+++4sIrK33367rDoal0peE+vWrcuqqqqyq666ao986dKlWURkEZGtX7++5DUaI0+891OTJk3i8MMPj4i/PjHbsGFD7NixI84888xYvHhxrdcPGTIkjjnmmN2/7927d/Tp0yf+7d/+LSIiNmzYEHPnzo1LLrkktmzZEuvXr4/169fHu+++G+eff37U1NTE6tWrk/MZMGBAZFkWEyZM2OvcR48eHffff398/etfj6FDh8a9994bTzzxRNTU1MRDDz1U5lcC/qqhrokPP/wwIiKaN29ea2zXxiC7XgPlaKhrIiJi3rx5MXv27Lj33nvL+0NDCQ15TUARGuqa6NSpU1xyySXxxBNPxOTJk+Ott96KX/3qVzF8+PBo1qxZRPjslEfjfQCeeOKJ6NGjR7Ro0SI6duwYnTt3jp/97GexadOmWq/t1q1brezEE0/cfRTFH//4x8iyLL773e9G586d9/h16623RkTEX/7yl8L+LF//+tfj6KOPjv/4j/8o7B5Uvoa4Jnb9mNSuH5v6Wx999NEer4FyNcQ1sWPHjrj++uvj8ssv32PfA6gLDXFNQJEa6pqYOnVqXHDBBTFmzJj41Kc+FWeffXZ07949LrroooiIPU5O4q/sar6fnnzyybjyyitjyJAhMXbs2DjyyCOjSZMm8YMf/CCWL19e9vV2/TuIMWPGxPnnn5/7mhNOOOGA5rw3n/zkJ2PDhg2F3oPK1VDXRIcOHaJ58+axZs2aWmO7so9//OMHfB8an4a6JqZPnx7Lli2LqVOn1jqneMuWLbFixYo48sgjo1WrVgd8LxqXhromoCgNeU20a9cunn322Vi5cmWsWLEiunTpEl26dIm+fftG586do3379nVyn0qi8d5Ps2bNiq5du8bTTz+9x25+u/426f+rqamplf3hD3+I4447LiL+utFZRESzZs1i0KBBdT/hvciyLFasWBGnnXbaQb83laGhronDDjssunfvHosWLao1Vl1dHV27drWTLfuloa6JlStXxn//93/HP/zDP9Qamz59ekyfPj3mzJkTQ4YMKWwOVKaGuiagKJWwJo499tg49thjIyJi48aN8dvf/jaGDh16UO7d0PhR8/3UpEmTiIg9juKqrq6OBQsW5L7+mWee2ePfVLz66qtRXV0dX/ziFyMi4sgjj4wBAwbE1KlTc5+8rVu3ruR8yjkSI+9aDz/8cKxbty6+8IUv7LUe8jTkNTFs2LBYuHDhHs33smXLYu7cufHVr351r/WQp6Guia997WsxZ86cWr8iIi644IKYM2dO9OnTp+Q1IE9DXRNQlEpbEzfddFPs2LEjbrjhhv2qr3SeeJfw2GOPxb//+7/XykePHh2DBw+Op59+Oi6++OK48MIL4+23344f/ehHccopp8TWrVtr1ZxwwgnRr1+/+M53vhPbtm2Le++9Nzp27Bj/9E//tPs1Dz74YPTr1y+6d+8eV199dXTt2jXWrl0bCxYsiFWrVsXrr7+enOurr74a55xzTtx666173RChS5cuMXz48OjevXu0aNEiXn755Zg5c2b07Nkzrrnmmn3/AtHoVOqaGDlyZDzyyCNx4YUXxpgxY6JZs2Zxzz33xFFHHRU33njjvn+BaHQqcU2cfPLJcfLJJ+eOHX/88Z50U1IlromIiE2bNsX9998fERG//vWvIyLigQceiPbt20f79u1j1KhR+/LloRGq1DUxceLEWLJkSfTp0yeaNm0azzzzTLz44ovx/e9/3/4gKQd/I/VD367t/1O//vznP2c7d+7M7rjjjqxLly5Z8+bNs9NOOy17/vnnsyuuuCLr0qXL7mvt2v7/7rvvziZPnpx98pOfzJo3b5597nOfy15//fVa916+fHk2YsSI7Oijj86aNWuWHXPMMdngwYOzWbNm7X7NgR6J8Y//+I/ZKaeckrVp0yZr1qxZdsIJJ2Tjxo3LNm/efCBfNipYpa+JLMuyP//5z9mwYcOytm3bZq1bt84GDx6c1dTU7O+XjArXGNbE/xeOE6OESl8Tu+aU9+tv5w67VPqaeP7557PevXtnbdq0yVq1apWdddZZ2VNPPXUgX7KKV5Vlf/OzDQAAAECd8m+8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAK1HRfX1hVVVXkPKBeHMgx9tYElciagNr2d11YE1Qi7xNQ276sC0+8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEBN63sCAPvrjDPOSI6NGjUqNx8xYkSyZvr06bn5/fffn6xZvHhxcgwAACI88QYAAIBCabwBAACgQBpvAAAAKJDGGwAAAAqk8QYAAIACVWVZlu3TC6uqip5LRWrSpElu3q5duzq9T2oH51atWiVrTjrppNz82muvTdZMmjQpN7/00kuTNR999FFuPnHixGTNv/7rvybH6tI+fvvnsiYOnp49e+bmc+fOTda0bdu2zu6/adOm5FjHjh3r7D6HAmuCujBw4MDcfMaMGcma/v375+bLli2rkzkdiP1dF9ZE5Rk/fnxuXupzy2GH5T/nGjBgQLLmpZdeKmteB5P3CahtX9aFJ94AAABQII03AAAAFEjjDQAAAAXSeAMAAECBNN4AAABQoKb1PQEAKMfZZ5+dHEvtMj9nzpyipkOOXr165eYLFy48yDOB8l155ZXJsXHjxuXmO3fuLPs+B7I7ONDwNNrG+9hjj83NDz/88GRN3759c/N+/fola9q3b5+bDx06ND25g2TVqlW5+ZQpU5I1F198cW6+ZcuWZM3rr7+emx/KR2Vw8PXu3Ts5Nnv27Ny81LF8qQ80pb5Xt2/fnpuXOjLsrLPOys0XL15c9n0AAKhMftQcAAAACqTxBgAAgAJpvAEAAKBAGm8AAAAokMYbAAAAClTRu5r37NkzOTZ37tzcvNQuyQ1RqeMtxo8fn5tv3bo1WTNjxozcfM2aNcma9957LzdftmxZsoaGrVWrVsmx008/PTd/8sknkzUf+9jHDnhOu9TU1CTH7rrrrtx85syZyZpf//rXuXlqfUVE/OAHP0iOsXcDBgxIjnXr1i03d5xY3TvssPTf3R9//PG5eZcuXZI1VVVVBzwnqAulvk9btGhxEGcCf9WnT5/k2GWXXZab9+/fP1nzmc98puw5jBkzJjn2zjvv5OalTn5Kfe6rrq4ub2INiCfeAAAAUCCNNwAAABRI4w0AAAAF0ngDAABAgTTeAAAAUKCK3tUcgMozYsSI5NiCBQsO4kwat1KnDVx99dW5eanTC5YuXXrAc4JyDBo0KDe/7rrryr5Wqe/fwYMH5+Zr164t+z5UtuHDh+fm9913X7KmU6dOuXmpkyLmz5+fHOvcuXNufvfddydrUkrNIXWfr33ta2Xfp6Go6MZ75cqVybF33303Nz8UjhNLbaO/cePGZM0555yTm2/fvj1Z8+Mf/7isecG+mjp1anLs0ksvPYgzqS11nFlEROvWrXPzl156KVmTOtqqR48eZc0LAIDK5UfNAQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAVvav5hg0bkmNjx47NzVNHPkRE/O53v8vNp0yZUt7EIuK1115Ljp177rm5+fvvv5+s+cxnPpObjx49uqx5QTnOOOOM3PzCCy9M1pQ6WiIltav4T3/602TNpEmTcvN33nknWZNa4++9916y5vOf/3xuvj9/TvbNYYf5O+NDwaOPPlp2TU1NTQEzgbR+/folx6ZNm5ab788JN6WOWvrTn/5U9vVo+Jo2zW+zzjzzzGTNI488kpu3atUqWfPLX/4yN7/tttuSNS+//HJyrHnz5rn5U089law577zzkmMpixYtKrumofPpBQAAAAqk8QYAAIACabwBAACgQBpvAAAAKJDGGwAAAApU0buaA9Bw9ejRIzc/6qijDvJMyLM/Oz//4he/KGAmkHbFFVckxz7+8Y+Xfb358+fn5tOnTy/7WlS2yy67LDffnxMhSv2/c/jw4bn55s2by75Pqevtz87lq1atSo498cQTZV+voWu0jfczzzyTm8+dOzdZs2XLltz81FNPTdZcddVVuXnqqKOI0seGpbzxxhu5+be//e2yrwV/q2fPnsmx1BtB27ZtkzVZluXmP//5z5M1l156aW7ev3//ZM348eNz81JveOvWrcvNX3/99WTNzp07c/NSR6qdfvrpufnixYuTNQAANFx+1BwAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAK1Gh3NU/Zn633N23aVHbN1VdfnRz7yU9+kpundk+GunDiiSfm5mPHjk3WpI4TWr9+fbJmzZo1uXmpYyW2bt2am//sZz9L1pQaOxhatmyZHLvxxhtz82984xtFTadBuuCCC3LzUl9b6l7q+Lbjjz++7GutXr36QKcDuTp16pSbf+tb30rWpD5Xbdy4MVnz/e9/v6x5Udluu+225NjNN9+cm6dOd4mIeOihh3Lz1EktEft/bFjKv/zLv9TZta6//vrkWOokmUrmiTcAAAAUSOMNAAAABdJ4AwAAQIE03gAAAFAgjTcAAAAUyK7mABySTjrppLJr3njjjQJm0rhNmjQpN0/tdh4R8Yc//CE337JlS53MicbpuOOOS47Nnj27zu5z//33J8fmzZtXZ/eh4bjlllty89TO5RER27dvz81feOGFZM24ceNy8w8//LDE7PK1aNEiOXbeeeclx4499tjcvKqqKlmT2u3/2WefTdY0RhrvOjBhwoTk2BlnnJGb9+/fP1kzaNCg3PzFF18sa17w/zVv3jw5lvpwnTrSKSL9IXrEiBHJmkWLFuXmjemIqNSbGgAAlcmPmgMAAECBNN4AAABQII03AAAAFEjjDQAAAAXSeAMAAECB7GpeB95///3k2NVXX52bL168OFnzyCOP5OaljrBI7RT94IMPJmuyLEuOUZlOO+205Fip3ctTvvzlL+fmL730UtnXgrqwcOHC+p5CvWvbtm1y7Atf+EJuftlllyVrSh07k3Lbbbfl5hs3biz7WrBL6vs3IqJHjx5lX+8///M/c/P77ruv7GvR8LVv3z45NnLkyNy81Gfp1LFhQ4YMKWdae3XCCSfk5jNmzEjWpE5dKmXWrFnJsbvuuqvs6zVGnngDAABAgTTeAAAAUCCNNwAAABRI4w0AAAAF0ngDAABAgexqDkDF6NChw0G5z6mnnpqbV1VVJWsGDRqUm3/iE59I1hx++OG5+Te+8Y1kzWGHpf9O/cMPP8zNq6urkzXbtm3LzZs2TX+E+O1vf5scg71J7fo8ceLEsq/18ssvJ8euuOKK3HzTpk1l34eGL/X/24iITp06lX2966+/Pjc/8sgjkzXf/OY3c/MvfelLyZrPfvazuXnr1q2TNaV2Y0+NPfnkk8maUic88X803gVbvnx5bn7llVcma6ZNm5abX3755cma1NgRRxyRrJk+fXpuvmbNmmQNDds999yTHEs1DKWOBnNsWLrJ2blz50GeCQAAhyo/ag4AAAAF0ngDAABAgTTeAAAAUCCNNwAAABRI4w0AAAAFsqt5PZkzZ05yrKamJjcvtSP1wIEDc/M77rgjWdOlS5fc/Pbbb0/WrF69OjnGoWPw4MG5ec+ePZM1qeMjnnvuubqYUsVK7V5e6qiO1157raDZVJbU0VelvrY/+tGPcvObb765Tua0S48ePXLzUseJ7dixIzf/4IMPkjVvvvlmbv7YY48laxYtWpQcS51EsHbt2mTNqlWrcvOWLVsma5YuXZocg4iI4447Ljk2e/bsOrvPW2+9lRwr9X1P47N9+/bk2Lp163Lzzp07J2vefvvt3LzUe9j+eOedd3LzzZs3J2s+9rGPJcfWr1+fm//0pz8tb2LU4ok3AAAAFEjjDQAAAAXSeAMAAECBNN4AAABQII03AAAAFEjjDQAAAAVynNghaMmSJbn5JZdckqy56KKLcvNp06Yla6655prcvFu3bsmac889NznGoSN1zM/hhx+erPnLX/6Sm//kJz+pkzk1BM2bN8/NJ0yYUPa15s6dmxy76aabyr5eYzRy5Mjc/E9/+lOypm/fvkVNZw8rV67MzZ955plkze9///vc/JVXXqmLKR2Qb3/728mx1HE5pY5pgr0ZN25ccix1TOP+mDhxYp1di8q2cePG5NiQIUNy8+effz5Z06FDh9x8+fLlyZpnn302N3/88ceTNRs2bMjNZ86cmawpdZxYqToOjCfeAAAAUCCNNwAAABRI4w0AAAAF0ngDAABAgTTeAAAAUCC7mjcgpXZb/PGPf5ybP/roo8mapk3z//OfffbZyZoBAwbk5vPnz0/W0DBs27YtN1+zZs1BnkmxUjuXR0SMHz8+Nx87dmyyZtWqVbn55MmTkzVbt25NjrF3d955Z31PoeIMHDiw7JrZs2cXMBMqTc+ePXPz8847r07vk9oNetmyZXV6Hxqn6urq3Dx16sPBlPrc3r9//2RNqZMDnFhRHE+8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoEAabwAAACiQ48QOQT169MjNhw0blqzp1atXbp46MqyUN998Mzn2y1/+suzr0TA899xz9T2FOpU6wqbU0WDDhw/PzVPH1EREDB06tKx5QaWYM2dOfU+BBuDFF1/Mzf/u7/6u7Gu98sorybErr7yy7OtBJWjZsmVuXurIsCzLkmMzZ8484DmRzxNvAAAAKJDGGwAAAAqk8QYAAIACabwBAACgQBpvAAAAKJBdzQt20kkn5eajRo1K1nzlK1/JzY8++ug6mdMu//M//5Obr1mzJllTaodEDh1VVVVl5RERQ4YMyc1Hjx5dF1MqxA033JAc++53v5ubt2vXLlkzY8aM3HzEiBHlTQyAiIjo2LFjbr4/nyceeuih5NjWrVvLvh5UghdeeKG+p8A+8sQbAAAACqTxBgAAgAJpvAEAAKBAGm8AAAAokMYbAAAACqTxBgAAgAI5TqwMqeO8Lr300mRN6tiw4447ri6mtFeLFi1Kjt1+++25+XPPPVfUdDhIsiwrK49If39PmTIlWfPYY4/l5u+++26y5qyzzsrNL7/88mTNqaeempt/4hOfSNasXLkyNy917Eapo2qgsUodQ3jiiScma1555ZWipsMhaNq0acmxww6ru2c8v/nNb+rsWlApzj///PqeAvvIE28AAAAokMYbAAAACqTxBgAAgAJpvAEAAKBAGm8AAAAoUKPd1fyoo47KzU855ZRkzQMPPJCbn3zyyXUyp72prq5Ojt199925+bPPPpus2blz5wHPicrRpEmT3HzkyJHJmqFDh+bmmzdvTtZ069atvImVUGqH23nz5uXmt9xyS53dHxqD1GkIdblbNQ1Dz549c/NBgwYla1KfNbZv356sefDBB3PztWvXpicHjVTXrl3rewrsI++aAAAAUCCNNwAAABRI4w0AAAAF0ngDAABAgTTeAAAAUCCNNwAAABSoIo4T69ChQ24+derUZE3qSIyDtSV/qWOQJk+enJu/8MILyZoPP/zwgOdE5ViwYEFuvnDhwmRNr169yr7P0UcfnZunjusr5d13302OzZw5MzcfPXp02fcB6sbf//3fJ8cef/zxgzcRDpr27dvn5qn3glJWr16dHBszZkzZ14PG6le/+lVuXurIR0cK1w9PvAEAAKBAGm8AAAAokMYbAAAACqTxBgAAgAJpvAEAAKBAh9yu5n369MnNx44dm6zp3bt3bn7MMcfUyZz25oMPPkiOTZkyJTe/4447kjXvv//+Ac+Jxm3VqlW5+Ve+8pVkzTXXXJObjx8/vk7mtMt9992Xmz/88MPJmj/+8Y91Ogdg31VVVdX3FABIWLJkSW5eU1OTrCl1itOnPvWp3HzdunXlTYxaPPEGAACAAmm8AQAAoEAabwAAACiQxhsAAAAKpPEGAACAAmm8AQAAoECH3HFiF198cVn5/nrzzTdz8+effz5Zs2PHjtx88uTJyZqNGzeWNS8o0po1a5JjEyZMKCsHKsfPf/7z5NhXv/rVgzgTDmVLly7NzX/zm98ka/r161fUdIASSh1d/OijjybHbr/99tz8uuuuS9ak+ir25Ik3AAAAFEjjDQAAAAXSeAMAAECBNN4AAABQII03AAAAFKgqy7Jsn15YVVX0XOCg28dv/1zWBJXImoDa9nddWBNUIu8TDUPbtm2TY0899VRybNCgQbn5008/naz55je/mZu///77yZpKsy/rwhNvAAAAKJDGGwAAAAqk8QYAAIACabwBAACgQBpvAAAAKJDGGwAAAArkODEaNUdiwJ6sCajNcWLwf7xPNHyljhq7/fbbc/PvfOc7yZoePXrk5m+++WZ5E2vAHCcGAAAA9UzjDQAAAAXSeAMAAECBNN4AAABQII03AAAAFMiu5jRqduaEPVkTUJtdzeH/eJ+A2uxqDgAAAPVM4w0AAAAF0ngDAABAgTTeAAAAUCCNNwAAABRI4w0AAAAF2ufjxAAAAIDyeeINAAAABdJ4AwAAQIE03gAAAFAgjTcAAAAUSOMNAAAABdJ4AwAAQIE03gAAAFAgjTcAAAAUSOMNAAAABfpfJ9gHtUu3/aQAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# Create a figure with 5 subplots\n", + "fig, axes = plt.subplots(1, 5, figsize=(10, 3))\n", + "\n", + "# Plot each image and its label\n", + "for i in range(5):\n", + " axes[i].imshow(train_images[i], cmap='gray')\n", + " axes[i].set_title(f'Label: {train_labels[i]}')\n", + " axes[i].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b79b80ac-34fd-4f15-b83c-41d43f37c07e", + "metadata": {}, + "source": [ + "Now we flatten the inputs into vectors, encode the targets as one-hot vectors, and write a mini-batch sampler." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "X_train = jnp.reshape(train_images, (-1, 28, 28, 1))\n", + "y_train = train_labels\n", + "\n", + "# Getting a batch like this results in horrible performance!\n", + "# def get_batch(key, batch_size):\n", + "# idx = jax.random.choice(key, X_train.shape[1], shape=(batch_size,))\n", + "# return X_train[idx, :], y_train[idx]\n", + "\n", + "def stream(batch_size):\n", + " num_samples = X_train.shape[0]\n", + " while True:\n", + " key = jax.random.PRNGKey(0)\n", + " perm = jax.random.permutation(key, num_samples)\n", + " for i in range(num_samples // batch_size):\n", + " indices = perm[i * batch_size:(i + 1) * batch_size]\n", + " yield X_train[indices], y_train[indices]" + ] + }, + { + "cell_type": "markdown", + "id": "24819e5b-f810-4d55-ab53-d57f592eed82", + "metadata": {}, + "source": [ + "Now we're ready to build our ViT." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CompositeModule\n", + "...consists of 54 atoms and 91 bonds\n", + "...non-smooth\n", + "...input sensitivity is 51.72763417352536\n", + "...contributes proportion 12.0 to feature learning of any supermodule\n" + ] + } + ], + "source": [ + "from modula.compound import ViT\n", + "\n", + "input_dim = X_train.shape[1:-1]\n", + "output_dim = 10\n", + "\n", + "vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0)\n", + "\n", + "print(vit)\n", + "\n", + "vit.jit()" + ] + }, + { + "cell_type": "markdown", + "id": "fa210de2-0ea1-4466-96f3-7cc1e2aca11d", + "metadata": {}, + "source": [ + "Let's train the MLP for 1000 steps." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8abb8bc4-1643-4d1d-8dea-006e5c873b74", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.notebook import tqdm\n", + "\n", + "def cross_entropy_loss(w, inputs, targets):\n", + " # We use the logsumexp trick for stable cross entropy\n", + " logits = vit(inputs, w) # shape is [batch, num_classes]\n", + " batch_indices = jnp.arange(logits.shape[0])\n", + " losses = -logits[batch_indices, targets] + jax.nn.logsumexp(logits, axis=-1)\n", + " return losses.mean()\n", + "\n", + "loss_and_grad = jax.jit(jax.value_and_grad(cross_entropy_loss))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7b9d477e-afa9-4150-851e-5c685e1dad01", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9c5737ccfa3f4c6eb73fe117321154c9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loss: 0.0000: 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy on shown samples: 5/5\n", + "Overall test accuracy: 94.69%\n" + ] + } + ], + "source": [ + "# Get predictions for test images\n", + "X_test = jnp.reshape(test_images, (-1, 28, 28, 1))\n", + "test_outputs = vit(X_test, w)\n", + "predicted_labels = jnp.argmax(test_outputs, axis=1)\n", + "\n", + "# Create a figure with subplots for multiple test images\n", + "n_samples = 5 # Number of test images to display\n", + "fig, axes = plt.subplots(1, n_samples, figsize=(10, 3))\n", + "\n", + "# Plot each test image with predicted labels\n", + "for i in range(n_samples):\n", + " axes[i].imshow(test_images[i], cmap='gray')\n", + " axes[i].set_title(f'Predicted label: {int(predicted_labels[i])}')\n", + " axes[i].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "# Print accuracy for these samples\n", + "correct = (predicted_labels[:n_samples] == test_labels[:n_samples]).sum()\n", + "print(f\"Accuracy on shown samples: {correct}/{n_samples}\")\n", + "\n", + "# Calculate and print overall test accuracy\n", + "total_correct = (predicted_labels == test_labels).sum()\n", + "total_samples = len(test_labels)\n", + "print(f\"Overall test accuracy: {100 * total_correct/total_samples:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ca803b74-6095-482b-8608-ddd43e478f4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Barebone without bias or LN\n", + "# vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=False, bias=False, scale=False)\n", + "# Just LN, no bias or scale\n", + "# vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=True, bias=False, scale=False)\n", + "# Both would diverge" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "860605de-198b-402e-b77a-6e4e0f4ded4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CompositeModule\n", + "...consists of 40 atoms and 77 bonds\n", + "...non-smooth\n", + "...input sensitivity is 2.84765625\n", + "...contributes proportion 10.0 to feature learning of any supermodule\n" + ] + } + ], + "source": [ + "# No LN or scale, just bias:\n", + "vit = ViT(output_dim, image_size=input_dim, patch_size=(7, 7), num_heads=8, d_embed=64, d_query=8, d_value=8, num_blocks=3, blocks_mass=6, attention_scale=1.0, final_scale=1.0, LN=False, bias=True, scale=False)\n", + "\n", + "print(vit)\n", + "\n", + "vit.jit()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "61303c30-072d-4316-8b7c-8b9bdeca5173", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "56f30e99a1ad487e9b8342c710ef8cc6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loss: 0.0000: 0%| | 0/1000 [00:00" ] @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 3, "id": "5a7a804b-06ec-4773-864c-db8a3b01c3e1", "metadata": {}, "outputs": [], @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 4, "id": "a7a14a1b-1428-4432-8e89-6b7cfed3d765", "metadata": {}, "outputs": [ @@ -172,14 +172,14 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 5, "id": "080bbf4f-0b73-4d6a-a3d5-f64a2875da9c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b71f8a04aa254476b5b883d8121bbf3d", + "model_id": "ba34de950ef549d3bb7c3408ec946be5", "version_major": 2, "version_minor": 0 }, @@ -229,13 +229,13 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "id": "9a08a8ea-d1e8-49b5-8166-05dcbde47f4c", "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAA94AAADgCAYAAAD19b5rAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHkhJREFUeJzt3QmQVNXZMOA7CCigQURQNAqI0YiIAuKK+4YoigpKJMYlATVuKfc9iluCiYkrlqmKe5SIuyJuhSBxiRB3QSNGkCgRBRXEhaX/Ovev4Rvg3pGGPjNM9/NUTYjv6XPvme5+e/q9yzlVhUKhkAAAAABRNIqzWQAAACBQeAMAAEBECm8AAACISOENAAAAESm8AQAAICKFNwAAAESk8AYAAICIFN4AAAAQkcIbAAAAIqrYwrtDhw7JMcccs/i/n3vuuaSqqir9d1UdY54w7ksuuaTo7d92221p3wkTJiSlEsYRtknDIyfkBEuSE3KCZckLecGS5IScWKUL7+oXp/pnjTXWSDbbbLPk5JNPTv73v/8lDcmoUaNW6A3Ksj788MMl3hdL/wwePDgpV3KCLPPmzUtuvPHGZN99903atWuXrLXWWkm3bt2S4cOHJwsXLkzKmZwgz1NPPZX88pe/TLp06ZKsttpq6RfKSiEvqM0LL7yQ9OrVK2nevHmy/vrrJ6eeemoyd+7cpJzJCZbHF198kbRt2zZ9j4wcOTKpL43rbc9JkgwdOjTp2LFj8u233ybjx49Pv0yGN91bb72VfmjUpV133TX55ptvkqZNmxbVL4w3fDGWKCuvTZs2yZ133rlMfPTo0cndd9+dFh/lTk5Q0wcffJCccsopyV577ZWcfvrpyY9+9KPkySefTH79618nL730UnL77bcn5U5OsLS//e1vyYgRI5Lu3bsnG2ywQVKJ5AVLe+2119K/FVtssUVyzTXXJNOnT0/+8Ic/JP/+97+TJ554Iil3coLaXHzxxenJjPpWr4X3/vvvn2y77bbp///Vr36VtG7dOv2wePjhh5Of/exnmX2+/vrrpEWLFiUfS6NGjdKjZNSf8Lr+/Oc/zzyaGQqOvn37JuVOTlBTOGPx5ptvJltuueXi2PHHH58cd9xxya233ppcdNFFyaabbpqUMznB0q688srkL3/5S9KkSZPkwAMPTL9YVxp5wdLOP//8pFWrVunlzeE7UxCuBglXC4arRMr95IWcIE/4GxEOxITiO/zUp1XqHu8999wz/fc///lP+m+4F2HNNddMpkyZkvTp0ye9zHLQoEFp26JFi5I///nP6RfS8OZeb7310i+ks2fPXmKbhUIhufzyy5Mf//jH6RGvPfbYI3n77beX2Xfe/Rgvv/xyuu/wYRaSs2vXrsm11167eHzhyFRQ8zKXaqUe4/KaOnVqekZs8803T5o1a5Z++AwYMCC9lDtLOAIUxhUeFz6sf/GLXywzxiAcMd1ll13S5yG8FgcccMByjfOzzz5LJk+evEJHmj755JNkzJgxyaGHHlqRH2JyorJzYt11112i6K52yCGHpP9OmjQpqTRyorJzIghnuUPRzf+RF5WdF1999VXy9NNPpycvqovuIIwnvA/+/ve/J5VGTlR2TtR02mmnpd+bwv4q+oz30kIyBOHFqrZgwYJkv/32S+9ZCZfMVF8uEl7UcCb02GOPTe9hCYl1ww03JK+++mryj3/8Y/Ef5XBkI7wBwxs9/PzrX/9Kj/p9//33Pzie8CEWjqaHeyvDixbOPoUvuo899lj632EMH3/8cfq4rEuk62KMWV555ZX0Pp+BAwemiReSIxzp2X333ZN33nlnmUtuwn0wa6+9dnppy7vvvps+NiRa9QdHEH6/o48+On0tfv/736dv+PC48LqE36e2e+zC73zppZemBXQYQzHuvffe9MOm+sOx0sgJOZFlxowZiwvzSiMn5ATLkheVnRfhyqjwelef8a0WLnXeZptt0v1UGjlR2TlR7b777kvHH57rvIMFdapQD2699dZC2PUzzzxTmDlzZuGjjz4q3HvvvYXWrVsXmjVrVpg+fXr6uKOPPjp93LnnnrtE/+effz6N33333UvER48evUT8008/LTRt2rRwwAEHFBYtWrT4ceeff376uLD9amPGjElj4d9gwYIFhY4dOxbat29fmD179hL7qbmtk046Ke23tBhjzBMe99vf/nbxf8+bN2+Zx7z44ovp4+64445lXocePXoUvv/++8XxYcOGpfGHH344/e85c+YU1l577cLgwYOX2OaMGTMKLVu2XCIexrH081Edq35uixHG1q5du8LChQsL5UxOyInl9d133xU6d+6cvhbz588vlCs5ISeWR3hOwvNfKeSFvMhy3333pY8bN27cMm0DBgworL/++oVyJSfkRJ4w9o033rhw3nnnLfG6hHypL/V6qfnee++dTqi10UYbpUdSwiUgDz74YLLhhhsu8bgTTzxxmaMXLVu2TPbZZ5/0koPqnx49eqTbCEdBgmeeeSY9whMmJ6p5ucZvfvObHxxbOOISjiaFx4YjNzUtz9T2dTHGPOFSkGrz589PPv/88/Q+0PB7hCNfSxsyZMgSl+2F57tx48bpJA9BOPoWZgMM98jU/F3CbLLbb7/94t8lTzjqFXK52LMY7733XjJx4sT0vRHul6kEckJO/JBwNDkcZQ5HfcOYyp2ckBMsS17Ii5rCRF7B6quvvkxbuCy5ur2cyQk5sbTf/e536ZjD/Aerinr91hbuZQhT/ocXJNyrEO4fWLrACm3h0oaawgyNX375ZTotfJZPP/00/Tdc2hD85Cc/WaI9JGa4v2J5LlEJy5WsiLoYY57wAXvVVVelky/997//Td+g1cKYlrb0vkMSh0thqi/JCL9LzftlllbzfqJSCjOZB5V0mbmckBO1ufrqq9NJpS677LL0ErJKICfkBMuSF/Iiqzj67rvvlmkLs3zXLJ7KlZyQEzWF/YXvTOF9EcawqqjXwnu77bZb5n6UpYWjd0snTrjnN7z5qguzpYU3WH2rzzGGI10hQcIRrh133DE9ShaOfIUjgGFcxaruE+7JCPekLC3WWbewZEz44AxH9CqFnIijHHIi3Nt1zjnnJCeccEJy4YUXJpVCTsRRDjlRyeRFHA01L0JhUz0h7dJCrBKW3ZMTcTTUnLj44ovTqx3CmfHqor96fpyZM2emsY033rjOr6htkH8JO3XqlF5OsfPOO9d6FK99+/aLj65ssskmi+PhCc+aYW/pfVRPQR8uX8mTd4lIXYwxT1gYPkxa8Mc//nGJI57h0o4sYd9h1sNqc+fOTT+oq8+oVT8XIelrey5KKcz8+P7776frMvLD5ER550RYDiUsjxJm96+e9ZTayYnyzglWjLwoz7wIZ1JDwTJhwoTk8MMPXxwPlx2H9b1rxliSnCjPnJg2bVpaR9R8HqqFWdqD8Jwsfel/bA3yxtnwAbJw4cL0csulhRkLq98M4QUN9xlcf/31S1waEabj/yHdu3dPOnbsmD526TdXzW1Vr/+39GPqYox5wn0SNbcVhO2H8WS55ZZb0nsgqoWZBcMYw5qIQZh1MFz6EdZOrfm4mgld6qn/w9nu4Mgjj1zuPpVMTpRvTowbNy49srzrrrumR7srZb6DlSUnyjcnWHHyojzzIpyFDM/HXXfdlcyZM2dxPJxVDIVPWP6JbHKiPHPi8ssvT+/xr/lT/fydffbZ6X/HWMO9LM9477bbbum0+uGeg3AkL0yTH95o4ShLmIAgrInXv3//9NKLM888M31cmMI/HG0JExyEteN+aBme8OU2vFn69u2bLsUQpu8Pl/KEFzusM/fkk0+mj6u+DDpM7R/eTOENGr4k18UY84TthA/b8EHcuXPn5MUXX0yPlNVcUqGmcER0r732ShM7TP1/0003pVP6H3TQQWl7SJDwXBx11FHph0f4/cK4w9Gkxx9/PD0CFyZ6KtXU/yGZR4wYkeywww6Lj4xROzlRnjkR7tcK+wxHwcNzE56nmsIaoOGHZcmJ8syJ4I033kgeeeSR9P+HMxrhPsPwJSvYeuut09eDbPKifPPiiiuuSHbaaaf0+QsTXE2fPj09Sxmev969e6/Q81EJ5ER55kSvXr2WiVWf3e7Zs2fSr1+/pF7Ux1Tq1VPOv/LKK7U+Lkx736JFi9z2W265JZ22PiwXsNZaaxW22mqrwtlnn134+OOPFz8mLEN16aWXpktShcftvvvuhbfeeiud0r+2qf+rjR8/vrDPPvuk2w9j6dq1a+H6669f3B6WCDjllFMKbdq0KVRVVS0z7X0px7i8U/+HpQqOPfbYwrrrrltYc801C/vtt19h8uTJy2yv+nUYO3ZsYciQIYVWrVqljx80aFDh888/X2Y/4bkJ2wrT/a+xxhqFTp06FY455pjChAkTSrpMTPXyCNddd12hUsgJOZGl+jXI+6n5O5YbOSEn8lSPKetneZ6LhkxeyIvahGWndtppp3Qf4XkNy1N99dVXhXImJ+TE8loVlhOrCv9TPyU/AAAAlD83CwIAAEBECm8AAACISOENAAAAESm8AQAAICKFNwAAAESk8AYAAICIFN4AAAAQUePlfWBVVVXMcUC9WJll7OUE5UhOQOnyQk5QjvydgBXLC2e8AQAAICKFNwAAAESk8AYAAICIFN4AAAAQkcIbAAAAIlJ4AwAAQEQKbwAAAIhI4Q0AAAARKbwBAAAgIoU3AAAARKTwBgAAgIgU3gAAABCRwhsAAAAiUngDAABARApvAAAAiEjhDQAAABEpvAEAACAihTcAAABE1DjmxoHyduaZZ2bGmzVrltuna9eumfH+/fsXvf/hw4fntr344ouZ8TvvvLPo/QAAwMpwxhsAAAAiUngDAABARApvAAAAiEjhDQAAABEpvAEAACAis5oDQIXYbLPNMuOTJ0/O7XPaaadlxq+//vqSjQtqatGiRWb86quvzu1z/PHHZ8YnTpyY22fAgAGZ8alTp/7gGAGKpfAGajVixIjcthVZAizPokWLiu6T90Ur2HvvvTPjY8eOze0zbdq0oscAAAA/xKXmAAAAEJHCGwAAACJSeAMAAEBECm8AAACISOENAAAAEZnVHKh19vJSzlxe27JFTz75ZG6fTTbZJDPet2/f3D6dOnXKjA8aNCi3z1VXXZXbBuWgW7duRa8qMH369IgjgmW1a9cuMz548ODcPnnv4R49euT2OfDAAzPjN9544w+OEVZU9+7dc9seeOCBzHiHDh2ShmjffffNbZs0aVJm/KOPPkrKlTPeAAAAEJHCGwAAACJSeAMAAEBECm8AAACISOENAAAAEZnVHAAqxDbbbJMZ//rrr3P7PPjggxFHRKVq06ZNbtvtt99ep2OBurTffvvltq2++upJOalt9ZnjjjsuMz5w4MCkXCm8oYJsu+22uW2HHHJI0dt7++23M+MHHXRQbp/PPvssMz537tzcPk2bNs2Mv/TSS7l9tt5668x469atc/sAAEAMLjUHAACAiBTeAAAAEJHCGwAAACJSeAMAAEBECm8AAACIqCxmNe/fv39mfPDgwbl9Pv7448z4t99+m9vn7rvvzozPmDEjt8/777+f2wZ1rV27drltVVVVRc1cXtuSGJ988klSSmeccUZmvHPnzkVv6/HHHy/BiGDV1aVLl9y2k08+OTN+5513RhwRlezUU0/NjPfr1y+3z3bbbZfUhV133TUz3qhR/nmp119/PTM+bty4ko2L8tC4cXaZ1adPn6RSTJw4Mbft9NNPz4y3aNEit09tS182BM54AwAAQEQKbwAAAIhI4Q0AAAARKbwBAAAgIoU3AAAARFQWs5oDAP/fT3/609y2vNliR4wYEXFEVLI//elPmfFFixYl9e3QQw8tKh5MnTo1M37EEUes0MzOlK899tgjM77jjjvm9hk2bFhSTlq1apXb1jlnZZrmzZuX7azmZVF4571JO3ToUNL9HH/88ZnxOXPm5PapbSmmhmj69OlFf1BMmDAh4ogoxqOPPprbtummmxb9/p41a1ZSFwYOHJgZb9KkSZ3sHwAAVoZLzQEAACAihTcAAABEpPAGAACAiBTeAAAAEJHCGwAAACIqi1nNBw8enBnv2rVrbp9JkyZlxrfYYovcPt27d8+M77777rl9dthhh8z4Rx99lNtno402SkplwYIFuW0zZ87MjLdr167o/UybNi23zazmDUPeEil15ayzzspt22yzzYre3ssvv1xUHMrF2WefXXSe+5xmZYwaNSq3rVGj+j3H8/nnn+e2zZ07NzPevn373D4dO3bMjP/zn//M7bPaaqvVOkYari5duuS23XPPPZnxKVOm5Pa58sork3Jy8MEH1/cQVinOeAMAAEBECm8AAACISOENAAAAESm8AQAAICKFNwAAAERUFrOaA0Al6dChQ27btttum9v23nvvZca//vrrkoyL8rbbbrtlxjfffPPcPosWLSoqvqJuvvnmzPhTTz2V2+fLL7/MjO+55565fS644IKix3biiSdmxocPH170tli1XHjhhbltLVq0yIz37t276Jn2V3XrrLNOUZ8ZMT4DGoKyKLyfffbZouK1GT16dNF9WrVqldu2zTbbZMYnTpyY26dnz55JqXz77bdFfwHLW2qttsSqbWkEqOnAAw/MjA8dOjS3T9OmTTPjn376aW6f8847LzM+b968HxwjAACUkkvNAQAAICKFNwAAAESk8AYAAICIFN4AAAAQkcIbAAAAIiqLWc3r2+zZs3PbxowZU/T2VmQ29hVx2GGHFT1L+5tvvpkZHzFiRMnGRXnLW+oob+by2tT2vhs7dmzR24OGorYlWmozc+bMko+Fylmq7t57782Mr7vuuiUdw9SpUzPj999/f26fSy+9tGQrWeTtPxgyZEhmvE2bNrl9hg0blhlfY401cvvccMMNmfH58+fn9iGe/v37Z8b79OmT2+f999/PjE+YMCEpN3nL7NW2ZNhzzz2XGf/iiy+ScuWMNwAAAESk8AYAAICIFN4AAAAQkcIbAAAAIlJ4AwAAQERmNQeABmarrbZaoX55sytDtcaN878alnL28tpWnhg4cGBm/LPPPkvqQm2zml911VWZ8WuuuSa3T/PmzYvOx0ceeSQzPmXKlNw+xDNgwICiXtvgpptuSiplxYNBgwZlxhcuXJjb5/LLL6+4mfsV3mWubdu2RX8gNGqUfyHE0KFDM+OzZs1agdFRrh566KHctn333bfo7d1xxx2Z8QsvvLDobQEAQF1zqTkAAABEpPAGAACAiBTeAAAAEJHCGwAAACJSeAMAAEBEZjUvcyeddFJuW5s2bTLjs2fPzu3z7rvvlmRclId27dplxnfaaafcPquvvnrRy8TkLTkxd+7cHxwjNGQ77LBDZvzYY4/N7fPqq6/mtj399NMlGRcsrwkTJmTGjzvuuNw+dbVs2IrIW+YrbzmloGfPnhFHRKm0bNmy6M/i2gwfPjwpJ0OGDCl6qcFJkybl9hkzZkxSaZzxBgAAgIgU3gAAABCRwhsAAAAiUngDAABARApvAAAAiMis5gCwitp7770z4+uss05un9GjR+e2ffvttyUZF5WpUaPiz9dsv/32STmpqqoq+rlZkeftkksuyYwfddRRRW+L5ZO36kqw4YYbZsbvueeepFJ06tSp6D5vvfVWlLE0VArvMrHzzjtnxs8999yit9WvX7/cNglETffff39mvHXr1kVv66677sptmzJlStHbAwCAVYVLzQEAACAihTcAAABEpPAGAACAiBTeAAAAEJHCGwAAACIyq3mZ6NOnT2a8SZMmuX2effbZzPiLL75YsnHR8B100EG5bd27dy96e88991xm/Le//W3R24Jyt/XWW2fGC4VCbp+RI0dGHBHl7oQTTshtW7RoUVLp+vbtmxnv1q1b0c9bbc9n3nJixDNnzpzcttdeey0z3rVr19w+ecs+zpo1K1mVtW3bNjPev3//orc1fvz4EoyofDjjDQAAABEpvAEAACAihTcAAABEpPAGAACAiBTeAAAAEJHCGwAAACKynFgD0qxZs9y23r17Z8a///773D55yzfNnz9/BUZHQ9e6devM+Pnnn5/bp7bl6opdkmPu3LlFbwvKwfrrr5/btssuu2TG33333dw+Dz74YEnGRWXKWy6rHLVp0yYz3rlz59w+tf1NLNbMmTNz23wXq3vffPNNbtuUKVMy44cddlhun8cffzwzfs011yR1oUuXLrltm2yySW5bhw4dil7GMo8lCJfkjDcAAABEpPAGAACAiBTeAAAAEJHCGwAAACJSeAMAAEBEZjVvQM4666zctm7dumXGR48endvnhRdeKMm4KA9nnHFGZrxnz55Fb+uhhx4qejZ9qFTHHHNMblvbtm0z40888UTEEUFluOCCCzLjJ510Ukn38+GHH2bGjz766Nw+06ZNK+kYWDl5312qqqpy+xxwwAGZ8XvuuSepC5999lluW20zlK+77rolG8Ntt91Wsm2VA2e8AQAAICKFNwAAAESk8AYAAICIFN4AAAAQkcIbAAAAIlJ4AwAAQESWE1sF5S0/cNFFF+X2+eqrrzLjQ4cOLdm4KG+nn356ybZ18skn57bNnTu3ZPuBctC+ffui+8yePTvKWKDcjBo1Krdt8803r5MxvPPOO5nx8ePH18n+WXmTJ0/OjB9++OG5fbbZZpvM+KabbprUhZEjR65Qv9tvvz0zPmjQoKK39c0336zQGMqVM94AAAAQkcIbAAAAIlJ4AwAAQEQKbwAAAIhI4Q0AAAARmdW8nrRu3Tq37brrrsuMr7baakXP2vnSSy+twOhg5ayzzjq5bfPnz6+TMXz55ZdF779JkyaZ8ZYtWxa9/7XXXrtOZpBfuHBhbts555yTGZ83b17J9s/KO/DAA4vu8+ijj0YZC1RVVeW2NWpU/Pma/fffv+g+t9xyS2Z8gw02KHpbtY150aJFSV3o27dvneyHVctrr71WVHxV8cEHH5RsW126dMlte+utt5JK44w3AAAARKTwBgAAgIgU3gAAABCRwhsAAAAiUngDAABARApvAAAAiMhyYpHlLQE2evTo3D4dO3bMjE+ZMiW3z0UXXbQCo4M43njjjfoeQnLfffdlxj/55JPcPuutt15m/IgjjkgaohkzZmTGr7jiijofC0nSq1evzPj6669f52OBPMOHD89tGzZsWNHbe+yxx0q2lFepl/8q5fZuvvnmkm0LVsUlBWtbajBPJS4ZVhtnvAEAACAihTcAAABEpPAGAACAiBTeAAAAEJHCGwAAACIyq3lknTp1yoz36NGj6G2dfvrpuW21zXgOy2PUqFGZ8YMPPjhpiAYMGFAn+1mwYEHJZst95JFHctsmTJhQ9Paef/75ovsQzyGHHFLU6hfBq6++mhkfN25cycYFNT3wwAO5bWeddVZmvE2bNklDNHPmzMz4pEmTcvsMGTKk6BUzoCEpFApFxVl+zngDAABARApvAAAAiEjhDQAAABEpvAEAACAihTcAAABEpPAGAACAiCwnVgLt27fPbXvqqaeK3l7ech2PPfZY0duC5XXooYdmxs8+++zcPk2aNCnZ/rfccsvctiOOOKJk+/nrX/+a2/bhhx8Wvb37778/Mz558uSit0XD17x589y2Pn36FL29kSNHZsYXLlxY9LZgeUydOjW3beDAgZnxfv365fY57bTTklXVFVdckRm/8cYb63wssKpYY401iu7zzTffRBlLuXHGGwAAACJSeAMAAEBECm8AAACISOENAAAAESm8AQAAIKKqQqFQWK4HVlXFHEeDljcrZnDeeecVvb3tttsuMz5hwoSit0XtlvPtn0lOUI7kxMqpbab/sWPHZsY//fTT3D5HHnlkZnzevHkrMDrqOi/kRJL07t07Mz5kyJDcPn379s2MP/LII7l9brnllqJfg3feeSczPm3atNw++DtR7mbMmJEZb9w4fzGsyy67LDN+7bXXlmxc5ZAXzngDAABARApvAAAAiEjhDQAAABEpvAEAACAihTcAAABEpPAGAACAiCwnVoRevXplxkeNGpXbZ8011yx6P5YTqzuWxIAlyQlYluXE4P/4O1HeHn300cz4Nddck9tnzJgxSaUrWE4MAAAA6pfCGwAAACJSeAMAAEBECm8AAACISOENAAAAETWOufFys8suu5Rs5vIpU6bkts2dO7fo7QEAAKyMvn371vcQypYz3gAAABCRwhsAAAAiUngDAABARApvAAAAiEjhDQAAABEpvAEAACAiy4lF9vrrr2fG99prr9w+s2bNijgiAAAA6pIz3gAAABCRwhsAAAAiUngDAABARApvAAAAiEjhDQAAABFVFQqFwnI9sKoq5jigXizn2z+TnKAcyQkoXV7ICcqRvxOwYnnhjDcAAABEpPAGAACAiBTeAAAAEJHCGwAAACJSeAMAAEBECm8AAABYFZYTAwAAAIrnjDcAAABEpPAGAACAiBTeAAAAEJHCGwAAACJSeAMAAEBECm8AAACISOENAAAAESm8AQAAICKFNwAAACTx/D/o14rPqwfOdQAAAABJRU5ErkJggg==", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA94AAADgCAYAAAD19b5rAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAIWZJREFUeJzt3X9UVVUWwPH9FBQERUMI+oWmo5MRqWQ/MS0NDMHU1CgrtUmo/NWysrIyM8sZK6dSomVrjRZZOViWFZraQs3RmjCsTDSlwJx0xEST0OTHnT9aMALnPLnwznu8x/ezlmvpPu+cu32wwe29nOOwLMsSAAAAAABgRCtPJwAAAAAAgC+j8QYAAAAAwCAabwAAAAAADKLxBgAAAADAIBpvAAAAAAAMovEGAAAAAMAgGm8AAAAAAAyi8QYAAAAAwCAabwAAAAAADGqxjXeXLl1k/PjxNX/esGGDOBwO2bBhg8dyqqtujjoOh0Nmz55te/2lS5eKw+GQ3Nxc+8lpzJ49WxwOh8vWg/tQE9QEaqMmqAnUR11QF6iNmqAmGsojjXf1B6f6V0BAgPTo0UMmT54s//3vfz2RUqNlZ2c36hMU9RUWFtb6vKj7a+LEiZ5O0RhqAiplZWWSnp4u8fHxEhkZKe3bt5c+ffpIRkaGVFZWejo9o6gJ6Kxdu1b+8pe/SHR0tLRu3Vq6dOni6ZTchrqAM1u2bJG4uDhp166dREREyNSpU6W0tNTTaRlFTaAhjh49KuHh4eJwOGTFihUey8PPY1cWkTlz5kjXrl3l5MmTsnnzZsnIyJDs7GzZsWOHtGvXzq25XHvttXLixAlp06aNrXnZ2dmSnp5OobhAWFiYZGZm1ouvWbNGli1bJvHx8R7Iyr2oCZzuhx9+kClTpsigQYNk+vTp0qFDB/nkk0/kvvvuk88//1xef/11T6doHDWBut566y1Zvny59O3bV8455xxPp+MR1AXq2r59uwwaNEguuugiWbBggezfv1+ef/552bNnj6xevdrT6RlHTcCZWbNmSVlZmafT8GzjfeONN8pll10mIiJ33323hIaGyoIFC+SDDz6QW2+9VTnnt99+k6CgIJfn0qpVKwkICHD5umi4oKAguf322+vFly5dKh06dJDk5GQPZOVe1AROFxERId9++61cfPHFNbG0tDS56667ZMmSJfLEE09I9+7dPZihedQE6nr22WfltddeE39/f0lKSpIdO3Z4OiW3oy5Q18yZM6VTp06yYcMG6dChg4j88XjxxIkTZe3atT5/84KagM6OHTskIyNDZs2aJbNmzfJoLs3qZ7yvv/56ERH58ccfRURk/PjxEhwcLAUFBZKYmCjt27eXsWPHiohIVVWVvPjii3LxxRdLQECAnH322ZKWliYlJSW11rQsS+bOnSvnnXeetGvXTq677jr57rvv6l1b9/MYX3zxhSQmJkqnTp0kKChIYmJi5KWXXqrJLz09XUSk1mMu1VydY0MVFRXJfffdJz179pTAwEAJDQ2V0aNHS2FhofL1ZWVlkpaWJqGhodKhQwe588476+UoIrJ69Wrp37+/BAUFSfv27WXo0KENyvPw4cOya9euRv1P04EDByQnJ0dGjhzZIr+IURMtuyY6d+5cq+muNmLECBERyc/PP+O1fA010bJrQkTknHPOEX9//zO+riWhLlp2Xfz666+ybt06uf3222uabhGRO++8U4KDg+Wf//znGa/la6iJll0Tp5s2bZqMGDFC+vfv3+A5pnj0jnddBQUFIiISGhpaE6uoqJCEhASJi4uT559/vuZxkbS0NFm6dKlMmDBBpk6dKj/++KMsWrRI8vLy5F//+lfNN+VZs2bJ3LlzJTExURITE+Wrr76S+Ph4OXXq1BnzWbdunSQlJUlkZKRMmzZNIiIiJD8/Xz766COZNm2apKWlyc8//yzr1q1TPiLtjhxVvvzyS9myZYukpKTIeeedJ4WFhZKRkSEDBw6UnTt31nvkZvLkydKxY0eZPXu27N69WzIyMqSoqKjmC4eISGZmpowbN04SEhLkb3/7m5SVlUlGRobExcVJXl6e05+xW7RokTz11FOSk5MjAwcOtPV3eeedd6Sqqqrmi2NLQ01QEyoHDx4UkT8a85aGmqAmUB910bLr4ttvv5WKioqaO77V2rRpI71795a8vLxGvR/ejJpo2TVRLSsrS7Zs2SL5+fna/yxwK8sDlixZYomItX79equ4uNj66aefrHfeeccKDQ21AgMDrf3791uWZVnjxo2zRMR65JFHas3/7LPPLBGxli1bViu+Zs2aWvFDhw5Zbdq0sYYOHWpVVVXVvG7mzJmWiFjjxo2rieXk5FgiYuXk5FiWZVkVFRVW165draioKKukpKTWdU5fa9KkSZbqbTSRo46IWE8++WTNn8vKyuq9ZuvWrZaIWG+88UZNrPrjEBsba506daomPn/+fEtErA8++MCyLMs6fvy41bFjR2vixIm11jx48KAVEhJSK/7kk0/Wez+qY9XvrR2xsbFWZGSkVVlZaXuuN6EmqImG+v33361evXpZXbt2tcrLy23P9xbUBDXREEOHDrWioqJszfFm1AV1oZKVlWWJiLVp06Z6Y6NHj7YiIiKczvdm1AQ1oVNWVmZdcMEF1qOPPmpZ1v8/LllZWWeca4pHHzUfPHiwhIWFyfnnny8pKSkSHBwsK1eulHPPPbfW6+69995af87KypKQkBC54YYb5PDhwzW/YmNjJTg4WHJyckREZP369XLq1CmZMmVKrcc17r///jPmlpeXJz/++KPcf//90rFjx1pjDdna3h056gQGBtb8vry8XH755Rfp3r27dOzYUb766qt6r09NTa312N69994rfn5+kp2dLSJ//C/d0aNH5dZbb631d2ndurVcccUVNX8XndmzZ4tlWbbvYnz//feybds2SUlJkVatmtVPRRhDTVATZzJ58mTZuXOnLFq0SPz8mtVDS0ZQE9QE6qMuqIvTnThxQkRE2rZtW28sICCgZtyXURPURF1//etfpby8XGbOnNnAv615Hv1XW3p6uvTo0UP8/Pzk7LPPlp49e9ZrsPz8/OS8886rFduzZ48cO3ZMwsPDleseOnRIRP74uQQRkT/96U+1xsPCwqRTp05Oc6t+RCU6OrrhfyE356hz4sQJmTdvnixZskT+85//iGVZNWPHjh2r9/q61w4ODpbIyMiaRzL27NkjIv//eZm6Tv95IldatmyZiEiLesycmqAmnHnuuefktddek6effloSExONXKO5oSaoCdRHXVAXp6tujn7//fd6YydPnqzVPPkqaoKaOF1hYaE899xzkp6eLsHBwS5Z0xU82nhffvnl9X4epa62bdvWK5yqqioJDw+vaczqCgsLc1mOjeXJHKdMmSJLliyR+++/X6666ioJCQkRh8MhKSkpUlVVZXu96jmZmZkSERFRb9zUXbe33npLevbsKbGxsUbWb46oCTN8oSaWLl0qDz/8sNxzzz3y+OOPu3z95oqaMMMXaqIloy7M8Na6iIyMFJE/NqSt68CBAy3i2D1qwgxvrYlZs2bJueeeKwMHDqxp+qv3xykuLpbCwkK54IIL3P5ErVd+J+zWrZusX79errnmGqf/ixcVFSUif/zvyoUXXlgTLy4uVu6wV/caIn9sQT948GDt63SPiLgjR50VK1bIuHHj5IUXXqiJnTx5Uo4ePap8/Z49e+S6666r+XNpaakcOHCg5o5a9XsRHh7u9L1wpS+++EL27t0rc+bMccv1vB014Zy318QHH3wgd999t4wcObJm11M4R0045+01gcahLpzz1rqIjo4WPz8/yc3NlTFjxtTET506Jdu3b68VQ23UhHPeWhP79u2TvXv31nofqt13330iIlJSUlLv0X/TvPIHZ8eMGSOVlZXy9NNP1xurqKio+WQYPHiw+Pv7y8KFC2s9GvHiiy+e8Rp9+/aVrl27yosvvljvk+v0tarP/6v7GnfkqNO6detaa4mILFy4UCorK5WvX7x4sZSXl9f8OSMjQyoqKuTGG28UEZGEhATp0KGDPPvss7VeV624uNhpPo3Z+v+tt94SEZHbbrutwXNaMmrCOW+uiU2bNklKSopce+21smzZshaz30FTURPOeXNNoPGoC+e8tS5CQkJk8ODB8uabb8rx48dr4pmZmVJaWiqjR492Or8loyac89aamDt3rqxcubLWr+r3b8aMGbJy5UojZ7ifiVfe8R4wYICkpaXJvHnzZPv27RIfHy/+/v6yZ88eycrKkpdeeklGjRolYWFh8uCDD8q8efMkKSlJEhMTJS8vT1avXn3GY3hatWolGRkZkpycLL1795YJEyZIZGSk7Nq1S7777jv55JNPRERqHoOeOnWqJCQkSOvWrSUlJcUtOeokJSVJZmamhISESK9evWTr1q2yfv36WkcqnO7UqVMyaNAgGTNmjOzevVteeeUViYuLk2HDhonIHz9vkZGRIXfccYf07dtXUlJSJCwsTPbt2ycff/yxXHPNNbJo0SJtPna3/q+srJTly5fLlVdeWfM/Y3COmnDOW2uiqKhIhg0bJg6HQ0aNGiVZWVm1xmNiYiQmJsb+G9ICUBPOeWtNiIh88803smrVKhER2bt3rxw7dkzmzp0rIiKXXnqpJCcnN+IdaRmoC+e8uS6eeeYZufrqq2XAgAGSmpoq+/fvlxdeeEHi4+NlyJAhjXo/WgJqwjlvrYm4uLh6seq72/369ZPhw4c3+D1wKXdsnV5X9ZbzX375pdPXjRs3zgoKCtKOL1682IqNjbUCAwOt9u3bW5dccok1Y8YM6+eff655TWVlpfXUU09ZkZGRVmBgoDVw4EBrx44dVlRUlNOt/6tt3rzZuuGGG6z27dtbQUFBVkxMjLVw4cKa8YqKCmvKlClWWFiY5XA46m1778ocdaTO1v8lJSXWhAkTrM6dO1vBwcFWQkKCtWvXrnrrVX8cNm7caKWmplqdOnWygoODrbFjx1q//PJLvevk5ORYCQkJVkhIiBUQEGB169bNGj9+vJWbm1vzGlccE1N9PMLLL7/coNf7AmqCmlCp/hjofp3+d/Q11AQ1oVOdk+pXQ94Lb0ZdUBfOfPbZZ9bVV19tBQQEWGFhYdakSZOsX3/9tUFzvRU1QU00VHM4TsxhWXWeHwAAAAAAAC7DDwsCAAAAAGAQjTcAAAAAAAbReAMAAAAAYBCNNwAAAAAABtF4AwAAAABgEI03AAAAAAAG0XgDAAAAAGCQX0Nf6HA4TOYBeERTjrGnJuCLqAmgvsbWBTUBX8T3CaC+htQFd7wBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACD/DydAADv9eCDDyrjgYGB2jkxMTHK+KhRo2xfPyMjQzu2detWZTwzM9P2dQAAAICm4I43AAAAAAAG0XgDAAAAAGAQjTcAAAAAAAbReAMAAAAAYBCNNwAAAAAABrGrOQAALUSPHj2U8V27dmnnTJs2TRlfuHChS3IC6goKClLGn3vuOe2ctLQ0ZXzbtm3aOaNHj1bGi4qKnGQHAI1D4w3AqeXLl2vHGnMEmE5VVZXtObp/aImIDB48WBnfuHGjds6+ffts5wAAAACcCY+aAwAAAABgEI03AAAAAAAG0XgDAAAAAGAQjTcAAAAAAAbReAMAAAAAYBC7mgMQEf3u5a7cuVxEf2zRJ598op1z4YUXKuPJycnaOd26dVPGx44dq50zb9487RjgC/r06aOMOztVYP/+/abSAZQiIyOV8YkTJ2rn6D6HY2NjtXOSkpKU8fT0dCfZAU3Tt29f7dh7772njHfp0sVQNmbFx8drx/Lz85Xxn376yVQ6HscdbwAAAAAADKLxBgAAAADAIBpvAAAAAAAMovEGAAAAAMAgGm8AAAAAAAxiV3MAAFqI3r17K+O//fabds7KlSsNZYOWLCwsTDv2+uuvuzETwL0SEhK0Y23btnVjJuY5O33mrrvuUsZTUlJMpeNxNN5AC3LZZZdpx0aMGGF7ve+++04ZHzZsmHbO4cOHlfHS0lLtnDZt2ijjn3/+uXbOpZdeqoyHhoZq5wAAAAAm8Kg5AAAAAAAG0XgDAAAAAGAQjTcAAAAAAAbReAMAAAAAYBCNNwAAAAAABvnEruajRo1SxidOnKid8/PPPyvjJ0+e1M5ZtmyZMn7w4EHtnL1792rHAHeLjIzUjjkcDmVct3O5iP5IjAMHDthL7AweeOABZbxXr1621/r444+bmg7QrEVHR2vHJk+erIxnZmaaSgct3NSpU5Xx4cOHa+dcfvnlhrKp7dprr1XGW7XS35f6+uuvlfFNmza5JCf4Dj8/dZuVmJjo5kw8Z9u2bdqx6dOnK+NBQUHaOc6OvvQG3PEGAAAAAMAgGm8AAAAAAAyi8QYAAAAAwCAabwAAAAAADKLxBgAAAADAIJ/Y1RwAAPzhz3/+s3ZMt1vs8uXLTaWDFu7vf/+7Ml5VVeXmTOobOXKkrbiISFFRkTJ+yy23aOc429kZvuu6665Txq+66irtnPnz55tKxyM6deqkHdOdTNOuXTvtHG/f1dwnGm/dJ2mXLl1cep20tDRl/Pjx49o5zo5i8kb79+9Xxp19ocjNzTWVDmz68MMPtWPdu3dXxp19fh85cqTJOTVESkqKMu7v7++W6wMAAABNwaPmAAAAAAAYROMNAAAAAIBBNN4AAAAAABhE4w0AAAAAgEE03gAAAAAAGOQTu5pPnDhRGY+JidHOyc/PV8Yvuugi7Zy+ffsq4wMHDtTOufLKK5Xxn376STvn/PPP147ZVVFRoR0rLi5WxiMjI21fZ9++fdoxdjX3DrojUtzloYce0o716NHD9npffPGFrTjgK2bMmKEd09U5X6fRFNnZ2dqxVq08e4/nl19+0Y6VlpYq41FRUdo5Xbt2Vcb//e9/a+e0bt1aOwbvFh0drR17++23lfGCggLtnGeffbbJOTUnN910k6dTaFa44w0AAAAAgEE03gAAAAAAGETjDQAAAACAQTTeAAAAAAAYROMNAAAAAIBBPrGrOQAALUmXLl20Y5dddpl27Pvvv1fGf/vtt6amhBZgwIABynjPnj21c6qqqmzFG+vVV19VxteuXaudc+zYMWX8+uuv18557LHH7CUmIvfee68ynpGRYXstNC+PP/64diwoKEgZHzJkiHaObqf95u6ss85SxnVfM0Rc/zXAG/hE4/3pp5/aijuzZs0a23M6deqkHevdu7cyvm3bNu2cfv362c5B5+TJk9ox3T/AdEetiegLy9nRCMDpkpKSlPE5c+Zo57Rp00YZP3TokHbOo48+qoyXlZU5yQ4AAABwPR41BwAAAADAIBpvAAAAAAAMovEGAAAAAMAgGm8AAAAAAAyi8QYAAAAAwCCf2NXc00pKSrRjOTk5ttdrzG7sjXHzzTcr4852af/222+V8eXLl7skJ/g+3VFHup3LnXH2ebdx40bb6wHewtkRLc4UFxe7OBP4GmdH1b3zzjvKeOfOnV2aQ1FRkTL+7rvvauc89dRTynhjTrLQXV9EJDU1VRkPCwvTzpk/f74yHhAQoJ2zaNEiZby8vFw7B+aMGjVKGU9MTNTO2bt3rzKem5vrkpyaE90xe86ODNuwYYMyfvToURdk1DxxxxsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACD2NUcAAAvc8kllzRqnm53ZaCan5/+n4au3L3c2ckTKSkpyvjhw4dddn1nnO1qPm/ePGV8wYIF2jnt2rVTxp3V46pVq5TxgoIC7RyYM3r0aGVc97EVEXnllVdMpeMRzk48GDt2rDJeWVmpnTN37lxl3Jd37qfx9nHh4eHaMd0XhFat9A9CzJkzRxk/cuSIvcTg095//33tWHx8vO313njjDWX88ccft70WAAAA4G48ag4AAAAAgEE03gAAAAAAGETjDQAAAACAQTTeAAAAAAAYROMNAAAAAIBB7Gru4yZNmqQdCwsLU8ZLSkq0c3bv3t3knOA7IiMjlfGrr75aO6dt27bKuLNjYnRHTpSWljrJDvB+V155pTI+YcIE7Zy8vDzt2Lp165qcE2BHbm6uMn7XXXdp57jr2LDG0B3zpTtOSUSkX79+ptKBC4WEhGjHdF+LncnIyGhKOs1Oamqqdkx31GB+fr52Tk5OTpNz8jbc8QYAAAAAwCAabwAAAAAADKLxBgAAAADAIBpvAAAAAAAMovEGAAAAAMAgdjUHAKCZGjx4sDJ+1llnaeesWbNGO3by5Mkm54SWq1Ur+/drrrjiCgOZeI7D4VDGnb03jXnfZs+erYzfcccdttdCw+hOXREROffcc5Xxt99+21Q6zU63bt1sz9mxY4eBTLwXjbePuOaaa5TxRx55xPZaw4cP145RQDjdu+++q4yHhobaXuvNN9/UjhUUFNheDwAAAGgueNQcAAAAAACDaLwBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg9jV3EckJiYq4/7+/to5n376qTK+detWl+QE3zBs2DDtWN++fW2vt2HDBmX8ySeftL0W4OsuvfRSZdyyLO2cFStWmEoHLcA999yjHauqqnJjJs1TcnKyMt6nTx/tHN375uz91B0nBnOOHz+uHdu+fbsyHhMTo52jO/bxyJEjtvJyt/DwcGV81KhRttfavHlzU9PxKdzxBgAAAADAIBpvAAAAAAAMovEGAAAAAMAgGm8AAAAAAAyi8QYAAAAAwCAabwAAAAAADOI4MS8SGBioHRsyZIgyfurUKe0c3fFN5eXl9hKDTwgNDVXGZ86cqZ3j7Lg6Hd2RHKWlpbbXAnxBRESEdqx///7K+O7du7VzVq5c2eSc0HLpjsvyRWFhYcp4r169tHOcfU+0q7i4WDvGv8Xc78SJE9qxgoICZfzmm2/Wzvn444+V8QULFthLrJGio6O1YxdeeKF2rEuXLsq4s2MsdTiCsDbueAMAAAAAYBCNNwAAAAAABtF4AwAAAABgEI03AAAAAAAG0XgDAAAAAGAQu5p7kYceekg71qdPH2V8zZo12jlbtmxpck7wHQ888IAy3q9fP9trvf/++9ox3W76QEs1fvx47Vh4eLgyvnr1akPZAC3HY489poxPmjTJpdcpLCxUxseNG6eds2/fPpfmgKbR/dvF4XBo5wwdOlQZf/vtt12S05kcPnxYO+Zsh/LOnTu7LIelS5e6bC1fwB1vAAAAAAAMovEGAAAAAMAgGm8AAAAAAAyi8QYAAAAAwCAabwAAAAAADKLxBgAAAADAII4Ta4Z0xw888cQT2jm//vqrMj5nzhyX5ATfN336dJetNXnyZO1YaWmpy64D+IKoqCjbc0pKSgxkAvie7Oxs7VjPnj3dksPOnTuV8c2bN7vl+mi6Xbt2KeNjxozRzundu7cy3r17d1ekdEYrVqxo1LzXX39dGR87dqzttU6cONGoHHwVd7wBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiF3NPSQ0NFQ79vLLLyvjrVu31s7R7dr5+eef20sMcIGzzjpLO1ZeXu6WHI4dO2b7+v7+/sp4SEiI7et37NhRO+bKHeQrKyu1Yw8//LAyXlZW5rLro+mSkpJsz/nwww8NZAKIOBwO7VirVvbv19x444225yxevFgZP+ecc2yv5Sznqqoq2+s1RnJysluug+Zl+/bttuLNxQ8//OCytaKjo7VjO3bscNl1vAV3vAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAIM4Tsww3RFga9as0c7p2rWrMl5QUKCd88QTT9hLDDDom2++8XQKkpWVpYwfOHBAO+fss89Wxm+55RaX5ORuBw8eVMafeeYZN2cCEZG4uDhlPCIiws2ZAHoZGRnasfnz59te76OPPlLGG3OUl6uP/3Lleq+++qrL1gI8SXekoLOjBnVa4pFhznDHGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAIPY1dywbt26KeOxsbG215o+fbp2zNmO50BDZGdnK+M33XSTmzNxjdGjR7vlOhUVFcp4Y3bLXbVqlXYsNzfX9nqfffaZ7TkwZ8SIEcq47vQLEZG8vDxlfNOmTS7JCajrvffe04499NBDynhYWJipdIwqLi5WxvPz87VzUlNTlXFnJ2YA3sSyLFtxNBx3vAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAIM4TswFoqKitGNr1661vZ7uuI6PPvrI9lpAQ40cOVIZnzFjhnaOv7+/y65/8cUXa8duueUWl13nH//4h3assLDQ9nrvvvuuMr5r1y7ba8H7tWvXTjuWmJhoe70VK1Yo45WVlbbXAhqiqKhIO5aSkqKMDx8+XDtn2rRpTU3JmGeeeUYZT09Pd3MmQPMREBBge86JEycMZOJ7uOMNAAAAAIBBNN4AAAAAABhE4w0AAAAAgEE03gAAAAAAGETjDQAAAACAQQ7LsqwGvdDhMJ2L19Ltiiki8uijj9pe7/LLL1fGc3Nzba8F5xr46a9ETcAXURNN42yn/40bNyrjhw4d0s657bbblPGysjJ7iaFJGlsX1ITIkCFDlPHU1FTtnOTkZGV81apV2jmLFy9Wxp19DHbu3KmM79u3TzsHfJ/wdQcPHlTG/fz0h2E9/fTTyvhLL73kkpy8QUPqgjveAAAAAAAYROMNAAAAAIBBNN4AAAAAABhE4w0AAAAAgEE03gAAAAAAGETjDQAAAACAQRwnZkNcXJwynp2drZ0THBxs+zocJ+Y+HIkB1EZNAPVxnBjwf3yf8G0ffvihMr5gwQLtnJycHFPpeA2OEwMAAAAAwMNovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACD/DydgDfp37+/Mt6YncsLCgq0Y6WlpbbXAwAAAICmSE5O9nQKPos73gAAAAAAGETjDQAAAACAQTTeAAAAAAAYROMNAAAAAIBBNN4AAAAAABhE4w0AAAAAgEEcJ2bY119/rYwPGjRIO+fIkSOm0gEAAAAAuBl3vAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADDIYVmW1aAXOhymcwHcroGf/krUBHwRNQHU19i6oCbgi/g+AdTXkLrgjjcAAAAAAAbReAMAAAAAYBCNNwAAAAAABtF4AwAAAABgEI03AAAAAAAG0XgDAAAAAGBQg48TAwAAAAAA9nHHGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACDaLwBAAAAADCIxhsAAAAAAINovAEAAAAAMIjGGwAAAAAAg2i8AQAAAAAwiMYbAAAAAACD/gfo14rP3b3PMwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -248,7 +248,7 @@ "output_type": "stream", "text": [ "Accuracy on shown samples: 5/5\n", - "Overall test accuracy: 90.58%\n" + "Overall test accuracy: 89.48%\n" ] } ], @@ -280,6 +280,14 @@ "total_samples = len(test_labels)\n", "print(f\"Overall test accuracy: {100 * total_correct/total_samples:.2f}%\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba18cc4c-75d0-4f9e-a7fb-35c997a0e951", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -298,7 +306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/modula/abstract.py b/modula/abstract.py index ec9ca01..c2e971f 100644 --- a/modula/abstract.py +++ b/modula/abstract.py @@ -1,6 +1,8 @@ -import jax import copy +import jax +import einops + class Module: def __init__(self): self.children = [] @@ -204,3 +206,26 @@ def __init__(self, scalar): def forward(self, x, w): return x * self.sensitivity + +class Mean(Bond): + def __init__(self, axis, size): + super().__init__() + self.smooth = True + self.axis = axis + self.size = size + self.sensitivity = 1 / size + + def forward(self, x, w): + assert x.shape[self.axis] == self.size + return jax.numpy.mean(x, axis=self.axis) + +class Patchify(Bond): + def __init__(self, size): + super().__init__() + self.smooth = True + self.sensitivity = 1 + self.size = size + + def forward(self, x, w): + p1, p2 = self.size + return einops.rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=p1, p2=p2) diff --git a/modula/atom.py b/modula/atom.py index c46855f..e994e80 100644 --- a/modula/atom.py +++ b/modula/atom.py @@ -90,6 +90,54 @@ def dualize(self, grad_w, target_norm=1.0): d_weight = jnp.nan_to_num(d_weight) return [d_weight] +class Bias(Atom): + def __init__(self, d): + super().__init__() + self.d = d + self.smooth = True + self.mass = 1 + self.sensitivity = 1 + + def forward(self, x, w): + weights = w[0] # shape [d] + return weights + + def initialize(self, key): + return [jnp.zeros(shape=self.d)] + + def project(self, w): + weight = w[0] + weight = weight / jnp.linalg.norm(weight) * jnp.sqrt(self.d) + return [weight] + + def dualize(self, grad_w, target_norm=1.0): + grad = grad_w[0] + d_weight = grad / jnp.linalg.norm(grad) * jnp.sqrt(self.d) * target_norm + d_weight = jnp.nan_to_num(d_weight) + return [d_weight] + +class Scale(Atom): + def __init__(self, d): + super().__init__() + self.d = d + self.smooth = True + self.mass = 1 + self.sensitivity = 1 + + def forward(self, x, w): + weights = w[0] # shape [d] + return weights * x + + def initialize(self, key): + return [jnp.ones(shape=self.d)] + + def project(self, w): + weight = w[0] + return [jnp.sign(weight)] + + def dualize(self, grad_w, target_norm=1.0): + grad = grad_w[0] + return [jnp.sign(grad) * target_norm] if __name__ == "__main__": diff --git a/modula/bond.py b/modula/bond.py index 7bdbf00..7814022 100644 --- a/modula/bond.py +++ b/modula/bond.py @@ -97,6 +97,28 @@ def forward(self, x, w): v, scores = x return scores @ v +class Constant(Bond): + def __init__(self, f): + super().__init__() + self.f = f + self.smooth = True + self.sensitivity = 0 + + def forward(self, x, w): + return self.f() + +class LayerNorm(Bond): + def __init__(self, eps=1e-6): + super().__init__() + self.eps = eps + self.smooth = True + self.sensitivity = 1 + + def forward(self, x, w): + mean = jnp.mean(x, axis=-1, keepdims=True) + var = jnp.var(x, axis=-1, keepdims=True) + return (x - mean) / jnp.sqrt(var + self.eps) + class Rope(Bond): """Rotates queries and keys by relative context window distance.""" def __init__(self, d_head, base=10000): @@ -106,18 +128,14 @@ def __init__(self, d_head, base=10000): self.rope_dim = d_head // 2 self.inverse_frequencies = 1/base**(jnp.arange(self.rope_dim) / self.rope_dim) - self.seq_len_cached = None - self.sin_cached = None - self.cos_cached = None def get_cached(self, seq_len): - if self.seq_len_cached != seq_len: - self.seq_len_cached = seq_len - distance = jnp.arange(seq_len) - freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim] - self.cos_cached = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim] - self.sin_cached = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim] - return self.sin_cached, self.cos_cached + # Actually caching the return value may lead to leaked intermediate value error + distance = jnp.arange(seq_len) + freqs = jnp.outer(distance, self.inverse_frequencies) # shape [seq_len, rope_dim] + cos = jnp.expand_dims(jnp.cos(freqs), (0, 1)) # shape [seq_len, rope_dim] + sin = jnp.expand_dims(jnp.sin(freqs), (0, 1)) # shape [seq_len, rope_dim] + return sin, cos def rotate(self, x): batch, n_heads, seq_len, d_head = x.shape @@ -126,6 +144,7 @@ def rotate(self, x): x1 = x[..., self.rope_dim:] # shape [batch, n_heads, seq_len, rope_dim] x2 = x[..., :self.rope_dim] # shape [batch, n_heads, seq_len, rope_dim] + # Why is the order reversed!? cos, sin = self.get_cached(seq_len) y1 = cos * x1 + sin * x2 y2 = -sin * x1 + cos * x2 diff --git a/modula/compound.py b/modula/compound.py index 7d3701a..3cb0107 100644 --- a/modula/compound.py +++ b/modula/compound.py @@ -1,3 +1,5 @@ +import jax.numpy as jnp + from modula.abstract import * from modula.atom import * from modula.bond import * @@ -8,14 +10,21 @@ def MLP(output_dim, input_dim, width, depth): m = m @ Linear(width, width) @ ReLU() return m @ Linear(width, input_dim) -def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal): +def Attention(num_heads, d_embed, d_query, d_value, softmax_scale, causal, posemb="rope", bias=False): """Multi-head attention""" - Q = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) - K = SplitIntoHeads(num_heads) @ Linear(num_heads * d_query, d_embed) - V = SplitIntoHeads(num_heads) @ Linear(num_heads * d_value, d_embed) - W = Linear(d_embed, num_heads * d_value) @ MergeHeads() - - AttentionScores = Softmax(softmax_scale) @ CausalMask() @ AttentionQK() @ Rope(d_query) @ (Q, K) + Q, K, V = Linear(num_heads * d_query, d_embed), Linear(num_heads * d_query, d_embed), Linear(num_heads * d_value, d_embed) + Q = SplitIntoHeads(num_heads) @ (Q + Bias(num_heads * d_query) if bias else Q) + K = SplitIntoHeads(num_heads) @ (K + Bias(num_heads * d_query) if bias else K) + V = SplitIntoHeads(num_heads) @ (V + Bias(num_heads * d_value) if bias else V) + W = Linear(d_embed, num_heads * d_value) + W = (W + Bias(d_embed) if bias else W) @ MergeHeads() + QK = (Q, K) + if posemb == "rope": + QK = Rope(d_query) @ QK + attn = AttentionQK() @ QK + if causal: + attn = CausalMask() @ attn + AttentionScores = Softmax(softmax_scale) @ attn return W @ (1/3 * ApplyAttentionScores()) @ (V, AttentionScores) def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mass=5, attention_scale=1.0, final_scale=1.0): @@ -31,4 +40,50 @@ def GPT(vocab_size, num_heads, d_embed, d_query, d_value, num_blocks, blocks_mas out = final_scale * Linear(vocab_size, d_embed) - return out @ blocks @ embed \ No newline at end of file + return out @ blocks @ embed + +def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): + """Follows the MoCo v3 logic.""" + y, x = jnp.mgrid[:h, :w] + + assert width % 4 == 0, "Width must be mult of 4 for sincos posemb" + omega = jnp.arange(width // 4) / (width // 4 - 1) + omega = 1. / (temperature**omega) + y = jnp.einsum("m,d->md", y.flatten(), omega) + x = jnp.einsum("m,d->md", x.flatten(), omega) + pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) + return jnp.asarray(pe, dtype)[None, :, :] + +def ViT(num_classes, image_size=(28, 28), patch_size=(7, 7), num_heads=4, d_embed=32, d_query=8, d_value=8, num_blocks=4, blocks_mass=5, attention_scale=1.0, final_scale=1.0, channels=1, LN=True, bias=True, scale=True): + i1, i2 = image_size + p1, p2 = patch_size + h, w = i1 // p1, i2 // p2 + patchify = Linear(d_embed, p1 * p2 * channels) @ Patchify(patch_size) + if bias: + patchify = patchify + Bias(d_embed) + posemb = Constant(lambda: posemb_sincos_2d(h, w, d_embed)) + + att = Attention(num_heads, d_embed, d_query, d_value, attention_scale, causal=False, posemb="none", bias=bias) + mlp = (Linear(d_embed, 4*d_embed) + Bias(d_embed) if bias else Linear(d_embed, 4*d_embed)) @ GeLU() @ (Linear(4*d_embed, d_embed) + Bias(4*d_embed) if bias else Linear(4*d_embed, d_embed)) + if LN: + ln = LayerNorm() + if bias and scale: + ln = (Scale(d_embed) + Bias(d_embed)) @ ln + elif bias: + ln = ln + Bias(d_embed) + elif scale: + ln = Scale(d_embed) @ ln + att = att @ ln + mlp = mlp @ ln + att_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * att + mlp_block = (1-1/(2*num_blocks)) * Identity() + 1/(2*num_blocks) * mlp + blocks = (mlp_block @ att_block) ** num_blocks + blocks.tare(absolute=blocks_mass) + + gap = Mean(axis=1, size=h * w) + out = final_scale * (Linear(num_classes, d_embed) + Bias(num_classes) if bias else Linear(num_classes, d_embed)) + + ret = blocks @ (patchify + posemb) + if LN: # Final LN + ret = ln @ ret + return out @ gap @ ret