From 194a1b0715418b6361614977ea5c494e2a6de94a Mon Sep 17 00:00:00 2001 From: zxiaomzxm Date: Wed, 17 May 2017 16:45:51 +0800 Subject: [PATCH 1/3] fix a bug about KL loss term --- draw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/draw.py b/draw.py index eefff2a..c8fdf2c 100644 --- a/draw.py +++ b/draw.py @@ -188,7 +188,7 @@ def binary_crossentropy(t,o): mu2=tf.square(mus[t]) sigma2=tf.square(sigmas[t]) logsigma=logsigmas[t] - kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch) + kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma-1,1) # each kl term is (1xminibatch) KL=tf.add_n(kl_terms) # this is 1xminibatch, corresponding to summing kl_terms from 1:T Lz=tf.reduce_mean(KL) # average over minibatches From 86e2467162a7c745445abd05a553cf01a583e51d Mon Sep 17 00:00:00 2001 From: zhaoxm Date: Fri, 19 May 2017 17:43:03 +0800 Subject: [PATCH 2/3] bug fixed. --- attention_test.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ draw.py | 53 +++++++++++++++++-------------- img/elephant.jpg | Bin 0 -> 10648 bytes plot_data.py | 11 ++++--- 4 files changed, 113 insertions(+), 28 deletions(-) create mode 100644 attention_test.py create mode 100644 img/elephant.jpg diff --git a/attention_test.py b/attention_test.py new file mode 100644 index 0000000..be7abb4 --- /dev/null +++ b/attention_test.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Fri May 19 14:15:42 2017 + +@author: zhaoxm +""" +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +from scipy.misc import imresize + +A, B, C = 100, 75, 3 +eps = 1e-8 + +def filterbank(gx, gy, sigma2,delta, N): + grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) + mu_x = gx + (grid_i - N / 2.0 + 0.5) * delta # eq 19 + mu_y = gy + (grid_i - N / 2.0 + 0.5) * delta # eq 20 + a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) + b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) + mu_x = tf.reshape(mu_x, [-1, N, 1]) + mu_y = tf.reshape(mu_y, [-1, N, 1]) + sigma2 = tf.reshape(sigma2, [-1, 1, 1]) + Fx = tf.exp(-tf.square((a - mu_x)) / (2*sigma2)) # 2*sigma2? + Fy = tf.exp(-tf.square((b - mu_y)) / (2*sigma2)) # batch x N x B + # normalize, sum over A and B dims + Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps) + Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps) + return Fx,Fy + +img = plt.imread('./img/elephant.jpg') +img = imresize(img, [B, A]) +plt.figure(1) +plt.imshow(img) +img = tf.cast(img, dtype=tf.float32) +gx, gy, sigma2, N = 75, 50, 2.0, 12 #(A-1)/2.0, (B-1)/2.0, 1.0, 5 +delta = 3 #(max(A,B)-1)/(N-1) +Fx,Fy = filterbank(gx, gy, sigma2,delta, N) +Fx = tf.tile(Fx, [3, 1, 1]) +Fy = tf.tile(Fy, [3, 1, 1]) +Fxt=tf.transpose(Fx,perm=[0,2,1]) +img=tf.transpose(img, [2, 0, 1]) +glimpse = tf.matmul(Fy,tf.matmul(img,Fxt)) +patch = tf.transpose(glimpse, [1, 2, 0]) +Fyt = tf.transpose(Fy, [0, 2, 1]) +recons = tf.matmul(Fyt,tf.matmul(glimpse,Fx)) +recons = tf.transpose(recons, [1, 2, 0]) + +sess = tf.Session() +atte_patch, recons_img, FX, FY = sess.run([patch, recons, Fx, Fy]) +FX, FY = FX[0], FY[0] +plt.figure(2) +plt.imshow(atte_patch/atte_patch.max()) +plt.figure(3) +plt.imshow(recons_img/recons_img.max()) +plt.figure(4) +filter = np.zeros([B, A]) +for i in range(N): + for j in range(N): + filter += np.outer(FY[i,:].T, FX[j,:]) +plt.imshow(filter/filter.max()) + + + + + + + + + + + + + + + diff --git a/draw.py b/draw.py index c8fdf2c..b82737f 100644 --- a/draw.py +++ b/draw.py @@ -13,7 +13,8 @@ from tensorflow.examples.tutorials import mnist import numpy as np import os - +os.environ["CUDA_VISIBLE_DEVICES"] = '2' +tf.reset_default_graph() tf.flags.DEFINE_string("data_dir", "", "") tf.flags.DEFINE_boolean("read_attn", True, "enable attention for reader") tf.flags.DEFINE_boolean("write_attn",True, "enable attention for writer") @@ -25,15 +26,15 @@ img_size = B*A # the canvas size enc_size = 256 # number of hidden units / output size in LSTM dec_size = 256 -read_n = 5 # read glimpse grid width/height +read_n = 2 # read glimpse grid width/height write_n = 5 # write glimpse grid width/height read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size write_size = write_n*write_n if FLAGS.write_attn else img_size -z_size=10 # QSampler output size -T=10 # MNIST generation sequence length +z_size=100 # QSampler output size +T=64 # MNIST generation sequence length batch_size=100 # training minibatch size -train_iters=10000 -learning_rate=1e-3 # learning rate for optimizer +train_iters=200000 +init_learning_rate=1e-3 # learning rate for optimizer eps=1e-8 # epsilon for numerical stability ## BUILD MODEL ## @@ -56,15 +57,15 @@ def linear(x,output_dim): def filterbank(gx, gy, sigma2,delta, N): grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) - mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19 - mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20 + mu_x = gx + (grid_i - N / 2.0 + 0.5) * delta # eq 19 + mu_y = gy + (grid_i - N / 2.0 + 0.5) * delta # eq 20 a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) mu_x = tf.reshape(mu_x, [-1, N, 1]) mu_y = tf.reshape(mu_y, [-1, N, 1]) sigma2 = tf.reshape(sigma2, [-1, 1, 1]) - Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2? - Fy = tf.exp(-tf.square((b - mu_y) / (2*sigma2))) # batch x N x B + Fx = tf.exp(-tf.square((a - mu_x)) / (2*sigma2)) # 2*sigma2? + Fy = tf.exp(-tf.square((b - mu_y)) / (2*sigma2)) # batch x N x B # normalize, sum over A and B dims Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps) Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps) @@ -75,8 +76,8 @@ def attn_window(scope,h_dec,N): params=linear(h_dec,5) # gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params) gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(params,5,1) - gx=(A+1)/2*(gx_+1) - gy=(B+1)/2*(gy_+1) + gx=(A-1)/2.0*(gx_+1) + gy=(B-1)/2.0*(gy_+1) sigma2=tf.exp(log_sigma2) delta=(max(A,B)-1)/(N-1)*tf.exp(log_delta) # batch x N return filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),) @@ -195,13 +196,14 @@ def binary_crossentropy(t,o): cost=Lx+Lz ## OPTIMIZER ## - +global_step = tf.Variable(0, trainable=False) +learning_rate = tf.train.exponential_decay(init_learning_rate, global_step, 10000, 0.1, staircase=True) optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5) grads=optimizer.compute_gradients(cost) for i,(g,v) in enumerate(grads): if g is not None: grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients -train_op=optimizer.apply_gradients(grads) +train_op=optimizer.apply_gradients(grads, global_step=global_step) ## RUN TRAINING ## @@ -211,7 +213,7 @@ def binary_crossentropy(t,o): train_data = mnist.input_data.read_data_sets(data_directory, one_hot=True).train # binarized (0-1) mnist data fetches=[] -fetches.extend([Lx,Lz,train_op]) +fetches.extend([Lx,Lz,train_op,learning_rate]) Lxs=[0]*train_iters Lzs=[0]*train_iters @@ -219,18 +221,23 @@ def binary_crossentropy(t,o): saver = tf.train.Saver() # saves variables learned during training tf.global_variables_initializer().run() -#saver.restore(sess, "/tmp/draw/drawmodel.ckpt") # to restore from model, uncomment this line - +#saver.restore(sess, "./models/model-10000") # to restore from model, uncomment this line +x_binary = np.zeros([batch_size, img_size]) for i in range(train_iters): xtrain,_=train_data.next_batch(batch_size) # xtrain is (batch_size x img_size) +# p = np.random.uniform() +# x_binary[xtrain > p] = 1.0 feed_dict={x:xtrain} results=sess.run(fetches,feed_dict) - Lxs[i],Lzs[i],_=results + Lxs[i],Lzs[i],_,lr=results if i%100==0: - print("iter=%d : Lx: %f Lz: %f" % (i,Lxs[i],Lzs[i])) + print("iter=%d, lr=%f : Lx: %f Lz: %f" % (i,lr,Lxs[i],Lzs[i])) + if (i+1)%1000==0: + saver.save(sess, './models/model', global_step=i+1) ## TRAINING FINISHED ## - +saver.restore(sess, "./models/model-200000") +#saver.restore(sess, "./trial/t32z10/model-30000") canvases=sess.run(cs,feed_dict) # generate some examples canvases=np.array(canvases) # T x batch x img_size @@ -238,7 +245,7 @@ def binary_crossentropy(t,o): np.save(out_file,[canvases,Lxs,Lzs]) print("Outputs saved in file: %s" % out_file) -ckpt_file=os.path.join(FLAGS.data_dir,"drawmodel.ckpt") -print("Model saved in file: %s" % saver.save(sess,ckpt_file)) +#ckpt_file=os.path.join(FLAGS.data_dir,"./models/drawmodel.ckpt") +#print("Model saved in file: %s" % saver.save(sess,ckpt_file)) -sess.close() +#sess.close() diff --git a/img/elephant.jpg b/img/elephant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a21e822ddffced361ede5e84dbfa0cde46be9231 GIT binary patch literal 10648 zcmbW6cQ{=w@^g%!tnDL`{?sW<-lFY8ag;A)@!*B7_8^6Fpim z7$PKE5>i~h_r1@(&vXB}_pbAt{mi{MlEwmN@2m}CtR~O)N z8K426qND^uhEAaV*yFcmco?Nx_XCIA@_1R^5`QBaVRU-b^TdJZ6GreNWd zRHtM$asc!Dut`Ov7EuY@Y3yS+{`eg(?TC$}rnz>V1HySjP)JxrR7O@#9)VQQxT~q9 zjY8{~m}1P#EiA2^oLwF~baivb`9AUU4+sp3dKMiM8yBCDmY$KBmHi?o_jPeeX<2zi zWmQviOKV$uN9UXOANmId2}8pp(=)Sk^9zfgmNqttTiZLN-MueI$0w&}KYso?zxWRq z5CHmbtgH6F!Tt{}<|{5Ta&izk_&;1gGXJX^#7s`XCrQboZUlDlVda;Kpklj|TGZG_ zEg)_Do!t@pk>(m)X5+@uf6)Gm?Eeld^8btMe}VmPt|b5+2zWJlAZCCnATfBX2I%z% zm>ZVGRki+uK6$z#Cr$rZ#Tmz6QE%o2#c-95nn%7TEX2{c*O@#WLQmv}RIr`mHTnS$ zAb;InJp{&?`Ot&t3vt?WiAxD)A0hp>$)h z;kdtRb7;tL=9*{s$!S?gNt&`ILtc-e!f8nYj28?zB~7WqowC76`D9bGIEA{R+6|I;v1+l=^o3hW^tg2Va zkJO)phfzX#r@QERG0!pKDRj7UmbL$eZoQAnaPzg`mvotLo zdplxcfSa*^nmVUWckaG5+7Xj59HI0^jk~GWE*Y$*r>Hm(WQX>~eQ#%;&UT&z%{Imy zy7P%+>ztlEE~*gh447auGk&xCP+5T!Av0|)OK#Gr?>4Aw7ChMq?f*T*4pB@pdz!`F zte3En)5*a$|0dn{=S?EptK7s+{=6L~L{(!D-+>&x`~|FgmP zrGxQhxhI*=ABI_#XbxNU1V2RlMeOUj$<@QQDnLusElI6~$l=XA@qwUAz}bM2yvXMj zmPG&~R;+URxn_wbjU2}OWibSv{?{X+>rt9IQm{+W(kb@C5BD_sBo3j+X38gD6++pR zwYg;lIvg$jysqZ%_DRr6#|1~7lu#MKHzObWEUuff_*V;YS5Y&`Pp2?JnW$8)e2t?7 zr@}7IZf6yFQ#h|t(s#>K+F@Jexs%YK;w|Q6A5;ylZ;Mf<)U9Gz&g=nQc#so#zP7kd zT9@{_os#uAiQ{kB#as}*mRjH2$v0YIqW$9Np~EnR47s_B0@o%Hz2d4UYk;sL*xTd5 zq=lIa9!n(@8MIJ%^MltAV!+jMCEK3SZZ3bXEXm32LVCs`8fsh&ns08|@AG%A3qAf* zd&cf6!A_F-<|McJ)sm0?&VU6ICB4MxQgI;0^kTMU_))E3iGB8n(uvtd=k4WMM~ds$ zg=T~1Q^ru&m8L!>Ixlw?zuio&le149FitR6)4)D8`LpSy7=3e`^`>-lV!e)gHn)3? z$re$`^cw;MJ}FrSGuv&aHXzwfQ<~CEQ-^UYrps|W8CL7Xjs`6fn+eX?Xd9`mkD9R~ zfh-03hk=9ccHs?5ChUkcM&%QIlW=t?k2d>PIwL;15RtV9Y6Pn!{h9IGe#A%-3WBrb zH}&3q!z^dyPKMq)$iX(9JN>}S%#H@SnS=KkM>{h} z&^z7Nb&S*o-O%(JD(sSt!pi5x{mvi{WdlMbggzN-GQ@J>83pSBx(+d9#BJXB>JM@U zD)Jy;_+~9BXx+c;S$|Sn+n0P)M;~;CMM_^1b;K@EE6ExS=1_KbMVp;7NE&KRfFWnq zMnvSd(^n6TUabY104B(>-li893ubf<4HnL)^{Q=7A8vDbw6vuPwI|%)6ez6qSx9<0 z@$9h9W|6Zf_cJWjMDWX_~YYNWu840DpN⪚YJlfES*=ejxara_ zV}Ztt-uA1E{4XN%8?bIe63fxb$)&QzA{Ng?0O5h`d|bu*LuJ>NK9y)G=B?G1e0<%^ zhH}zPp3l%3KY5t&&?ZVhn5wsQB*m(~>nA|{^(#}w02w0(o(Bro7zrA=Ol`QX)G6cO zQuLsf0{-;XSDRP#fA(jj1SlPt4XRWAE{9_0VIMRcFPdMxREmU z*xCtXMdC9NyMHa=9&P3b<*?d2ge}@Xx^BN4s=rOYSY!xP6(!2_y4oBw1~cOJ4JX6@ z4IzHYy0L{TvpqbjxcjuvX)c)!s6JIUrErY?S)HpCD>YC+$q^=+;5!`50O8@z9N5c!Q% z7p5kr31y=GCs6o49Gc4A7os?uMD^&$&$7dRe1ryBaDi!o9zZAy^Ih3}nkoEZ?!F=p zgb#^Ex)^$+Cf~8&>1=df7vWa+D4h}-W4S}P!6mR~2TRMyW3bPAWN4PF=^G68IUrbjjSD(s4lo+@os+Vbln)>o6^}eF16FwjjAn-wd~L)p(QAp-dqY4U%h8Q zo}8`3oN5?te=m}nfP`*l+DV|`sk1a^lsXKCCqqAWa)HY)S1wE<$hVvd=w&=_X@>;? zv`FMU#ktWh>3`Y6InR zQ`q8Hghr0jd}wOK)c#2HuR6B6R?AXywyz@~)Gl*)FKMjEC1RYt#&Ob01PAHxXx!#{ z7q89$4%;jhx6H%?l$doPTFVm(^}MApwXdd5h#}a*Ycy%l=&$6d-lq5%1d5?Zmks}PiKBFaHDHOlsi z{XAb}P_T89@-nonR)Bt-k5@F2E|(I;ZXb`)vg`nN${carJ?}gNbyg71UyKkyuM--8 z6Mn>%T(!G~$uEdURn3G6gZn{meew(cy}JZh73HXtM@lXNbJZMeG+TH}&8NRz8#QKx z@mT+5c z!=9s3Dbno&Cwq3>CV|7XFsjbZzrsL)HrX+CZL3^< zN+=1H;?#==Zw3$(9x*k>=?VfK?#ji0hy1yUvlfbY^o2=e+9Kuap;`sTAqR9sA}8iN zXzKhC^MDxtBiQ{s?#fBdGq(2x1lwkR!cF^T4TZweTRM1!87k-C>1Gxw3dJNvDt~$S zM|E&F+j-PW9^I}HKtBs3cSDGXZeF;{Iw{3nRX!U6f&LBI73aA#`hk6zW8W@fOh_Ak zRKPf&t%xSa_hfWMK$P_apmgi@367lRd~wQeOwf~{IuYxbL3hC$$&mAoOMq%Lh>FV{ zMJ!Tc+%`yxYS48tsqd_syv1?18&^14Y8`(~f_;dGlmO`DY*YeH5{)Z{REm3^&N8nv zM;ipw?Ubz+puu=r{Xbp2gETEltHw9QIL>j83+{Ee`E_(-U^yYL;=KRTdzBSH`3M zjfS1=85VxBFBY@S->(!7QLDm<*|Q}r-m+{&W!?MfNk1ki&wVA23fzm+ zg+e1+^tZCa;iF81Ix$Od5UFPC6PD?yAoHX42#q}guS(-4q;htfu3LdeR2t79y|3%~ zac0}JGxEXE;(Ge2>PGJA4bwE(`Zj7vrFBPwMf*f{_YE!NG*14cRVRkK7g1+(*H?F% zbJybdAFky=ew8}xU4V>)0t=p1NSN_27tjy^#xZeG3bu1`Sq*=Yg%TOBK z#e4&9W-+vrI*0ApH{ox%QzO-TTK3k9!S4tyDfm@Glm$`3B;%1j;eP6K-RQ#f{bv|u zC9aur>1j@+yhFOCd!CIq6`RpSEUwo|?PJMkA3%-t>Q=&zX$ zGouAqo%W8{g0v9%S5O@_JFB!Ojk30mD~@jn`rQ>QkoD#16sX<9sdH+%RfiCq;eeD;&J`1TYkhSu0Kp=a=f|bi!9pprzwH7vVqX|_SK!!c8 z@PjI8t%N*0=;&Wqsm9OySNqd2h^M9Wxep$Pa?C<1SH@P#{3|rE>fxGUD6iFk*bG2U zcy|>;14;#7*9~+j)YG|M0zS!6J~aA~mneM+Xi}tEq4THv>puNnGJ@cavyQf(+ZjGl zY%WieLr*F|*l57p5-@90S2*rf{jqgH>1 zg;}^4A@dCZTPn_<#*n!AAFNKwH2%Wzp*jB<3JOs1ksQt@!%82n?X_Dee{6eWCw+_p zj}Vu-t&ot+co6bnsrrlAolAg00e!{xR5bkzg1BdG(r=Jwi$-@R--59Z?d!VqBDgOB zlc#kB_bGHzZ;fVJ-uW}4Z(8UyUp=~^rylCOv`$2*ZrJ`An)zhV7@h+z)2{U|V16TQ zG3rGkB}Ltyx&#!?cT%@_>dzg|V~KU^U9ndor7ifTqP4AFhoZ^^vAlsd)5R8N;wNNr z=+9~k_sMTqH*rdZ!`f}0)ambIQbyw=jP>Q_@{}!aWV^S+%wGx*3kX@8l%z(6%oyc` zYqRUxfT0C*<`Z0_lzS!*=3i}ns=8HtVi~Y%EPm8%pK3!~kM=U0nMBkq$$C5n;@J^rBDv&$Qu!;ZC0agW=i4MVEEQ1OiLb zV3+CXPJ%mC@Yu0niulpJ7EiBt8kc~Dauj$tO^XULSyXkn7*6M%14!M)B~KTM`7Rl( z4Uu&h8|y71l+H*;MR^=cs|ZL1=9j9(Ac8QT12002ZrwhY=?$Ky^Ox1){j49bUQj|z z!2ei3Knx4E$}9#~dwT-~vXouPR#0S#PP|3~=-H~oGtzWAw%oS1&{6cRV_$wafCFGi zRxfVQL!s02_Cc|F8kjh)gMjc(gmbFxcb(N;0($*=%nD_Pp(DS@e}XOn^||ZWxDG=p zbZ@=gwD0Ez*B?K=%`thYkm~ns7ISWJvD4aZXNeVN=gP}dyfJMguw;=x4EHf6a1Q1B zq@J4>jRCt~IuO;t1aId;x%XLr>ui4A8ZxMu)FEZ&22xK7&9}hf`oZJV-5fJUW@;s_&r2uA z`g!taveu!G^4u=wL;6z{l0J)YqpEh9wjQe{saR(oFvKW+Qsa0N{F^iH;HR=#NWY)# zc;%;$@dF5(p?omh^fz49J*lz6=AMG4KvT>kXU?u$WmA1kw^M)wJF#r}q;Br`UWb|a z8m`(8;`)axa~1IDgL~IGGq!m1cdF?fH&vQ=f@WDzTu4}MvBK8TbfGBUYzRXncaWJ4 zC6Zgxq^?U!3eB(eRh{M-27el~z%%hNza}wu7W1O@`kUmDw_M@BoA522ij2$h#R{<< zGq-P;$mle&Mjnx941zi1edE7+<)Q|+1pIT$MK!VCZ6D=VwP`=SXrg|? z4)U8>Dh}|MT!~Vqa4T`AWrvibAs9MStfJ|O?V^Ek^gw;~c0_6>Wx|_JPi$9+5eJ$y zy5`!GSgqeSWP$Lr{eqTZrK#RZnCFW0e!|NVff*R@*bFDezE=XU9!l07iBiuoZ8}v} zD^vN3o7uBk3X47Sv{K7zYuY$p``{QFwqd&y^?h^jQ(|s&NRvaj0C`;BzlVR0EN1&n z%Db4i#$BFBQ4`xTYcs>$8`h&1v)XJDIoC%y zmRr1dD+U6--hFFN0wi+gvUfLm&GX(_P8AvA75?3Gx*4RU!0^t*C~zc>s{#t+9)~}= zb1XGJTo-Z8-~J&u*9;4WsOo1-tcI2qaJiEz%Ea<`*juj`z`{}3)I!PTgPF5|p!`~~ zmx(%+(5DNm-!FvU{$7rOeBzAwmHv-?jT8pi2p~K zFWi5QSGZ$!2zV`S8jTnSbt;t$o5dZrzgnVc*yw4sEV+sVEOA?t-`npyEk)Q&n{iA?qx$4n3q%`I z);DtmP7On@UScDLs#Q2;J8K9Oks|1HufmMm9hpA3lIq%)|H?BJTn!0g(#4K}Wy2 zHGK2Yos1fGIUC(-wiY$w=aSszf0s4~O@0P`3|<)-n~4}3paxe$XA4#%hKqegbK-jp z1IA$iV~={H3UIk6fzR3ml5FTwkR(u|>0`;y?d|pyls@qHGK%V?FQWIeDe`XY?#vkZ zeoRgJil$Exxf=jK+&Ux57!}DOUi4hwou8;z9-n((CzTf&bFu#LtS-;s`(cv3oaNs4 zyNl#1;1t9FRgX2< zXWvHfQe|ylhCI*$H?DFoZpyilF_98_b((BV>+VBK3l01nWsH6TgcyI#lUUmVU2q;% zWO%~Azd92-H?Iizv2`o5QE+8gV#Q||r25zWc(<1^tu*hxAJ<%C=ETrj( z(d-iXi=z+I(Hu?-fFPPYS5TdHsUMYlxN510*QX*pPu5Lae#x%ApdEa9_T{nMe4K{ljub9($Vcr;h|Bc#OSxI^l0en<; z-W3~ERsxL?LT;8!D3}|GXQZr^Oyho{&JB$8SXg^|CtHz*&j==mI5m1vb&{%;1Cl4E zOfi9|sGEo%tJD2I&_}KipvC_~* z;uCf+YO{d;>J&!kS=Ax%03+X;zfo*E$_LB{Qvl4Xiqv!sNZAZZ=974>KKiYcL}7Fv z+^fDy{zGQzo=mtYv-jf3Mt?RR=oesIzklIYj50s-1RKr3?HvO>r^s%or)XCmic#c5 zZQF#0d~uS!;j74b&Ap2;vsVblY8&=!C+i-e%^c7HF7$d=>eKd2EQULdI5l5Z82ZEN znK|>{ZTt6`sFMLgM@>*dPw)Padg}Py=EgQZP4Ul9n1U`oUb|g{W(vpW(CLRTJ}xJj zv8?h7ms3PJM|@V^AMn`@N>b09OD?E-b;uh)EVsIn9CUg5>(b%U$tMF^a9Z#2pIScQ z(rcxGY?Mo(_Glr9s*X&j5t=!$PKzu(Sw~c6gp8wUi-g$MSPX{ zZ|9#mO@f!QjAs}#8T{5+SHe=wxL;F^0ATBQ1?GQxx@6~W?XF4%;PbtN(i=~cQ+*Nj zapJ{avzB8pXxLLC`v*_jNnLXjVtmy81EokOHOASKzaAM*r|e%hirtN5f>^cm)pkX( zYDO4-Z-tXReJQ%F{NHAqaKB+fP#MRhFJKt*ezs}2O)fNW-Si(%yS&wSi@h`Cdh%eI z{s>VyVQ#&3=HMX8MIhcJ*~rYPtmSlGCXbEe<(&x{-CdcRPzl;xlS|8&LSJ&z(x?3x9G+ zt&9|I7coC|m(XI5$6!v7b*3;CNi)qgepnWtgao6S!oSR7@~#L!UIFIB-wR~esJH66 zo40y+S8fcvvHpC`jPT-X7mBo)JKPoV^_$(iUe_jE(qs^4Z72Tjt?h$*0bbnKZg!CH zIXj&@nJ<5i^Z7q=nS?Kts_T`CNaZQ>zW1-Y5Y{*e)HejpfFKwXX8 zM&y3DI{XL!3vMB_^Cs-A#G&-sOQpCXx|edN;B*U_!R^ximr*_JvRVM3*qe0DL2u}P}t&LHpHhMlgbiD^N_vTR1NpV#E$eVHhDz@CmVtkJAmjIhE z(TVO>b@$1R+)mLTZhh-xJ7Q@Y0;GR;r%ZQTpvy zQ-(}&szqb_ZR7%x!!o~KOAWokhrOtGvl1O>ue2TH66l6a(4(eMjA8^sP3r(Mb|6;g z&-y6@QPd!H%Xtg7i>xc`yEZLX(k$w=RDxGm=BqaNm_tn@R(yolu1GWjl0zmHlfaxq2?aWf(>0Uqq}VU^G0J~@IF-#j@Z zs3w_w=C5kb&-blEeY7c4IP{G*NgV63CV}9r3+=~~B=hl$=NqDO+xBu>#6OpSi|V*G z%jb%$6B6KVHw~HCEj!CN8t`eQZ)ShVIKA!siRpBoQ!9hMw$J2#LVCa44pChkYj=(Q zTYLFgP4#nXr&T57t3JV648yl2Mw7Myz1oA{dE&(*Bw-Vk+*4!F{iye1UhQ+jPKFQ$C zjOSfNmb>NN74^E&8K>lON3`&pPnxP#GkwOjZycHe@I>d_WY>Co?A=haDQigY>Ape2 z>LJmGcdZYk(jnMFr(U}V96B&-`E3MMu|N0LSN?UiaArrU(2KE#Cponx$$^;=@8v-D z8>zykc;voa%o`hj@8)qR)}a9;tQspj>CR~pk50H@)Ws{jp-w38W~$0-)bSUiqzD&M z+04PgTn&^Z$1U1tcHH_kamPF%->*u>;uk1G!q_=}wwJWiE5!5RVs69@`8(YxF? z$!woS&k0-5pXVmrQh)Bg?RX07dIDC#*cuKPuXBPKjC{K8b!!97ZBkd}cHp>OAeVHU(hg>Q5| zDOl;ftF@#vn{3Z9gfVJqP?Z>Wl6~lNV+FNl|L}$6YUB%t)DO1}beU8|o&Eid-|Fb= zV=bUOuR=9s9sDTz;4wB50(+=<&a!ot^IoriANB_H249X39U4h-ss!vRBmQMz3);Eptw0_(#O!QSQ=bf=?rm6G# zA!1}HAGC|T7wiZn5P>H36#A!UFz1 z$#r)75g6;^I?B3|U--RhCy4*nTsqrGsNt>#8KUwaeWJDCp`sd}R*1nRp!%$`v_1Ji z#~Uroc;DUBW|?hxX)8^6uan`2i}Q%pwmBpP%b?y%JkK@f!*DmBJOp#Qf3faQjBkfs z{A1U!jQ5-~Q#v4DXU+F!JS%CXY~h0bfgovKwG*!#VnAW=_h{xeXgh&FMBe~(#$T~C zG7Q2~IUA~zekHlN`FZPce{_LS=mXbDx$c7A=0S7`7&mEA2&9urdiHeO)p~gtt}&)p zQt)CBOf%XmXM&>8QeP z&U=iY$gvSp#`2n{Y_3OUk~YkiQazaXONt_Bux74m;{fq9ob zvH}1#3}S<3PhK%K;eihMp|V^sYt$q>efF=LZMW83N=oB#zXO%-p{f;6fBOr#Lpafq zC;W8p*W8nwAIt>M9}k@yEUT~I@wK!5qKOrZGbl=@&oDcF=*1rJ6Csq4T2ovI7$n1+ zhDn<_-`JvmPXYtHk!52CED|P*#~%6*&P2Ms)ZX|cxEL;m?G0E7kgn=t2wtlj@cRtt zr4C3f)lFg=#vFnHtY2IdTE*Vl@W^-@*(lve<9Fe=Q<*Y<2))a367A6YQ{bpBL#vAo z7!`@zGpeCtQBPd=A{F3tq^4#*tqlH#FR*-5a&}2p1Gmz`J20>;o1~KO*#A^O9-Rw6 zhZo-{(9}rJr}HZx@Z?>ZV?e=p^h3f1vs);*Yg}4x%gH|ziiQbuAl8pgOMgq(+~E}` zv!y14iUp{1+BysG<(ryEjpn>RYDM#gSK#J985uiHU@NJ{|*H8U7MLNZ-bZbzHNS#o$^Q(HPHY1HKM G)BgbAh~0Vs literal 0 HcmV?d00001 diff --git a/plot_data.py b/plot_data.py index 121c623..3fcbf43 100644 --- a/plot_data.py +++ b/plot_data.py @@ -2,7 +2,6 @@ # example usage: python plot_data.py noattn /tmp/draw/draw_data.npy import matplotlib -import sys import numpy as np interactive=False # set to False if you want to write images to file @@ -37,8 +36,10 @@ def xrecons_grid(X,B,A): return img if __name__ == '__main__': - prefix=sys.argv[1] - out_file=sys.argv[2] +# prefix=sys.argv[1] +# out_file=sys.argv[2] + prefix='atten' + out_file='draw_data.npy' [C,Lxs,Lzs]=np.load(out_file) T,batch_size,img_size=C.shape X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas) @@ -53,7 +54,7 @@ def xrecons_grid(X,B,A): arr[t].set_yticks([]) else: plt.matshow(img,cmap=plt.cm.gray) - imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif + imgname='%s_%d.png' % ('img/'+prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif plt.savefig(imgname) print(imgname) f=plt.figure() @@ -64,4 +65,4 @@ def xrecons_grid(X,B,A): if interactive: plt.show() else: - plt.savefig('%s_loss.png' % (prefix)) + plt.savefig('%s_loss.png' % ('img/' + prefix)) From 146c152be9aa3d12812a289bea8ecbb59b31bf31 Mon Sep 17 00:00:00 2001 From: zhaoxm Date: Fri, 19 May 2017 19:42:01 +0800 Subject: [PATCH 3/3] Revert "bug fixed." This reverts commit 86e2467162a7c745445abd05a553cf01a583e51d. --- attention_test.py | 77 ---------------------------------------------- draw.py | 53 ++++++++++++++----------------- img/elephant.jpg | Bin 10648 -> 0 bytes plot_data.py | 11 +++---- 4 files changed, 28 insertions(+), 113 deletions(-) delete mode 100644 attention_test.py delete mode 100644 img/elephant.jpg diff --git a/attention_test.py b/attention_test.py deleted file mode 100644 index be7abb4..0000000 --- a/attention_test.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python2 -# -*- coding: utf-8 -*- -""" -Created on Fri May 19 14:15:42 2017 - -@author: zhaoxm -""" -import tensorflow as tf -import numpy as np -import matplotlib.pyplot as plt -from scipy.misc import imresize - -A, B, C = 100, 75, 3 -eps = 1e-8 - -def filterbank(gx, gy, sigma2,delta, N): - grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) - mu_x = gx + (grid_i - N / 2.0 + 0.5) * delta # eq 19 - mu_y = gy + (grid_i - N / 2.0 + 0.5) * delta # eq 20 - a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) - b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) - mu_x = tf.reshape(mu_x, [-1, N, 1]) - mu_y = tf.reshape(mu_y, [-1, N, 1]) - sigma2 = tf.reshape(sigma2, [-1, 1, 1]) - Fx = tf.exp(-tf.square((a - mu_x)) / (2*sigma2)) # 2*sigma2? - Fy = tf.exp(-tf.square((b - mu_y)) / (2*sigma2)) # batch x N x B - # normalize, sum over A and B dims - Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps) - Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps) - return Fx,Fy - -img = plt.imread('./img/elephant.jpg') -img = imresize(img, [B, A]) -plt.figure(1) -plt.imshow(img) -img = tf.cast(img, dtype=tf.float32) -gx, gy, sigma2, N = 75, 50, 2.0, 12 #(A-1)/2.0, (B-1)/2.0, 1.0, 5 -delta = 3 #(max(A,B)-1)/(N-1) -Fx,Fy = filterbank(gx, gy, sigma2,delta, N) -Fx = tf.tile(Fx, [3, 1, 1]) -Fy = tf.tile(Fy, [3, 1, 1]) -Fxt=tf.transpose(Fx,perm=[0,2,1]) -img=tf.transpose(img, [2, 0, 1]) -glimpse = tf.matmul(Fy,tf.matmul(img,Fxt)) -patch = tf.transpose(glimpse, [1, 2, 0]) -Fyt = tf.transpose(Fy, [0, 2, 1]) -recons = tf.matmul(Fyt,tf.matmul(glimpse,Fx)) -recons = tf.transpose(recons, [1, 2, 0]) - -sess = tf.Session() -atte_patch, recons_img, FX, FY = sess.run([patch, recons, Fx, Fy]) -FX, FY = FX[0], FY[0] -plt.figure(2) -plt.imshow(atte_patch/atte_patch.max()) -plt.figure(3) -plt.imshow(recons_img/recons_img.max()) -plt.figure(4) -filter = np.zeros([B, A]) -for i in range(N): - for j in range(N): - filter += np.outer(FY[i,:].T, FX[j,:]) -plt.imshow(filter/filter.max()) - - - - - - - - - - - - - - - diff --git a/draw.py b/draw.py index b82737f..c8fdf2c 100644 --- a/draw.py +++ b/draw.py @@ -13,8 +13,7 @@ from tensorflow.examples.tutorials import mnist import numpy as np import os -os.environ["CUDA_VISIBLE_DEVICES"] = '2' -tf.reset_default_graph() + tf.flags.DEFINE_string("data_dir", "", "") tf.flags.DEFINE_boolean("read_attn", True, "enable attention for reader") tf.flags.DEFINE_boolean("write_attn",True, "enable attention for writer") @@ -26,15 +25,15 @@ img_size = B*A # the canvas size enc_size = 256 # number of hidden units / output size in LSTM dec_size = 256 -read_n = 2 # read glimpse grid width/height +read_n = 5 # read glimpse grid width/height write_n = 5 # write glimpse grid width/height read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size write_size = write_n*write_n if FLAGS.write_attn else img_size -z_size=100 # QSampler output size -T=64 # MNIST generation sequence length +z_size=10 # QSampler output size +T=10 # MNIST generation sequence length batch_size=100 # training minibatch size -train_iters=200000 -init_learning_rate=1e-3 # learning rate for optimizer +train_iters=10000 +learning_rate=1e-3 # learning rate for optimizer eps=1e-8 # epsilon for numerical stability ## BUILD MODEL ## @@ -57,15 +56,15 @@ def linear(x,output_dim): def filterbank(gx, gy, sigma2,delta, N): grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) - mu_x = gx + (grid_i - N / 2.0 + 0.5) * delta # eq 19 - mu_y = gy + (grid_i - N / 2.0 + 0.5) * delta # eq 20 + mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19 + mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20 a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) mu_x = tf.reshape(mu_x, [-1, N, 1]) mu_y = tf.reshape(mu_y, [-1, N, 1]) sigma2 = tf.reshape(sigma2, [-1, 1, 1]) - Fx = tf.exp(-tf.square((a - mu_x)) / (2*sigma2)) # 2*sigma2? - Fy = tf.exp(-tf.square((b - mu_y)) / (2*sigma2)) # batch x N x B + Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2? + Fy = tf.exp(-tf.square((b - mu_y) / (2*sigma2))) # batch x N x B # normalize, sum over A and B dims Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps) Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps) @@ -76,8 +75,8 @@ def attn_window(scope,h_dec,N): params=linear(h_dec,5) # gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params) gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(params,5,1) - gx=(A-1)/2.0*(gx_+1) - gy=(B-1)/2.0*(gy_+1) + gx=(A+1)/2*(gx_+1) + gy=(B+1)/2*(gy_+1) sigma2=tf.exp(log_sigma2) delta=(max(A,B)-1)/(N-1)*tf.exp(log_delta) # batch x N return filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),) @@ -196,14 +195,13 @@ def binary_crossentropy(t,o): cost=Lx+Lz ## OPTIMIZER ## -global_step = tf.Variable(0, trainable=False) -learning_rate = tf.train.exponential_decay(init_learning_rate, global_step, 10000, 0.1, staircase=True) + optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5) grads=optimizer.compute_gradients(cost) for i,(g,v) in enumerate(grads): if g is not None: grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients -train_op=optimizer.apply_gradients(grads, global_step=global_step) +train_op=optimizer.apply_gradients(grads) ## RUN TRAINING ## @@ -213,7 +211,7 @@ def binary_crossentropy(t,o): train_data = mnist.input_data.read_data_sets(data_directory, one_hot=True).train # binarized (0-1) mnist data fetches=[] -fetches.extend([Lx,Lz,train_op,learning_rate]) +fetches.extend([Lx,Lz,train_op]) Lxs=[0]*train_iters Lzs=[0]*train_iters @@ -221,23 +219,18 @@ def binary_crossentropy(t,o): saver = tf.train.Saver() # saves variables learned during training tf.global_variables_initializer().run() -#saver.restore(sess, "./models/model-10000") # to restore from model, uncomment this line -x_binary = np.zeros([batch_size, img_size]) +#saver.restore(sess, "/tmp/draw/drawmodel.ckpt") # to restore from model, uncomment this line + for i in range(train_iters): xtrain,_=train_data.next_batch(batch_size) # xtrain is (batch_size x img_size) -# p = np.random.uniform() -# x_binary[xtrain > p] = 1.0 feed_dict={x:xtrain} results=sess.run(fetches,feed_dict) - Lxs[i],Lzs[i],_,lr=results + Lxs[i],Lzs[i],_=results if i%100==0: - print("iter=%d, lr=%f : Lx: %f Lz: %f" % (i,lr,Lxs[i],Lzs[i])) - if (i+1)%1000==0: - saver.save(sess, './models/model', global_step=i+1) + print("iter=%d : Lx: %f Lz: %f" % (i,Lxs[i],Lzs[i])) ## TRAINING FINISHED ## -saver.restore(sess, "./models/model-200000") -#saver.restore(sess, "./trial/t32z10/model-30000") + canvases=sess.run(cs,feed_dict) # generate some examples canvases=np.array(canvases) # T x batch x img_size @@ -245,7 +238,7 @@ def binary_crossentropy(t,o): np.save(out_file,[canvases,Lxs,Lzs]) print("Outputs saved in file: %s" % out_file) -#ckpt_file=os.path.join(FLAGS.data_dir,"./models/drawmodel.ckpt") -#print("Model saved in file: %s" % saver.save(sess,ckpt_file)) +ckpt_file=os.path.join(FLAGS.data_dir,"drawmodel.ckpt") +print("Model saved in file: %s" % saver.save(sess,ckpt_file)) -#sess.close() +sess.close() diff --git a/img/elephant.jpg b/img/elephant.jpg deleted file mode 100644 index a21e822ddffced361ede5e84dbfa0cde46be9231..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10648 zcmbW6cQ{=w@^g%!tnDL`{?sW<-lFY8ag;A)@!*B7_8^6Fpim z7$PKE5>i~h_r1@(&vXB}_pbAt{mi{MlEwmN@2m}CtR~O)N z8K426qND^uhEAaV*yFcmco?Nx_XCIA@_1R^5`QBaVRU-b^TdJZ6GreNWd zRHtM$asc!Dut`Ov7EuY@Y3yS+{`eg(?TC$}rnz>V1HySjP)JxrR7O@#9)VQQxT~q9 zjY8{~m}1P#EiA2^oLwF~baivb`9AUU4+sp3dKMiM8yBCDmY$KBmHi?o_jPeeX<2zi zWmQviOKV$uN9UXOANmId2}8pp(=)Sk^9zfgmNqttTiZLN-MueI$0w&}KYso?zxWRq z5CHmbtgH6F!Tt{}<|{5Ta&izk_&;1gGXJX^#7s`XCrQboZUlDlVda;Kpklj|TGZG_ zEg)_Do!t@pk>(m)X5+@uf6)Gm?Eeld^8btMe}VmPt|b5+2zWJlAZCCnATfBX2I%z% zm>ZVGRki+uK6$z#Cr$rZ#Tmz6QE%o2#c-95nn%7TEX2{c*O@#WLQmv}RIr`mHTnS$ zAb;InJp{&?`Ot&t3vt?WiAxD)A0hp>$)h z;kdtRb7;tL=9*{s$!S?gNt&`ILtc-e!f8nYj28?zB~7WqowC76`D9bGIEA{R+6|I;v1+l=^o3hW^tg2Va zkJO)phfzX#r@QERG0!pKDRj7UmbL$eZoQAnaPzg`mvotLo zdplxcfSa*^nmVUWckaG5+7Xj59HI0^jk~GWE*Y$*r>Hm(WQX>~eQ#%;&UT&z%{Imy zy7P%+>ztlEE~*gh447auGk&xCP+5T!Av0|)OK#Gr?>4Aw7ChMq?f*T*4pB@pdz!`F zte3En)5*a$|0dn{=S?EptK7s+{=6L~L{(!D-+>&x`~|FgmP zrGxQhxhI*=ABI_#XbxNU1V2RlMeOUj$<@QQDnLusElI6~$l=XA@qwUAz}bM2yvXMj zmPG&~R;+URxn_wbjU2}OWibSv{?{X+>rt9IQm{+W(kb@C5BD_sBo3j+X38gD6++pR zwYg;lIvg$jysqZ%_DRr6#|1~7lu#MKHzObWEUuff_*V;YS5Y&`Pp2?JnW$8)e2t?7 zr@}7IZf6yFQ#h|t(s#>K+F@Jexs%YK;w|Q6A5;ylZ;Mf<)U9Gz&g=nQc#so#zP7kd zT9@{_os#uAiQ{kB#as}*mRjH2$v0YIqW$9Np~EnR47s_B0@o%Hz2d4UYk;sL*xTd5 zq=lIa9!n(@8MIJ%^MltAV!+jMCEK3SZZ3bXEXm32LVCs`8fsh&ns08|@AG%A3qAf* zd&cf6!A_F-<|McJ)sm0?&VU6ICB4MxQgI;0^kTMU_))E3iGB8n(uvtd=k4WMM~ds$ zg=T~1Q^ru&m8L!>Ixlw?zuio&le149FitR6)4)D8`LpSy7=3e`^`>-lV!e)gHn)3? z$re$`^cw;MJ}FrSGuv&aHXzwfQ<~CEQ-^UYrps|W8CL7Xjs`6fn+eX?Xd9`mkD9R~ zfh-03hk=9ccHs?5ChUkcM&%QIlW=t?k2d>PIwL;15RtV9Y6Pn!{h9IGe#A%-3WBrb zH}&3q!z^dyPKMq)$iX(9JN>}S%#H@SnS=KkM>{h} z&^z7Nb&S*o-O%(JD(sSt!pi5x{mvi{WdlMbggzN-GQ@J>83pSBx(+d9#BJXB>JM@U zD)Jy;_+~9BXx+c;S$|Sn+n0P)M;~;CMM_^1b;K@EE6ExS=1_KbMVp;7NE&KRfFWnq zMnvSd(^n6TUabY104B(>-li893ubf<4HnL)^{Q=7A8vDbw6vuPwI|%)6ez6qSx9<0 z@$9h9W|6Zf_cJWjMDWX_~YYNWu840DpN⪚YJlfES*=ejxara_ zV}Ztt-uA1E{4XN%8?bIe63fxb$)&QzA{Ng?0O5h`d|bu*LuJ>NK9y)G=B?G1e0<%^ zhH}zPp3l%3KY5t&&?ZVhn5wsQB*m(~>nA|{^(#}w02w0(o(Bro7zrA=Ol`QX)G6cO zQuLsf0{-;XSDRP#fA(jj1SlPt4XRWAE{9_0VIMRcFPdMxREmU z*xCtXMdC9NyMHa=9&P3b<*?d2ge}@Xx^BN4s=rOYSY!xP6(!2_y4oBw1~cOJ4JX6@ z4IzHYy0L{TvpqbjxcjuvX)c)!s6JIUrErY?S)HpCD>YC+$q^=+;5!`50O8@z9N5c!Q% z7p5kr31y=GCs6o49Gc4A7os?uMD^&$&$7dRe1ryBaDi!o9zZAy^Ih3}nkoEZ?!F=p zgb#^Ex)^$+Cf~8&>1=df7vWa+D4h}-W4S}P!6mR~2TRMyW3bPAWN4PF=^G68IUrbjjSD(s4lo+@os+Vbln)>o6^}eF16FwjjAn-wd~L)p(QAp-dqY4U%h8Q zo}8`3oN5?te=m}nfP`*l+DV|`sk1a^lsXKCCqqAWa)HY)S1wE<$hVvd=w&=_X@>;? zv`FMU#ktWh>3`Y6InR zQ`q8Hghr0jd}wOK)c#2HuR6B6R?AXywyz@~)Gl*)FKMjEC1RYt#&Ob01PAHxXx!#{ z7q89$4%;jhx6H%?l$doPTFVm(^}MApwXdd5h#}a*Ycy%l=&$6d-lq5%1d5?Zmks}PiKBFaHDHOlsi z{XAb}P_T89@-nonR)Bt-k5@F2E|(I;ZXb`)vg`nN${carJ?}gNbyg71UyKkyuM--8 z6Mn>%T(!G~$uEdURn3G6gZn{meew(cy}JZh73HXtM@lXNbJZMeG+TH}&8NRz8#QKx z@mT+5c z!=9s3Dbno&Cwq3>CV|7XFsjbZzrsL)HrX+CZL3^< zN+=1H;?#==Zw3$(9x*k>=?VfK?#ji0hy1yUvlfbY^o2=e+9Kuap;`sTAqR9sA}8iN zXzKhC^MDxtBiQ{s?#fBdGq(2x1lwkR!cF^T4TZweTRM1!87k-C>1Gxw3dJNvDt~$S zM|E&F+j-PW9^I}HKtBs3cSDGXZeF;{Iw{3nRX!U6f&LBI73aA#`hk6zW8W@fOh_Ak zRKPf&t%xSa_hfWMK$P_apmgi@367lRd~wQeOwf~{IuYxbL3hC$$&mAoOMq%Lh>FV{ zMJ!Tc+%`yxYS48tsqd_syv1?18&^14Y8`(~f_;dGlmO`DY*YeH5{)Z{REm3^&N8nv zM;ipw?Ubz+puu=r{Xbp2gETEltHw9QIL>j83+{Ee`E_(-U^yYL;=KRTdzBSH`3M zjfS1=85VxBFBY@S->(!7QLDm<*|Q}r-m+{&W!?MfNk1ki&wVA23fzm+ zg+e1+^tZCa;iF81Ix$Od5UFPC6PD?yAoHX42#q}guS(-4q;htfu3LdeR2t79y|3%~ zac0}JGxEXE;(Ge2>PGJA4bwE(`Zj7vrFBPwMf*f{_YE!NG*14cRVRkK7g1+(*H?F% zbJybdAFky=ew8}xU4V>)0t=p1NSN_27tjy^#xZeG3bu1`Sq*=Yg%TOBK z#e4&9W-+vrI*0ApH{ox%QzO-TTK3k9!S4tyDfm@Glm$`3B;%1j;eP6K-RQ#f{bv|u zC9aur>1j@+yhFOCd!CIq6`RpSEUwo|?PJMkA3%-t>Q=&zX$ zGouAqo%W8{g0v9%S5O@_JFB!Ojk30mD~@jn`rQ>QkoD#16sX<9sdH+%RfiCq;eeD;&J`1TYkhSu0Kp=a=f|bi!9pprzwH7vVqX|_SK!!c8 z@PjI8t%N*0=;&Wqsm9OySNqd2h^M9Wxep$Pa?C<1SH@P#{3|rE>fxGUD6iFk*bG2U zcy|>;14;#7*9~+j)YG|M0zS!6J~aA~mneM+Xi}tEq4THv>puNnGJ@cavyQf(+ZjGl zY%WieLr*F|*l57p5-@90S2*rf{jqgH>1 zg;}^4A@dCZTPn_<#*n!AAFNKwH2%Wzp*jB<3JOs1ksQt@!%82n?X_Dee{6eWCw+_p zj}Vu-t&ot+co6bnsrrlAolAg00e!{xR5bkzg1BdG(r=Jwi$-@R--59Z?d!VqBDgOB zlc#kB_bGHzZ;fVJ-uW}4Z(8UyUp=~^rylCOv`$2*ZrJ`An)zhV7@h+z)2{U|V16TQ zG3rGkB}Ltyx&#!?cT%@_>dzg|V~KU^U9ndor7ifTqP4AFhoZ^^vAlsd)5R8N;wNNr z=+9~k_sMTqH*rdZ!`f}0)ambIQbyw=jP>Q_@{}!aWV^S+%wGx*3kX@8l%z(6%oyc` zYqRUxfT0C*<`Z0_lzS!*=3i}ns=8HtVi~Y%EPm8%pK3!~kM=U0nMBkq$$C5n;@J^rBDv&$Qu!;ZC0agW=i4MVEEQ1OiLb zV3+CXPJ%mC@Yu0niulpJ7EiBt8kc~Dauj$tO^XULSyXkn7*6M%14!M)B~KTM`7Rl( z4Uu&h8|y71l+H*;MR^=cs|ZL1=9j9(Ac8QT12002ZrwhY=?$Ky^Ox1){j49bUQj|z z!2ei3Knx4E$}9#~dwT-~vXouPR#0S#PP|3~=-H~oGtzWAw%oS1&{6cRV_$wafCFGi zRxfVQL!s02_Cc|F8kjh)gMjc(gmbFxcb(N;0($*=%nD_Pp(DS@e}XOn^||ZWxDG=p zbZ@=gwD0Ez*B?K=%`thYkm~ns7ISWJvD4aZXNeVN=gP}dyfJMguw;=x4EHf6a1Q1B zq@J4>jRCt~IuO;t1aId;x%XLr>ui4A8ZxMu)FEZ&22xK7&9}hf`oZJV-5fJUW@;s_&r2uA z`g!taveu!G^4u=wL;6z{l0J)YqpEh9wjQe{saR(oFvKW+Qsa0N{F^iH;HR=#NWY)# zc;%;$@dF5(p?omh^fz49J*lz6=AMG4KvT>kXU?u$WmA1kw^M)wJF#r}q;Br`UWb|a z8m`(8;`)axa~1IDgL~IGGq!m1cdF?fH&vQ=f@WDzTu4}MvBK8TbfGBUYzRXncaWJ4 zC6Zgxq^?U!3eB(eRh{M-27el~z%%hNza}wu7W1O@`kUmDw_M@BoA522ij2$h#R{<< zGq-P;$mle&Mjnx941zi1edE7+<)Q|+1pIT$MK!VCZ6D=VwP`=SXrg|? z4)U8>Dh}|MT!~Vqa4T`AWrvibAs9MStfJ|O?V^Ek^gw;~c0_6>Wx|_JPi$9+5eJ$y zy5`!GSgqeSWP$Lr{eqTZrK#RZnCFW0e!|NVff*R@*bFDezE=XU9!l07iBiuoZ8}v} zD^vN3o7uBk3X47Sv{K7zYuY$p``{QFwqd&y^?h^jQ(|s&NRvaj0C`;BzlVR0EN1&n z%Db4i#$BFBQ4`xTYcs>$8`h&1v)XJDIoC%y zmRr1dD+U6--hFFN0wi+gvUfLm&GX(_P8AvA75?3Gx*4RU!0^t*C~zc>s{#t+9)~}= zb1XGJTo-Z8-~J&u*9;4WsOo1-tcI2qaJiEz%Ea<`*juj`z`{}3)I!PTgPF5|p!`~~ zmx(%+(5DNm-!FvU{$7rOeBzAwmHv-?jT8pi2p~K zFWi5QSGZ$!2zV`S8jTnSbt;t$o5dZrzgnVc*yw4sEV+sVEOA?t-`npyEk)Q&n{iA?qx$4n3q%`I z);DtmP7On@UScDLs#Q2;J8K9Oks|1HufmMm9hpA3lIq%)|H?BJTn!0g(#4K}Wy2 zHGK2Yos1fGIUC(-wiY$w=aSszf0s4~O@0P`3|<)-n~4}3paxe$XA4#%hKqegbK-jp z1IA$iV~={H3UIk6fzR3ml5FTwkR(u|>0`;y?d|pyls@qHGK%V?FQWIeDe`XY?#vkZ zeoRgJil$Exxf=jK+&Ux57!}DOUi4hwou8;z9-n((CzTf&bFu#LtS-;s`(cv3oaNs4 zyNl#1;1t9FRgX2< zXWvHfQe|ylhCI*$H?DFoZpyilF_98_b((BV>+VBK3l01nWsH6TgcyI#lUUmVU2q;% zWO%~Azd92-H?Iizv2`o5QE+8gV#Q||r25zWc(<1^tu*hxAJ<%C=ETrj( z(d-iXi=z+I(Hu?-fFPPYS5TdHsUMYlxN510*QX*pPu5Lae#x%ApdEa9_T{nMe4K{ljub9($Vcr;h|Bc#OSxI^l0en<; z-W3~ERsxL?LT;8!D3}|GXQZr^Oyho{&JB$8SXg^|CtHz*&j==mI5m1vb&{%;1Cl4E zOfi9|sGEo%tJD2I&_}KipvC_~* z;uCf+YO{d;>J&!kS=Ax%03+X;zfo*E$_LB{Qvl4Xiqv!sNZAZZ=974>KKiYcL}7Fv z+^fDy{zGQzo=mtYv-jf3Mt?RR=oesIzklIYj50s-1RKr3?HvO>r^s%or)XCmic#c5 zZQF#0d~uS!;j74b&Ap2;vsVblY8&=!C+i-e%^c7HF7$d=>eKd2EQULdI5l5Z82ZEN znK|>{ZTt6`sFMLgM@>*dPw)Padg}Py=EgQZP4Ul9n1U`oUb|g{W(vpW(CLRTJ}xJj zv8?h7ms3PJM|@V^AMn`@N>b09OD?E-b;uh)EVsIn9CUg5>(b%U$tMF^a9Z#2pIScQ z(rcxGY?Mo(_Glr9s*X&j5t=!$PKzu(Sw~c6gp8wUi-g$MSPX{ zZ|9#mO@f!QjAs}#8T{5+SHe=wxL;F^0ATBQ1?GQxx@6~W?XF4%;PbtN(i=~cQ+*Nj zapJ{avzB8pXxLLC`v*_jNnLXjVtmy81EokOHOASKzaAM*r|e%hirtN5f>^cm)pkX( zYDO4-Z-tXReJQ%F{NHAqaKB+fP#MRhFJKt*ezs}2O)fNW-Si(%yS&wSi@h`Cdh%eI z{s>VyVQ#&3=HMX8MIhcJ*~rYPtmSlGCXbEe<(&x{-CdcRPzl;xlS|8&LSJ&z(x?3x9G+ zt&9|I7coC|m(XI5$6!v7b*3;CNi)qgepnWtgao6S!oSR7@~#L!UIFIB-wR~esJH66 zo40y+S8fcvvHpC`jPT-X7mBo)JKPoV^_$(iUe_jE(qs^4Z72Tjt?h$*0bbnKZg!CH zIXj&@nJ<5i^Z7q=nS?Kts_T`CNaZQ>zW1-Y5Y{*e)HejpfFKwXX8 zM&y3DI{XL!3vMB_^Cs-A#G&-sOQpCXx|edN;B*U_!R^ximr*_JvRVM3*qe0DL2u}P}t&LHpHhMlgbiD^N_vTR1NpV#E$eVHhDz@CmVtkJAmjIhE z(TVO>b@$1R+)mLTZhh-xJ7Q@Y0;GR;r%ZQTpvy zQ-(}&szqb_ZR7%x!!o~KOAWokhrOtGvl1O>ue2TH66l6a(4(eMjA8^sP3r(Mb|6;g z&-y6@QPd!H%Xtg7i>xc`yEZLX(k$w=RDxGm=BqaNm_tn@R(yolu1GWjl0zmHlfaxq2?aWf(>0Uqq}VU^G0J~@IF-#j@Z zs3w_w=C5kb&-blEeY7c4IP{G*NgV63CV}9r3+=~~B=hl$=NqDO+xBu>#6OpSi|V*G z%jb%$6B6KVHw~HCEj!CN8t`eQZ)ShVIKA!siRpBoQ!9hMw$J2#LVCa44pChkYj=(Q zTYLFgP4#nXr&T57t3JV648yl2Mw7Myz1oA{dE&(*Bw-Vk+*4!F{iye1UhQ+jPKFQ$C zjOSfNmb>NN74^E&8K>lON3`&pPnxP#GkwOjZycHe@I>d_WY>Co?A=haDQigY>Ape2 z>LJmGcdZYk(jnMFr(U}V96B&-`E3MMu|N0LSN?UiaArrU(2KE#Cponx$$^;=@8v-D z8>zykc;voa%o`hj@8)qR)}a9;tQspj>CR~pk50H@)Ws{jp-w38W~$0-)bSUiqzD&M z+04PgTn&^Z$1U1tcHH_kamPF%->*u>;uk1G!q_=}wwJWiE5!5RVs69@`8(YxF? z$!woS&k0-5pXVmrQh)Bg?RX07dIDC#*cuKPuXBPKjC{K8b!!97ZBkd}cHp>OAeVHU(hg>Q5| zDOl;ftF@#vn{3Z9gfVJqP?Z>Wl6~lNV+FNl|L}$6YUB%t)DO1}beU8|o&Eid-|Fb= zV=bUOuR=9s9sDTz;4wB50(+=<&a!ot^IoriANB_H249X39U4h-ss!vRBmQMz3);Eptw0_(#O!QSQ=bf=?rm6G# zA!1}HAGC|T7wiZn5P>H36#A!UFz1 z$#r)75g6;^I?B3|U--RhCy4*nTsqrGsNt>#8KUwaeWJDCp`sd}R*1nRp!%$`v_1Ji z#~Uroc;DUBW|?hxX)8^6uan`2i}Q%pwmBpP%b?y%JkK@f!*DmBJOp#Qf3faQjBkfs z{A1U!jQ5-~Q#v4DXU+F!JS%CXY~h0bfgovKwG*!#VnAW=_h{xeXgh&FMBe~(#$T~C zG7Q2~IUA~zekHlN`FZPce{_LS=mXbDx$c7A=0S7`7&mEA2&9urdiHeO)p~gtt}&)p zQt)CBOf%XmXM&>8QeP z&U=iY$gvSp#`2n{Y_3OUk~YkiQazaXONt_Bux74m;{fq9ob zvH}1#3}S<3PhK%K;eihMp|V^sYt$q>efF=LZMW83N=oB#zXO%-p{f;6fBOr#Lpafq zC;W8p*W8nwAIt>M9}k@yEUT~I@wK!5qKOrZGbl=@&oDcF=*1rJ6Csq4T2ovI7$n1+ zhDn<_-`JvmPXYtHk!52CED|P*#~%6*&P2Ms)ZX|cxEL;m?G0E7kgn=t2wtlj@cRtt zr4C3f)lFg=#vFnHtY2IdTE*Vl@W^-@*(lve<9Fe=Q<*Y<2))a367A6YQ{bpBL#vAo z7!`@zGpeCtQBPd=A{F3tq^4#*tqlH#FR*-5a&}2p1Gmz`J20>;o1~KO*#A^O9-Rw6 zhZo-{(9}rJr}HZx@Z?>ZV?e=p^h3f1vs);*Yg}4x%gH|ziiQbuAl8pgOMgq(+~E}` zv!y14iUp{1+BysG<(ryEjpn>RYDM#gSK#J985uiHU@NJ{|*H8U7MLNZ-bZbzHNS#o$^Q(HPHY1HKM G)BgbAh~0Vs diff --git a/plot_data.py b/plot_data.py index 3fcbf43..121c623 100644 --- a/plot_data.py +++ b/plot_data.py @@ -2,6 +2,7 @@ # example usage: python plot_data.py noattn /tmp/draw/draw_data.npy import matplotlib +import sys import numpy as np interactive=False # set to False if you want to write images to file @@ -36,10 +37,8 @@ def xrecons_grid(X,B,A): return img if __name__ == '__main__': -# prefix=sys.argv[1] -# out_file=sys.argv[2] - prefix='atten' - out_file='draw_data.npy' + prefix=sys.argv[1] + out_file=sys.argv[2] [C,Lxs,Lzs]=np.load(out_file) T,batch_size,img_size=C.shape X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas) @@ -54,7 +53,7 @@ def xrecons_grid(X,B,A): arr[t].set_yticks([]) else: plt.matshow(img,cmap=plt.cm.gray) - imgname='%s_%d.png' % ('img/'+prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif + imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif plt.savefig(imgname) print(imgname) f=plt.figure() @@ -65,4 +64,4 @@ def xrecons_grid(X,B,A): if interactive: plt.show() else: - plt.savefig('%s_loss.png' % ('img/' + prefix)) + plt.savefig('%s_loss.png' % (prefix))