diff --git a/examples/keras/blocksparse_example.ipynb b/examples/keras/blocksparse_example.ipynb new file mode 100644 index 0000000..e65d71b --- /dev/null +++ b/examples/keras/blocksparse_example.ipynb @@ -0,0 +1,407 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n", + "import numpy as np\n", + "from blocksparse_layer import BlockSparse\n", + "from sparsity_pattern_initializers import BarabasiAlbert\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Load dataset\n", + "mnist = tf.keras.datasets.cifar10\n", + "\n", + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "x_train, x_test = x_train / 255.0, x_test / 255.0\n", + "\n", + "# Create data generator for augmentation\n", + "train_gen = ImageDataGenerator(\n", + " featurewise_center=False, # set input mean to 0 over the dataset\n", + " samplewise_center=False, # set each sample mean to 0\n", + " featurewise_std_normalization=False, # divide inputs by std of the dataset\n", + " samplewise_std_normalization=False, # divide each input by its std\n", + " zca_whitening=False, # apply ZCA whitening\n", + " rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180)\n", + " width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)\n", + " height_shift_range=0.1, # randomly shift images vertically (fraction of total height)\n", + " horizontal_flip=True, # randomly flip images\n", + " vertical_flip=False) # randomly flip images\n", + "train_gen.fit(x_train)\n", + "gen_flow = train_gen.flow(x_train, y_train, batch_size=64)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAALYAAAD8CAYAAADaM14OAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAE2tJREFUeJztnW+sZVV1wH+rbxinYMkwKvicGTuQjn+IccBMEGubEKippUb8gBZjzaSlmS8mxcZEh/aDadImJWkEPjQ0L1BDExOwI+lMiHHUET74wYFHh2pkHHhFCq8zCqZDNNpQJq5+uOfG28u55+5/Z599zlm/ZPLevXefvdc9b8/aa6+19jqiqhjG0Pi1rgUwjDawiW0MEpvYxiCxiW0MEpvYxiCxiW0MEpvYxiCJmtgi8kEROS0iGyJyKJVQhhGLhAZoRGQFeBr4ALAJPA58XFWfSieeYYSxJeLaa4ANVX0WQEQeAG4CFk7sN+5Y0T27L4gY0p+nv3vha95727t/4dTOldn+SuinK6bytyn7zzj3E1V907J2MRN7J/DCzOtN4L3zjUTkIHAQ4K07t/DYsd0RQ/rz+2+56jXvHTv2pFM7V2b7K6GfrpjK36bs39TD/+nSLmZiS817r7FrVHUNWAPYv29b9sSUY2deO1lqJ/uZ/z/Zp23m31/22aJ+m8ZsksuFpjFir3P5jj7yx8q6surWPmbzuAnMqt9dwJmI/gwjGTET+3Fgr4hcLiJbgVuAo2nEMow4gr0iACJyI3AXsAL8k6r+bVP7/fu2aZc29vwyWGemuOCzhLuYIL7EXF/3nWPlicVn/G/q4SdUdf+ydjE2Nqr6VeCrMX0YRhtEaWxfcmjsXnoTArV77MpROnWri6vGtpC6MUiiTBFjOS4aM1Qbu2jhPmrqKTGym8Y2BsngNLZPQCHVWC5atSnQM99f3XVN3p0m2RaN1XRNG56c3JjGNgaJTWxjkAzG3bdo6fYNSPRhye1Sxqaglou55TvGfD/m7jNGzWA0diypNP4UH62aUwO7aNyu5Whqv7K6YRrbGC+Dc/f54JIgFaq9QrR56H7AR56U2jjkHvnIHINpbGOQmI2dgDrN79K+qa2Lx2FR27r2Pqm6TSuHi4yhq4tLUMpsbGPU2MQ2BkkvNo8+h2dzjBGSG5LyoG2ImVLXxkXGkA1qk7kTvxnfcGpvGtsYJMVo7CYtlMNF1BSYWdS2Dhe3ms9YLtfPj90kV1N/dW1Tydr0vsv3yFl+wTCKpRh3n0+wIhaX5CefhJ7Q5J/UhW5yJCj5yLPsM5/+zN1nGNjENgZKVlPkYtmh75UbvKNfXRASHVzWPqTvWBMi5D42RVJdjqiFFiJywfKxjVHTibvP92BqjgCNy1g+uREueQ9N/ccGeEKuS7XBLOH0kWlsY5B04u4LdW/5EKvlfTLm6trkODkTEhJ3WRVCT7ekOlHUhNnYxqjpPKQeYmPnsLnr5AoJevh6VRaNH9pPiMy+3p6QU0ehnjELqRujxia2MUiKDdA00bScdh3YWUTKfGyfNrFj+Rz3it0Eu1xvm0dj1PRCY7u4pRb1F5slmCqw0maYuQ6XchI5C+SEjFl3z0xjG6NmqcYWkd3APwNvBn4JrKnq3SKyA3gQ2AM8B3xMVc819TXV2LHEatxU5xvbJFXwpmttnFqOlBr7PPAZVX0ncC3wKRG5EjgEHFfVvcDx6rVhFMHSAI2qngXOVr//TEROMXmO+k3AdVWz+4FHgc+5DBpqb4ZoXNfAhksKZt11821yhNJd+nMJmqROO3U5UZ8LLxtbRPYAVwMngMuqST+d/JemFs4wQnGe2CLyeuArwKdV9ace1x0UkXURWX+VV0JkNAxvnHJFROQCJpP6S6r6UPX2j0VkVVXPisgq8GLdtaq6BqzBZPMYI2zIstjk7kt1+NTlet/AkY954HM42eV7uX73LjaorizV2CIiwH3AKVX9wsxHR4ED1e8HgCPpxTOMMFw09vuBTwLfE5Gp2vhL4O+AL4vIrcDzwEddBw11B4WeEYwhNMc59HTNsrFcWaTVU250S9TUU1y8It8GZMHH8U5pw2iBTvKxfXNxp21C7NbYAI0LLmH7UI0XupItui5VrnQouUL7FlI3BklWjf22d/+CY8fctGSsNnM5AePSZ2iKbGw6ro8dnjrQE3pdSo9LLKaxjUFiE9sYJJ3kY6fCt7roPL4FakLKL4RujLoMfqTKl5/9LFUOjeVjG6OmF+UXXNrOf9bGyZX5Pn2y6lw2oU391F03xaeNz71zvb6J2JUrFNPYxiApLkDTZgjXR4s24eNm9NXUy9r6yuPTpo75cV0CZynHD8U0tjFIbGIbg6RYd19oPolPm9Bc7WVj1rV1vS6GUNNq/nrffOyYMX0xd58xanrl7nO5psl1FZJ3EasFY11nTdlw82Okdmcu6zMmH6btLEPT2MYg6UWJs0WEZt7Fjt1WUKmNPGifUzoln4iZYja2MWqK8YqkPv8XG4SJJXXyT1/xSQpzWcFMYxujxia2MUiKfRxeqlzrNg+m1skT47Z07aetjXEfTCIzRYxRU8zmcUrsiZMpvpuT+etCN52ptF5bbrpUq8uyvufHSPV9TGMbo6YYje1jb8aWRWszrzvEreUzVkpy2NSpxzCNbYyaYjR2CC6pob5pqz7j+dibqVaJpr5CS62lps39iWlsY9TYxDYGSeemyKLlx7fsQMjSmyN40zRuqGlUWiAlpzxmihijpnONPSWnC6+L84h148ee1mnLbRi7krW5EprGNkZNL0/Q+JwDdL1+Wdu69qmCMC5juV4Xcr0Pbf7tXIJbprGNUePzANMVETkpIg9Xry8XkRMi8oyIPCgiW9sT0zD88Cm/cBtwCri4en0HcKeqPiAi/wjcCtzj0lGbGzSX5Sy2fEIIuTdkqcpaNJWzCJE19u/hipPGFpFdwB8C91avBbgeOFw1uR/4SHLpDCMQV419F/BZ4Deq128AXlbV89XrTWBnSsF8cyPmiS2tUNePj/YKGT+0P5+SE77j5ywDkXJMl0dOfwh4UVWfmH27pmmte0VEDorIuoisv8orgWIahh+uj5z+sIjcCGxjYmPfBWwXkS2V1t4FnKm7WFXXgDWYuPsWDRIblg3JeHPRUKmCDanyspvap3I7xl7vq/l93H2uLNXYqnq7qu5S1T3ALcC3VPUTwCPAzVWzA8CRYCkMIzExRSk/BzwgIn8DnATuixHEZffc9JnLeUYXYota+rTNcS4yVgvmSDFoYwyvia2qjwKPVr8/C1yTXCLDSIBFHo1B0nl2X8zyHls+oc1SBLG0NUbuHPTY3PN5LFfEGDXFPQ6viaYN5qLASp3mbtpY+qwKbZLS9VXXb4q+U29eXbIlV1bdZDONbQySXtnYTbho7CbtkVozhpRc64o+yWE2tjFqOtfYi0hV/CWHjRxazCakTawcLtf7nuxPveqaxjaMBdjENgZJsaaIC23WiKsbw6VcgoscOZfumH5T9O2SuefzPcwUMUZNL8sv1OHj7mu6vi2NnzN8PTteF/nYqcaoWxFXVjdMYxvjpZOQuguhZx5DC82kPrkee/4wx8kVnz7b1Nw+J4Jgw6lP09jGILGJbQySYk2RVEeyusp/iD2i1kV+uMsh59z53KGYxjYGSS/dfalqabdZj7pJjlI0XemuwLp7ZgEaY9T0OqTehG/t6lTaeJ42bNO2Q+m+/eYMQpnGNkZNVo29f982fezY7mQaoc385VjtNd9vjjLAXZPDO2Ma2xg1NrGNQdKLzaPvkaXZtiW62VxoKgOxKH+jq1IR87R5DM5MEWPU9EJj+xBy4LepnxyuLxfX5CwxWjjHoeK6vlKtHKaxjVHTa40de+bQtX0IXYerY0uKpSZmJZu9zjS2MWp6nQQ1S+qzk20GUVxObsf0F0ubnqTYv51pbGPU2MQ2BkmvNo8+WXSxxXCGSte1DM0UMYwInDS2iGxn8hz1dzF5Au+fAqeBB4E9wHPAx1T1XFM/qetjdx1IKC2EnaoGd+7v1WWJs7uBr6nqO4B9wCngEHBcVfcCx6vXhlEESzW2iFwM/Dtwhc40FpHTwHWqelZEVoFHVfXtTX3lPEHTFYv2Ab6BEpckqFhCZG2T3PWxrwBeAr4oIidF5F4RuQi4TFXPAlQ/L627WEQOisi6iKy/yisOwxlGPC4TewvwHuAeVb0a+DkeZoeqrqnqflXdfwGvCxTTMPxwMUXeDHxHVfdUr3+XycT+LTKZIj7VOHPUmisNnxITpWwCQ/tJZoqo6o+AF0RkOmlvAJ4CjgIHqvcOAEecpDaMDLi6+65i4u7bCjwL/AmT/xRfBt4KPA98VFX/u6mfJo2d2mXlQpsPCnIZtw+rS6q/S+7sPqfafar6JFDX2bBdHEZv6VV2n29hmmX9jzGkPktpOdtNTOWwJxoYo6bYMsJ1pNYaqUqczVKKhnMh9sxjjI3tu1r+qo090cAYMTaxjUFSbD62b07vPKFL6JhMkj7IOI/lYxujpliN7YtPpppPgZocmttXc/ZR06bCNLYxanpRHzsHqUuTLWo7i48Wz12arASsYI5hzNFLG7vNUHCqZJ3UdH2+shS73jS2MWpsYhuDpJemyJTQaquhY/jkT5eydDfRhYyxY5opYoyaXuVjz5Myr7sPp1nmaavKahtZj6kwjW2Mml7lY7uEu100TB8TnOpWpxw1wbuoO54C09jGIOmVxk5dkrZpjFh7M7WGajPQVFqtkRT3zjS2MUhsYhuDpNgATe78i5j63DnyulOR48FJbd4Pc/cZo6Y4jV3KmcMunmWT8jGBqVegUFL3aRrbGDXFaWwf2jxVUloouQ3XZs5in6kwjW2MGpvYxiDppSmSumZzU/tUGYQlb8ja7rdpLN/xzBQxRk2xGrurTV9bp2O6PozbBl1ssE1jG6OmWI3tgkvhylQusdi+unrejYtMpZzdzP0AU8PoHU752CLyF8CfAQp8j8lTw1aBB4AdwL8Bn1TV/21JTsCtfPCU0Mr5OelaU3adq93mWEs1tojsBP4c2K+q7wJWgFuAO4A7VXUvcA64NZlUhhGJqymyBfh1EdkCXAicBa4HDlef3w98JL14hhHGUlNEVf9LRP6eyUNK/wf4OvAE8LKqnq+abQI7W5OyYtEj71yumcXH9dbG0S6fA8ep3Zclm2YpcTFFLgFuAi4H3gJcBPxBTdNa94qIHBSRdRFZf5VXYmQ1DGdcNo+/B/xQVV8CEJGHgN8GtovIlkpr7wLO1F2sqmvAGkzcfSmEdgmltxkQidWiPuH72A3m/Gc588u7xMXGfh64VkQuFBFh8pjpp4BHgJurNgeAI+2IaBj+OAVoROSvgT8CzgMnmbj+dvIrd99J4I9VtdHWiA3QtBF0mVJKglNOcpRIS41rgMbJj62qnwc+P/f2s8A1AbIZRusUG1L3tfNKPQ1SWmi/DzTdMwupG6PGJrYxSIo1RUJpcwOzqO8cWYZt0uaTIULuWRNmihijpnON3VU1zrYpRca2XHp1fbZZZGjafmV1wzS2MV4619hTUiX7TOk617kP9PE+mI1tjBqb2MYg6fxRHandQaUtqyUv96lckyU+f940tjFIitk8NrFoY+iSpZcy+FCy9h065u4zDDJr7P37tuljx3ZHF5NMRSmZe7GUUj4tlRyW3WcYC+iFje1D6NnHUrRvKnzKl03pw3c3jW2MGpvYxiDpPECTmtDyC6VVUs1RCbUPpkcoprGNQVKcxm6zpFdbtJF/7JJi0NYKVDdumytPG5jGNgZJJxq7SXvF/s/Oqald5Fj0epaSXY1dBnpixjaNbQySXgZocpxEn5LjrGCb5FgNcgZ6LEBjjBqb2MYg6aUpEkqbhWFC5Eh1aihk7DbHCMXlUeKWj22MmqwaW0ReAn4O/CTboGl4I/2TGfop9zKZf1NV37Ssk6wTG0BE1l2WkpLoo8zQT7lTyWymiDFIbGIbg6SLib3WwZix9FFm6KfcSWTObmMbRg7MFDEGSbaJLSIfFJHTIrIhIodyjeuLiOwWkUdE5JSIfF9Ebqve3yEi3xCRZ6qfl3Qt6zwisiIiJ0Xk4er15SJyopL5QRHZ2rWMs4jIdhE5LCI/qO73+1Ld5ywTW0RWgH9g8qjqK4GPi8iVOcYO4DzwGVV9J3At8KlK1kPAcVXdCxyvXpfGbcCpmdd3AHdWMp8Dbu1EqsXcDXxNVd8B7GMie5r7rKqt/wPeBxybeX07cHuOsRPIfgT4AHAaWK3eWwVOdy3bnJy7qolwPfAwIEwCHVvq/gZd/wMuBn5Itc+beT/Jfc5liuwEXph5vVm9VzQisge4GjgBXKaqZwGqn5d2J1ktdwGfBX5ZvX4D8LJOnnUP5d3zK4CXgC9W5tO9InIRie5zroktNe8V7Y4RkdcDXwE+rao/7VqeJkTkQ8CLqvrE7Ns1TUu651uA9wD3qOrVTFItkpl3uSb2JrB75vUu4Eymsb0RkQuYTOovqepD1ds/FpHV6vNV4MWu5Kvh/cCHReQ5Js+3v56JBt8uItPjf6Xd801gU1VPVK8PM5noSe5zron9OLC32qVvBW4BjmYa2wsREeA+4JSqfmHmo6PAger3A0xs7yJQ1dtVdZeq7mFyb7+lqp8AHgFurpqVJvOPgBdE5O3VWzcAT5HqPmfcLNwIPA38B/BXXW9eGuT8HSZL9neBJ6t/NzKxWY8Dz1Q/d3Qt6wL5rwMern6/AngM2AD+BXhd1/LNyXoVsF7d638FLkl1ny3yaAwSizwag8QmtjFIbGIbg8QmtjFIbGIbg8QmtjFIbGIbg8QmtjFI/g9a5Swl5abv4gAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAFMFJREFUeJzt3X/s3VV9x/Hny1Ko6AhUgXW0WTGpDLPYYr6pEBaDKCtzRv/RRbcsbOnSZHELZi5YtmTBZcswS6b7Y3NpppM/3AB/TcKMteloMpOl5cv4IYilyJg07ShWiEYCUnzvj/upu1y/n/v93M895/P53J7XI2m+vb/Oed97v+f7OefzPp9zFBGYWVle1XcAZtY9N3yzArnhmxXIDd+sQG74ZgVywzcrkBu+WYHmaviSrpd0WNLjknanCsrM8lLbCTyS1gCPAdcBR4F7gQ9GxLfShWdmOZw1x2u3A49HxBMAkm4H3gvUNvzXr18TmzetnaPKbj320LlZy3/jm5+vrW/8sWlxtC2jaflnkqG+55RxvcCP+HG8qNWeN0/DvwR4auz2UeCt016wedNaDu3dNEeV3drxC9uylr937wO19Y0/Ni2OtmU0Lf9MMtT3nDKug7G/0fPmGeOv9FflZ8YNknZJWpa0/MzJl+eozsxSmWeMfxVwS0TsqG7fDBARf1X3mqWt62IIR/ypR9BjwzgqjMcxKXtPpMVnMBnvkI6oZ6K6z/tg7OcH8f1Vu/rzHPHvBbZIulTS2cAHgLvmKM/MOtJ6jB8RpyT9AbAXWAN8JiIeSRaZmWUzz8k9IuKrwFcTxWJmHWk9xm9jKGP8afoc/zctfyjnIVJI8V6mlZH6sxr6uYwuxvhmtqDc8M0K5K7+DF4xOWZKui1F+eNmqatpjNOeV/fYmdSNzlFX28+j7ntqE5O7+mZWyw3frEBu+GYFmiuPX5rcY9o6TVN7qV7XtIxFTjm2HYPneC99fD4+4psVyA3frEDFd/Wnpaiaaptu69O0dF5d13aWFFibtOJQroZsG0eXKcG6z3T7judXvH+Sj/hmBXLDNytQ8V39prPWJh+bJsXZ9KZdvraz6VLPPGw7SzD3zMC2n2PTx1Kf8Z/3d+yxONno9T7imxXIDd+sQG74ZgUqfozfVtsxcps0V9vyUzyv6Wva1pv6nEdTqa7Oa7MY6RBSuj7imxXIDd+sQF6II4EcabMUacWmceS+mKetNnXPkp5NXUbda2Z53SxlrsQLcZhZLTd8swK54ZsVyOm8BFKM46eVmWPhxi7H7m3H1qnjyD11OPdnmvLzWfWIL+kzkk5IenjsvvWS9kk6Uv28YKZazaxXTbr6nwWun7hvN7A/IrYA+6vbZrYgGqXzJG0G7o6IX65uHwauiYjjkjYAByListXKGU/n5Uh3LJoUn0GKLnzuFFXu1OQQZsLNI+Xai9t3PMXygy9kS+ddHBHHAaqfF7Usx8x6kP2svqRdkpYlLT9z8uXc1ZlZA23P6j8tacNYV/9E3RMjYg+wB0Zd/Zb1/VTubazq6uqivjaadntTnBGepYzU2YWmZeTe5iuHlBmL3Atx3AXcUP3/BuArLcsxsx40Sef9C/CfwGWSjkraCdwKXCfpCHBdddvMFsSqXf2I+GDNQ+9IHIuZdaS3mXttr4DqcjHJHFtcpR5ndjmmHer2Uamvzmury8VCvK6+mc3MDd+sQJ0uxHGe1sdbtfqpgS7TZrlTdikusJml/Jw78LYd0jRNt02rL3U3vW1qssuhVZv37IU4zKyWG75ZgdzwzQo0mDF+m3Fgl3vDTdN2TJij/DoptvJuu89gn+PzFK9rY5Y0bsoUde6r88xsgbnhmxVoMF39cW1nuw3x6rm26bzUXcOu19xLMRxJPeRoU1ef2rxPp/PMrJYbvlmBFm557aGsD9dml9dZ6m5bRl0XO8cZ7dRDk6bZi1m+26ZlNC1vUs5sQM5ZpT7imxXIDd+sQG74ZgVauDH+uK5TMH1duZdioY/cC0NMqy/FwirT5P496HJW3yznQObhI75ZgdzwzQq00F39adp2y/u6AKZt3Sk0TfXl6Ip31bWdxRBn8aXmI75ZgdzwzQrkhm9WoIUb46fecnla+ZPajMFnWUSjzZVkKRbpGOJ5h2lxTFqE+McNIY4mW2htknSPpEclPSLpxur+9ZL2STpS/bwgf7hmlkKTrv4p4CMRcTlwJfAhSW8CdgP7I2ILsL+6bWYLoMneeceB49X/fyjpUeAS4L3ANdXTbgMOAB+dVtYb3/w8e/eOul451pvvck21adrOaMs5666t1GvYpRie5Vjbvm0ZORccmcXpMrNsoSVpM3AFcBC4uPqjcPqPw0WzlGVm/Wnc8CW9Fvgi8OGI+MEMr9slaVnS8jMnX24To5kl1qjhS1rLqNF/LiK+VN39tKQN1eMbgBMrvTYi9kTEUkQsXfi6NSliNrM5rbrYpiQxGsN/PyI+PHb/XwMnI+JWSbuB9RFx07Sy2q6r30aK1XOaSrVN9hBTT0OMaZocC7WmWPu/q8+u6WKbTfL4VwO/DXxT0ul38yfArcCdknYC3wXe3zZYM+tWk7P63wDq/oKsvla2mQ1Op+vqL21dF4f2blrxsdzdqT63v26jzxSVLS6vq29mtdzwzQrU2xZaKbZBmibHWf0Us9FSz/Q6k7aMmibFWfema9aneqwP7uqbWS03fLMCueGbFWgw22R3eWVd7u20u0z1Tas7xdg09xg29XuZVkYJPMY3s1pu+GYF6m3mXopZdl1vC9W2zHnLTzE0yTGrMXdacSiLXCyC/1+I4ymWH3zBXX0z+1lu+GYFcsM3K9Bg0nltpLg6L0VqKNVVgqnH/4twBeG89XZdd5fa/F55jG9mtdzwzQq0EOm8nCm2tnF0OaOt7XAk9VrxqbYDr6urbRl9GsoMyNM8c8/MarnhmxVokGf1h7oQR9Pycl9U1OVyz0O80KfPs/pDzyi4q29mtdzwzQrkhm9WoEGm8yY1TS81HcenWKd+mkWYMTe0RSJLkjNNnGyML2mdpEOSHpT0iKSPVfdfKumgpCOS7pB09sxRmlkvmnT1XwSujYitwDbgeklXAh8HPhERW4BngZ35wjSzlGbq6ks6F/gG8PvAvwE/HxGnJF0F3BIRO6a9vm06r6kUKaqmr2ta75mc6pulzLqyPWxJK2k6T9KaaqfcE8A+4DvAcxFxqnrKUeCStsGaWbcaNfyIeDkitgEbge3A5Ss9baXXStolaVnS8ku82D5SM0tmpnReRDwHHACuBM6XdHqb7Y3AsZrX7ImIpYhYWss588RqZomctdoTJF0IvBQRz0l6NfBORif27gHeB9wO3AB8ZZaKp40ru9zueprc035zazr2bTruzrHfYYrzMm2mC6eYwjwp9wIsKc9lrNrwgQ3AbZLWMOoh3BkRd0v6FnC7pL8A7gc+PVckZtaZVRt+RDwEXLHC/U8wGu+b2YIp/uq8PhcESXG126TU22R3mW4b+pVvs+jrvfjqPDOr5YZvVqBBdvWnabPe3ORj08pLsaBElzPJclzwkeJMeOqz6bmlvgArR9ZgWnleXtvMVuWGb1YgN3yzAg1yjN/nYpt9LWTZtu7VyqyTOzU5RKnPeUx73qSutgN3Os/MarnhmxWoyVz9LNpepJO6jBRdtz67uV2m8yYtQppuXO6YuhzizftefMQ3K5AbvlmB3PDNCjTIdN6kLhdTaFreom/vPK7P6c1N41q0qdS5OZ1nZjNzwzcr0EJ09cflTlHVvWZS7u26Umu7Xl7qLvukoX4+44YSYxPu6ptZLTd8swL1NnMvhdzryE1bE69NvdPiaFtO0zJmeZ91jy3C7Ly2mg7dzpQhgY/4ZgVywzcrkBu+WYEGk87LmW6bpYzUMwNzjPtyzmhb7bmpeZvstJKn86qtsu+XdHd1+1JJByUdkXSHpLPnCdjMujNLV/9G4NGx2x8HPhERW4BngZ0pAzOzfBql8yRtBH4d+EvgjyQJuBb4zeoptwG3AJ9qG0ibmXC5Z+61LW/Rupt9XtjS5Y67iy7ld9H0iP9J4CbgJ9Xt1wHPRcSp6vZR4JK5IjGzzqza8CW9GzgREfeN373CU1c8Syhpl6RlScsv8WLLMM0spSZd/auB90h6F7AOOI9RD+B8SWdVR/2NwLGVXhwRe4A9MDqrnyRqM5vLTOk8SdcAfxwR75b0eeCLEXG7pH8AHoqIv5/2+hRX541ruqDGpNTTUPucxjnUabR9pem63Fdglji6SlF3cXXeRxmd6Huc0Zj/03OUZWYdmukinYg4AByo/v8EsD19SGaW22Bm7qXQppuUY+beNCliTB3TtPrOpKFDCbwQh5nVcsM3K9BCL8Qxqc1iCm1n/9XVO4scr0uxFPkQu9JDWRAk9/Cvq/fiI75ZgdzwzQrkhm9WoMGk8/pK16TYFirVLK2+Fvroc9y6aAtxdLkteZvvxek8M6vlhm9WoMF09VNrO9ut7Tp1Tcufps+6hyjFRTo59yqYfG6KbdXmHba4q29mtdzwzQrkhm9WoMFM2R3iuvrTNB3PrVbfvHXnsMhXwrXd7zDFoi4pyveUXTPLxg3frECdpvOWtq6LQ3s3AYux5tm07tpQZxd2KfdnlbqMcUP9/ZuX03lmVssN36xAg5y5l2Np4qbltTm7O5SudyqLfFa/qaEOn+blrr6Z1XLDNyuQG75ZgQY5xp8m9WIHTV8za31tyq+rayjj0aEsHDKL1AuCDP0zaDrGbzRlV9KTwA+Bl4FTEbEkaT1wB7AZeBL4jYh4tm3AZtadWbr6b4+IbRGxVN3eDeyPiC3A/uq2mS2ARl396oi/FBHfG7vvMHBNRByXtAE4EBGXTSunz4U4cnavUi2MMfQ0Wl87wM6izwu3ZikzV/nbdzzF8oMvJEvnBfB1SfdJ2lXdd3FEHAeofl7UIl4z60HTy3Kvjohjki4C9kn6dtMKqj8UuwDWcW6LEM0stUZH/Ig4Vv08AXyZ0fbYT1ddfKqfJ2peuyciliJiaS3npInazOay6hhf0muAV0XED6v/7wP+HHgHcDIibpW0G1gfETdNK6vtGD/FFVY503ldp/2Gks6bZojnKNoa+rmXcSnTeRcDX5Z0+vn/HBFfk3QvcKekncB3gffPE7CZdWfVhh8RTwBbV7j/JKOjvpktmMHM3FuEbaGGUFdbKdaK7zqOpmXMW95kmW1n7vW1lv44X51nZrXc8M0K5IZvVqDBrKufe4+zNpqWn2ohyNTpwkVKQ80ix9VzbZ+XItXcx3fjI75ZgdzwzQo0mHTeuBxXWDUt40ztHucwlM+qyziGthXWJKfzzKyWG75ZgQZzVn9cim5S2zL6HAYMpes8ru0usqnrnvZ5NF2fMMfn2+WFWynj9xHfrEBu+GYFcsM3K9Ag03ltDT3V0rUu9yAYihRp3Fle11TqFHXda5zOM7NabvhmBRpkOq+toWw7NZStsbpMTQ5lMZIUadw+U5jTpPxMfcQ3K5AbvlmB3PDNCtRpOm9p67o4tHcTkGcxhaZllJjmyqGvz6PLvfhWq3toU3GdzjOzWm74ZgXqNJ332EPnzt2V6bO73dcWWjnk3FIslboYcw8Tp8nxnnNfQbiSRkd8SedL+oKkb0t6VNJVktZL2ifpSPXzgmxRmllSTbv6fwt8LSJ+idF2Wo8Cu4H9EbEF2F/dNrMFsGpXX9J5wNuA3wGIiB8DP5b0XuCa6mm3AQeAj6YIqumiCzm0mXU3pK59m8UrFt1Qv4s2uprZ2eSI/wbgGeCfJN0v6R+r7bIvjojjANXPi5JFZWZZNWn4ZwFvAT4VEVcAP2KGbr2kXZKWJS2/xIstwzSzlJo0/KPA0Yg4WN3+AqM/BE9L2gBQ/Tyx0osjYk9ELEXE0lrOSRGzmc2p0cw9Sf8B/F5EHJZ0C/Ca6qGTEXGrpN3A+oi4aVo5uRfimCZF+mrcUK4EzKHusxrKlmWLoK+tvJvO3Guax/9D4HOSzgaeAH6XUW/hTkk7ge8C729Ylpn1rFHDj4gHgKUVHurn8G1mc+ltzb2ud4Bt+rppz/Oafq/U9LMq5fPoUl378UU6ZlbLDd+sQG74ZgXqbbHNPq+sSzH+HMoClV0uSjFLvIs8rl+EtGLdOZXtO55v9Hof8c0K5IZvVqBO03mSngH+B3g98L3OKl7ZEGIAxzHJcbzSrHH8YkRcuNqTOm34P61UWo6IlSYEFRWD43AcfcXhrr5ZgdzwzQrUV8Pf01O944YQAziOSY7jlbLE0csY38z65a6+WYE6bfiSrpd0WNLj1eIdXdX7GUknJD08dl/ny4NL2iTpnmqJ8kck3dhHLJLWSTok6cEqjo9V918q6WAVxx3V+gvZSVpTred4d19xSHpS0jclPSBpubqvj9+RTpay76zhS1oD/B3wa8CbgA9KelNH1X8WuH7ivj6WBz8FfCQiLgeuBD5UfQZdx/IicG1EbAW2AddLuhL4OPCJKo5ngZ2Z4zjtRkZLtp/WVxxvj4htY+mzPn5HulnKPiI6+QdcBewdu30zcHOH9W8GHh67fRjYUP1/A3C4q1jGYvgKcF2fsQDnAv8FvJXRRJGzVvq+Mta/sfplvha4G1BPcTwJvH7ivk6/F+A84L+pzr3ljKPLrv4lwFNjt49W9/Wl1+XBJW0GrgAO9hFL1b1+gNEiqfuA7wDPRcSp6ildfT+fBG4CflLdfl1PcQTwdUn3SdpV3df199LZUvZdNvyVVgUpMqUg6bXAF4EPR8QP+oghIl6OiG2MjrjbgctXelrOGCS9GzgREfeN3911HJWrI+ItjIaiH5L0tg7qnDTXUvaz6LLhHwU2jd3eCBzrsP5JjZYHT03SWkaN/nMR8aU+YwGIiOcY7YJ0JXC+pNOXanfx/VwNvEfSk8DtjLr7n+whDiLiWPXzBPBlRn8Mu/5e5lrKfhZdNvx7gS3VGduzgQ8Ad3VY/6S7gBuq/9/AaLydlSQBnwYejYi/6SsWSRdKOr/6/6uBdzI6iXQP8L6u4oiImyNiY0RsZvT78O8R8VtdxyHpNZJ+7vT/gV8FHqbj7yUi/hd4StJl1V3vAL6VJY7cJ00mTlK8C3iM0XjyTzus91+A48BLjP6q7mQ0ltwPHKl+ru8gjl9h1G19CHig+veurmMB3gzcX8XxMPBn1f1vAA4BjwOfB87p8Du6Bri7jziq+h6s/j1y+nezp9+RbcBy9d38K3BBjjg8c8+sQJ65Z1YgN3yzArnhmxXIDd+sQG74ZgVywzcrkBu+WYHc8M0K9H/0vCcwmZoyAgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Define model\n", + "input_shape = (x_train.shape[1:])\n", + "\n", + "model = tf.keras.models.Sequential([\n", + " tf.keras.layers.Flatten(input_shape=input_shape),\n", + " BlockSparse(units=2048,\n", + " sparsity_mask_initializer=BarabasiAlbert(9),\n", + " blocksize=32,\n", + " feature_axis=0,\n", + " activation=tf.nn.relu),\n", + " BlockSparse(units=2048,\n", + " sparsity_mask_initializer=BarabasiAlbert(9),\n", + " blocksize=32,\n", + " feature_axis=0,\n", + " activation=tf.nn.relu),\n", + " BlockSparse(units=1024,\n", + " sparsity_mask_initializer=BarabasiAlbert(9),\n", + " blocksize=32,\n", + " feature_axis=0,\n", + " activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.2),\n", + " tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n", + "])\n", + "\n", + "# Plot the layout of the blocksparse layers\n", + "plt.imshow(model.layers[1].bsmm.layout)\n", + "plt.show()\n", + "\n", + "plt.imshow(model.layers[2].bsmm.layout)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/70\n", + "782/782 [==============================] - 20s 26ms/step - loss: 2.1556 - acc: 0.1753 - val_loss: 1.9805 - val_acc: 0.2634\n", + "Epoch 2/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.9448 - acc: 0.2841 - val_loss: 1.8054 - val_acc: 0.3467\n", + "Epoch 3/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.8231 - acc: 0.3385 - val_loss: 1.7055 - val_acc: 0.3920\n", + "Epoch 4/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.7375 - acc: 0.3737 - val_loss: 1.6294 - val_acc: 0.4147\n", + "Epoch 5/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.6781 - acc: 0.3951 - val_loss: 1.5671 - val_acc: 0.4417\n", + "Epoch 6/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.6381 - acc: 0.4073 - val_loss: 1.5279 - val_acc: 0.4520\n", + "Epoch 7/70\n", + "782/782 [==============================] - 18s 24ms/step - loss: 1.5968 - acc: 0.4240 - val_loss: 1.5178 - val_acc: 0.4590\n", + "Epoch 8/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.5736 - acc: 0.4323 - val_loss: 1.4632 - val_acc: 0.4709\n", + "Epoch 9/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.5417 - acc: 0.4455 - val_loss: 1.4369 - val_acc: 0.4834\n", + "Epoch 10/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.5155 - acc: 0.4550 - val_loss: 1.4041 - val_acc: 0.5023\n", + "Epoch 11/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.4934 - acc: 0.4628 - val_loss: 1.4050 - val_acc: 0.5025\n", + "Epoch 12/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.4787 - acc: 0.4677 - val_loss: 1.3833 - val_acc: 0.5080\n", + "Epoch 13/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.4579 - acc: 0.4755 - val_loss: 1.4262 - val_acc: 0.4885\n", + "Epoch 14/70\n", + "782/782 [==============================] - 18s 24ms/step - loss: 1.4410 - acc: 0.4830 - val_loss: 1.3716 - val_acc: 0.5104\n", + "Epoch 15/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.4223 - acc: 0.4876 - val_loss: 1.3763 - val_acc: 0.5122\n", + "Epoch 16/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.4018 - acc: 0.4952 - val_loss: 1.3439 - val_acc: 0.5173\n", + "Epoch 17/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.3894 - acc: 0.4996 - val_loss: 1.3252 - val_acc: 0.5228\n", + "Epoch 18/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.3788 - acc: 0.5032 - val_loss: 1.3389 - val_acc: 0.5164\n", + "Epoch 19/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.3614 - acc: 0.5114 - val_loss: 1.2923 - val_acc: 0.5370\n", + "Epoch 20/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.3545 - acc: 0.5124 - val_loss: 1.2964 - val_acc: 0.5354\n", + "Epoch 21/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.3363 - acc: 0.5184 - val_loss: 1.2856 - val_acc: 0.5435\n", + "Epoch 22/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.3255 - acc: 0.5239 - val_loss: 1.2981 - val_acc: 0.5343\n", + "Epoch 23/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.3168 - acc: 0.5266 - val_loss: 1.2496 - val_acc: 0.5538\n", + "Epoch 24/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.3075 - acc: 0.5307 - val_loss: 1.2631 - val_acc: 0.5495\n", + "Epoch 25/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.2909 - acc: 0.5369 - val_loss: 1.2515 - val_acc: 0.5509\n", + "Epoch 26/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.2864 - acc: 0.5380 - val_loss: 1.2395 - val_acc: 0.5534\n", + "Epoch 27/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.2723 - acc: 0.5449 - val_loss: 1.2600 - val_acc: 0.5462\n", + "Epoch 28/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.2596 - acc: 0.5487 - val_loss: 1.2315 - val_acc: 0.5579\n", + "Epoch 29/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.2581 - acc: 0.5486 - val_loss: 1.2147 - val_acc: 0.5679\n", + "Epoch 30/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.2478 - acc: 0.5516 - val_loss: 1.2240 - val_acc: 0.5618\n", + "Epoch 31/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.2322 - acc: 0.5577 - val_loss: 1.2365 - val_acc: 0.5624\n", + "Epoch 32/70\n", + "782/782 [==============================] - 17s 22ms/step - loss: 1.2227 - acc: 0.5594 - val_loss: 1.2121 - val_acc: 0.5713\n", + "Epoch 33/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.2175 - acc: 0.5631 - val_loss: 1.1999 - val_acc: 0.5690\n", + "Epoch 34/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.2092 - acc: 0.5664 - val_loss: 1.2147 - val_acc: 0.5636\n", + "Epoch 35/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.2017 - acc: 0.5682 - val_loss: 1.1827 - val_acc: 0.5775\n", + "Epoch 36/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1947 - acc: 0.5700 - val_loss: 1.1996 - val_acc: 0.5667\n", + "Epoch 37/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1822 - acc: 0.5731 - val_loss: 1.1867 - val_acc: 0.5754\n", + "Epoch 38/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1815 - acc: 0.5746 - val_loss: 1.1945 - val_acc: 0.5737\n", + "Epoch 39/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1690 - acc: 0.5795 - val_loss: 1.1883 - val_acc: 0.5806\n", + "Epoch 40/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1623 - acc: 0.5815 - val_loss: 1.1801 - val_acc: 0.5720\n", + "Epoch 41/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1576 - acc: 0.5859 - val_loss: 1.1609 - val_acc: 0.5866\n", + "Epoch 42/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1482 - acc: 0.5866 - val_loss: 1.2112 - val_acc: 0.5715\n", + "Epoch 43/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1325 - acc: 0.5925 - val_loss: 1.1698 - val_acc: 0.5788\n", + "Epoch 44/70\n", + "782/782 [==============================] - 18s 22ms/step - loss: 1.1310 - acc: 0.5925 - val_loss: 1.1588 - val_acc: 0.5858\n", + "Epoch 45/70\n", + "782/782 [==============================] - 18s 23ms/step - loss: 1.1267 - acc: 0.5965 - val_loss: 1.1520 - val_acc: 0.5956\n", + "Epoch 46/70\n", + "782/782 [==============================] - 18s 24ms/step - loss: 1.1201 - acc: 0.6007 - val_loss: 1.1707 - val_acc: 0.5883\n", + "Epoch 47/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.1076 - acc: 0.6011 - val_loss: 1.1562 - val_acc: 0.5879\n", + "Epoch 48/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.1057 - acc: 0.6008 - val_loss: 1.1772 - val_acc: 0.5885\n", + "Epoch 49/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.0940 - acc: 0.6072 - val_loss: 1.1414 - val_acc: 0.5946\n", + "Epoch 50/70\n", + "782/782 [==============================] - 20s 26ms/step - loss: 1.0889 - acc: 0.6064 - val_loss: 1.1343 - val_acc: 0.5989\n", + "Epoch 51/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.0835 - acc: 0.6090 - val_loss: 1.1487 - val_acc: 0.5936\n", + "Epoch 52/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.0692 - acc: 0.6159 - val_loss: 1.1417 - val_acc: 0.5977\n", + "Epoch 53/70\n", + "782/782 [==============================] - 20s 25ms/step - loss: 1.0678 - acc: 0.6149 - val_loss: 1.1522 - val_acc: 0.5945\n", + "Epoch 54/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.0606 - acc: 0.6191 - val_loss: 1.1394 - val_acc: 0.5988\n", + "Epoch 55/70\n", + "782/782 [==============================] - 20s 25ms/step - loss: 1.0532 - acc: 0.6227 - val_loss: 1.1625 - val_acc: 0.5880\n", + "Epoch 56/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.0408 - acc: 0.6242 - val_loss: 1.1403 - val_acc: 0.6006\n", + "Epoch 57/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.0423 - acc: 0.6236 - val_loss: 1.1481 - val_acc: 0.5949\n", + "Epoch 58/70\n", + "782/782 [==============================] - 20s 25ms/step - loss: 1.0294 - acc: 0.6291 - val_loss: 1.1264 - val_acc: 0.6063\n", + "Epoch 59/70\n", + "782/782 [==============================] - 19s 25ms/step - loss: 1.0284 - acc: 0.6306 - val_loss: 1.1507 - val_acc: 0.5987\n", + "Epoch 60/70\n", + "782/782 [==============================] - 20s 25ms/step - loss: 1.0132 - acc: 0.6364 - val_loss: 1.1640 - val_acc: 0.5985\n", + "Epoch 61/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.0172 - acc: 0.6339 - val_loss: 1.1489 - val_acc: 0.6016\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 62/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 1.0013 - acc: 0.6379 - val_loss: 1.1774 - val_acc: 0.5915\n", + "Epoch 63/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9975 - acc: 0.6402 - val_loss: 1.1649 - val_acc: 0.5906\n", + "Epoch 64/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9974 - acc: 0.6412 - val_loss: 1.1380 - val_acc: 0.6026\n", + "Epoch 65/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9841 - acc: 0.6466 - val_loss: 1.1528 - val_acc: 0.6013\n", + "Epoch 66/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9766 - acc: 0.6470 - val_loss: 1.1571 - val_acc: 0.5991\n", + "Epoch 67/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9770 - acc: 0.6464 - val_loss: 1.1541 - val_acc: 0.6029\n", + "Epoch 68/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9671 - acc: 0.6525 - val_loss: 1.1385 - val_acc: 0.6029\n", + "Epoch 69/70\n", + "782/782 [==============================] - 20s 25ms/step - loss: 0.9548 - acc: 0.6558 - val_loss: 1.1488 - val_acc: 0.6065\n", + "Epoch 70/70\n", + "782/782 [==============================] - 19s 24ms/step - loss: 0.9539 - acc: 0.6569 - val_loss: 1.1450 - val_acc: 0.6085\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "flatten (Flatten) (None, 3072) 0 \n", + "_________________________________________________________________\n", + "block_sparse (BlockSparse) (None, 2048) 1415168 \n", + "_________________________________________________________________\n", + "block_sparse_1 (BlockSparse) (None, 2048) 1155072 \n", + "_________________________________________________________________\n", + "block_sparse_2 (BlockSparse) (None, 1024) 763904 \n", + "_________________________________________________________________\n", + "dropout (Dropout) (None, 1024) 0 \n", + "_________________________________________________________________\n", + "dense (Dense) (None, 10) 10250 \n", + "=================================================================\n", + "Total params: 3,344,394\n", + "Trainable params: 3,344,394\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n", + "10000/10000 [==============================] - 1s 82us/step\n" + ] + }, + { + "data": { + "text/plain": [ + "[1.1449687225341796, 0.6085]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Compile train and evaluate\n", + "model.compile(optimizer=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])\n", + "model.fit_generator(gen_flow, epochs=70, workers=2, validation_data=(x_test, y_test))\n", + "model.summary()\n", + "model.evaluate(x_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:TensorFlow optimizers do not make it possible to access optimizer attributes or optimizer state after instantiation. As a result, we cannot save the optimizer as part of the model save file.You will have to compile your model again after loading it. Prefer using a Keras optimizer instead (see keras.io/optimizers).\n", + "WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.\n", + "10000/10000 [==============================] - 1s 92us/step\n" + ] + }, + { + "data": { + "text/plain": [ + "[1.144968727684021, 0.6085]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Save the model and reload it\n", + "model.save('blocksparse_model.h5')\n", + "del model\n", + "\n", + "model = tf.keras.models.load_model('blocksparse_model.h5', custom_objects={'BlockSparse':BlockSparse})\n", + "model.compile(optimizer=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])\n", + "model.evaluate(x_test, y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10000/10000 [==============================] - 1s 94us/step\n" + ] + }, + { + "data": { + "text/plain": [ + "[1.1449687286376953, 0.6085]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Save and restore using model config and get weights seperately\n", + "config = model.get_config()\n", + "weights = model.get_weights()\n", + "del model\n", + "\n", + "model = tf.keras.Sequential.from_config(config, custom_objects={'BlockSparse':BlockSparse})\n", + "model.set_weights(weights)\n", + "\n", + "model.compile(optimizer=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])\n", + "model.evaluate(x_test, y_test)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/keras/blocksparse_layer.py b/examples/keras/blocksparse_layer.py new file mode 100755 index 0000000..532f0fe --- /dev/null +++ b/examples/keras/blocksparse_layer.py @@ -0,0 +1,192 @@ +"""Block sparse layer. +""" + +import tensorflow as tf +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import activations +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.layers import InputSpec +from tensorflow.python.keras.layers import Layer +from tensorflow.python.ops import nn +from blocksparse.matmul import BlocksparseMatMul +import sparsity_pattern_initializers as sp_initializers +import numpy as np + +class BlockSparse(Layer): + """ A blocksparse variant of a regular densely-connected NN layer. + + `BlockSparse` implements the operation: + `output = activation(dot(input, kernel) + bias)` + where `activation` is the element-wise activation function + passed as the `activation` argument, 'dot' is a blocksparse variant + of the dot product as defined in `kernel` is a weights matrix + created by the layer, and `bias` is a bias vector created by the layer + (only applicable if `use_bias` is `True`). + + Example: + + ```python + model = Sequential() + model.add(Dense(32, BarabasiAlbert(2), input_shape=(16,))) + # now the model will take as input arrays of shape (*, 16) + # and output arrays of shape (*, 32) + # after the first layer, you don't need to specify + # the size of the input anymore: + model.add(Dense(32), BarabasiAlbert(2)) + ``` + + Arguments: + units: Positive integer, dimensionality of the output space. + blocksize: values 32, 16, 8 supported + feature_axis Boolean, when block_size is less than 32 memory + access becomes far more efficient with a (C,N) activation layout + sparsity_mask_initializer: Initializer for the sparsity mask for + the blocksparse weight matrix of the `kernel` weights matrix. + sparsity_mask: Boolean numpy array, defines the sparsity mask for + the blocksparse weight matrix. If a mask is given the + sparsity_mask_initializer will not be used. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation").. + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + + Input shape: + a 2D input with shape `(batch_size, input_dim)`. + + Output shape: + a 2D input with shape `(batch_size, units)`. + """ + + def __init__(self, + units, + blocksize=32, + feature_axis=1, + sparsity_mask_initializer='barabasi_albert', + sparsity_mask=None, + activation=None, + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + + super(BlockSparse, self).__init__( + activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + self.units = units + self.sparsity_mask_initializer = sp_initializers.get(sparsity_mask_initializer) + self.blocksize = blocksize + self.feature_axis = feature_axis + self.activation = activations.get(activation) + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + if sparsity_mask is not None: + if isinstance(sparsity_mask, dict): + if sparsity_mask.get('type') == 'ndarray': + sparsity_mask = np.array(sparsity_mask['value']) + self._initial_sparsity_mask = sparsity_mask + else: + self._initial_sparsity_mask = None + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape[-1] is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + + if self.units%self.blocksize != 0: + raise ValueError('The number of units should be divisible by the blocksize. ' + 'Got {} units and blocksize {}'.format(self.units, self.blocksize)) + if input_shape[-1].value%self.blocksize != 0: + raise ValueError('The input_shape should be divisible by the blocksize. ' + 'Got {} units and blocksize {}'.format(input_shape[-1].value, self.blocksize)) + + mask_shape=(input_shape[-1].value//self.blocksize, self.units//self.blocksize) + + if self._initial_sparsity_mask is not None: + if self._initial_sparsity_mask.shape != mask_shape: + raise ValueError('Incorrect shape for initial sparsity expected {} got {}'.format(mask_shape, + self._initial_sparsity_mask.shape)) + sparsity_mask = self._initial_sparsity_mask + else: + sparsity_mask = self.sparsity_mask_initializer(mask_shape) + + self.bsmm = BlocksparseMatMul(sparsity_mask, + block_size=self.blocksize, + feature_axis=self.feature_axis, + name=self.name + '_bsmm') + + self.kernel = self.add_weight('kernel', + shape=self.bsmm.w_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True) + + if self.use_bias: + self.bias = self.add_weight('bias', + shape=[self.units,], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True) + + def call(self, input): + if self.feature_axis: + outputs = self.bsmm(input, self.kernel) + else: + outputs = self.bsmm(tf.transpose(input), self.kernel) + outputs = tf.transpose(outputs) + + if self.use_bias: + outputs = nn.bias_add(outputs, self.bias) + if self.activation is not None: + return self.activation(outputs) # pylint: disable=not-callable + return outputs + + def compute_output_shape(self, input_shape): + return (input_shape[0], self.units) + + def get_config(self): + config = { + 'units': self.units, + 'blocksize': self.blocksize, + 'feature_axis': self.feature_axis, + 'sparsity_mask_initializer': sp_initializers.serialize(self.sparsity_mask_initializer), + 'sparsity_mask': self.bsmm.layout, + 'activation': activations.serialize(self.activation), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint) + } + base_config = super(BlockSparse, self).get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file diff --git a/examples/keras/blocksparse_model.h5 b/examples/keras/blocksparse_model.h5 new file mode 100644 index 0000000..5c29782 Binary files /dev/null and b/examples/keras/blocksparse_model.h5 differ diff --git a/examples/keras/sparsity_pattern_initializers.py b/examples/keras/sparsity_pattern_initializers.py new file mode 100644 index 0000000..e0a1d9f --- /dev/null +++ b/examples/keras/sparsity_pattern_initializers.py @@ -0,0 +1,87 @@ +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +import networkx as nx +import numpy as np +import six + + +class SparsityMaskInitializer(object): + """Sparsity mask initializer base class: all sparsity mask initializers inherit from this class. + """ + + def __call__(self, shape): + raise NotImplementedError + + def get_config(self): + """Returns the configuration of the initializer as a JSON-serializable dict. + Returns: + A JSON-serializable Python dict. + """ + return {} + + +class BarabasiAlbert(SparsityMaskInitializer): + """Initializes a Barabasi Albert graph and slices it to the required shape + + For details see: + networkx.generators.random_graphs.barabasi_albert_graph in the networkx documentation + + Examples: + + ```python + # Generate a Barabasi Albert graph with m = 1 and shape (5,10) + ba = BarabasiAlbert(1) + mask = ba((5,10)) + ``` + + Arguments: + m: Number of edges to attach from a new node to existing nodes + """ + + def __init__(self, m=5): + self.m = m + + def __call__(self, shape): + """ + Returns: + 2d numpy array containing the adjacency_matrix of the generated + Barabasi Albert graph sliced to the requested shape + """ + n = max(shape[0], shape[1]) + g = nx.generators.barabasi_albert_graph(n=n, m=self.m) + a = nx.adjacency_matrix(g).toarray().astype(np.int32) + np.eye(n, dtype=np.int32) + a[0:self.m,0:self.m] = 1 + return a[:shape[0], :shape[1]] + + def get_config(self): + return {'m': self.m} + +# Compatibility aliases +barabasi_albert = BarabasiAlbert + +# Utility functions +def serialize(initializer): + return serialize_keras_object(initializer) + + +def deserialize(config, custom_objects=None): + return deserialize_keras_object( + config, + module_objects=globals(), + custom_objects=custom_objects, + printable_module_name='initializer') + + +def get(identifier): + if identifier is None: + return None + if isinstance(identifier, dict): + return deserialize(identifier) + elif isinstance(identifier, six.string_types): + config = {'class_name': str(identifier), 'config': {}} + return deserialize(config) + elif callable(identifier): + return identifier + else: + raise ValueError('Could not interpret initializer identifier: ' + + str(identifier))