From 634ef86a2c0a753c7999373b27d5d1607a2aad98 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 20 Jun 2023 17:42:36 +0200 Subject: [PATCH] fix: example test fixed --- examples/siamese_glvq_iris.py | 3 +-- examples/siamese_gtlvq_iris.py | 3 +-- examples/warm_starting.py | 4 +++- glvq_iris.ckpt | Bin 0 -> 5331 bytes iris.pth | Bin 0 -> 4927 bytes prototorch/models/glvq.py | 20 -------------------- prototorch/models/unsupervised.py | 2 +- 7 files changed, 6 insertions(+), 26 deletions(-) create mode 100644 glvq_iris.ckpt create mode 100644 iris.pth diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index b00e308..ff63da9 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -51,8 +51,7 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( distribution=[1, 2, 3], - proto_lr=0.01, - bb_lr=0.01, + lr=0.01, ) # Initialize the backbone diff --git a/examples/siamese_gtlvq_iris.py b/examples/siamese_gtlvq_iris.py index 4f036d1..35405e1 100644 --- a/examples/siamese_gtlvq_iris.py +++ b/examples/siamese_gtlvq_iris.py @@ -51,8 +51,7 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( distribution=[1, 2, 3], - proto_lr=0.01, - bb_lr=0.01, + lr=0.01, input_dim=2, latent_dim=1, ) diff --git a/examples/warm_starting.py b/examples/warm_starting.py index a9e17dc..22acf51 100644 --- a/examples/warm_starting.py +++ b/examples/warm_starting.py @@ -55,7 +55,9 @@ if __name__ == "__main__": # Setup trainer for GNG trainer = pl.Trainer( - max_epochs=1000, + accelerator="cpu", + max_epochs=50 if args.fast_dev_run else + 1000, # 10 epochs fast dev run reproducible DIV error. callbacks=[ es, ], diff --git a/glvq_iris.ckpt b/glvq_iris.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..45793132ee27222167aa307aac09d5a032853016 GIT binary patch literal 5331 zcmbtYd5{!W8SmZYoGeGU5~2x6PzHgWJpc_5cHF=&+GMi~iMG+TdDHWDy4jxY`MO^( zEZBh0~2O0>WwW+zO z4JK$Q-0KS_xL&Xf0TT_wwQX()-07!o@(SGJg$=lpf=NqrFgdNwk>H*yTrpVY`VPy{ zJ%+n|m|~E`COr_A?Wc5)=L5?w=z=@G>*)inDVUm0Ym-T!>#-t-RK$|%v|T2)B0CaO zaG6$zdzzt9YrwZr4q%#krGV*MR7I^<8~4KuwNb6x4>N~Nsng{GX60bEcEx$cbej!w zuSqxDQrUI5BYc=+qG&T^M$j!V*Cfb!60%XP-w#(HGD2P%LCzO&RSw>w&8~&CS)SWI zyfuQ0wcqSI#ftJ*lk%-_4N7Pfuz(~C!9ule2;LT9|y|Fl3IwBTHdf2+K`cl-WSRiC0#H&=jp)jw{k3 zTv#%;@p>kN=4c%=Q?Xs&hn5gpi4^#G$`spm%RaP)urjU5-0ic3E@`kTgw<(tVhHcj z<|3VPl2)3k66)O{Tt~}#o(aR`zAj8AC|Aq80BfX4GGD8>X`_y26g`d#tqoyal=YZ% z%!cbjSRc(j&I*Ijju|iAD6O-(K+r5wqBLWkoWCdDfsh6wBBevE3uucGcphd5+z`Ts zw24qAGMmU?(w1q6Yx(FbUqt%D#^FG_I;pRZW)%XPdm}n`hR~HZQPWKo(7;3)?;V*2 z+VUg}3xr$3qh#oASy&}&mpoQZwJ+Znf~qxN5*I;?=(9P58_7vAH20u8gqzZu3_I10 zo)ET>#<>v{rN&!FrU4=B;-0G)Jyw95L%4-l-e1`xJx7LqAV~@WUt5|~yu_RUi`ICy zY9V}3YrQy2WT-S@#M)a!_z)@fVX1|JAyr7J-jQjbU6aHfKwV^3SNr=TA?VtoarV_@ z>9!CU$;?-TpfeOqqe!+f23h8kJkC+>R}#Ju0w<_G$&<`T8Is~?au8@TRnZ6w3klW* zc|Oe$EIOatNlJ~v{UO*Sr8I_SD8(6rj0#6dAUGZf65@tXCLv%VqDP4S7_5jW#vl`7 zBdT7;dj_wW2oy`>0Dw)4#B_hEMkAJulFlpEaww{I}H z(O-5gN9pCd?6AF9A`^Pma_z>9Bc2O?&!i}c+<%=^KrGgYK4L) zi7iNsw&CgJI>K`g0}?}r5QcDDbp6FbVNbU<6P3zKXG*zvg&VHNVQ;rKBS|kaVZ!ZZ znxY)nta$Usx~Y&AERQzbp>0OijOmtmhS_{4bMOk!Y;cXBByYmZ4%h4V%go@JK97|+ zO*3?5(Cr=R;b*X<#Xj<3A4w%G`{Q*O3SmE8kNv{bs|+9=?nFgpMV4#mfq=tZiTk$= zhVgc+6Yz1o9hWYRZpBXsxH|`*B+IGZMjj!+d(1Qz5b3K*{lZvaqWkcvUifqfpNTF6 zqQPfFxK~v|_*~lTNyuFZ_eF9)UnTbo0=}4o`^kxXXIW${e z%lF&8@SPA2r?q-?f$hO}yW014`nNy-=YP=+oyY#cha+9>FP)U*R9E}U|9R%zygQfe zhojU(Vb;PUUF$av-t)KSyVnr@4lWx*y}%7N2#=}?+0RNoRrbe*>83BSo$?;y@c1w- z_hE7}>^r~LO{N;GECP?^Z4Tcbj{B0nK5Q5rVG;F-VN$OzkXP;xA|?yZnj^Jl%PY*d7QineM@0<+n)J(FI8lj>v_OW^j4aa~Ja6zT@?=bcUsF<>*hypa zAnnKg1y%n|2)|7euk7?ZD~~NTcJ99m;bfZBbb^v@pfx_8fKzG%oi>R=?MqLcTA1P>+4Kg+1`L68Zh&u?Hdeq5N&fNCU*%Rm*`SgtYRWw~$ zpZF@uGKd>~z95?k;(UYpvJo3rbabLIN#xwP!up=rf!NR-coEL)tktfjMK z;YBT*bH~B#p6yS!r&|tYPcM3E)cb*r$FqNW@UH9=|2&p`Y0K#M=|#_F|MA+R*%zNZ zlYQ;s(Qm2$@Pf{@Z$SUnOVJ-)a={_L`s5hzEsLKVn}{9ww{Jgm_zlQk^=9S&ukYNklgH*C zN3+KqLo-($zwgz<ZYOB<~|QO~`Gn<*h!8N!~KH&nRgyf>rc z6LMG8@>b_alD8TErAhFvuH~(czWhcQ8Lw3F#yQoaPMWkl-f>OCippNuo6B+jy&$fr hxmgh%`Q+ax3dsNZaL9t(qfV4f8uCn4Jp;=ZKWo39-ma?m z>i2&2-mBL=m~ClIB-+~(r?jbw*$KyYf~A%b8maO&x3714;_Ak>yV)%JUWk9b$x=nn zvfUt6aEC$3tn1&nUV+4-JJHYh?w>2p0J+#os?n5;G7oDP_xwcy<>M=&)zErheyXA{}h zZ0lY)8*fc}VcLF1IH!j&BL_3b_`|GyFgpk51l)K%9}NwWrE~b$x!Nq$C)gV)+xm8= zr27V1AHeGvLSqUOOLxL~Y-)B&2y^S8i)QMfP%Z_;LxblxzxZ8vlQkfw6U9yF3TpQ*I3A&@8qM< zwo-b)kdPGrwJlvZkkX(+f#rcVyFqsi=p|hLTZFmjURE4JZQ5J;w6|rOB#cstVk&a& zK;9L7lW&yEwzU?iRp3&LuD4fUMN-2kPH?(czGe$(*+wX&LV+G0oO~a|@{k_m;U}L| z))eKXyESSkdqI`dRot3h;r!ALSj{^xOLC_mnvIf)IE)e7hc)qJg0^qxw)%n6Zxr*E z(NB{U=tHA36et9u6%G4)P%0Yi_@jzWJ9)-mC^159}7>5xW@)29N`KkHtb3!c4Q+HwrpYozNR%J-$Nz1DvsFnd^=S&DNuN-LMIa@DsVO5 zq!Q*k*vtv84(X`qnkdivcs`JHY&k`_=@rXf3B#Hw$=j}Nt*-UJyO0>Y0$XFm7$VKq z#|E1n3!G84Z))9(nbl*771-gWS2q+EC-HAXz$(!+j8IrlJ6qVp9mMR1W*!+ zWftxCN`;e2)Z z*o$26tH4-{>;00H8-xY#AB7L(;6`b|P11s!YZiP^TJRxh!H1;wNXTCkSp-wNQ-V7)2%b0Yq*xo20z4ln z8V<#+!9-wSRWM9vBUT2Eh5qk^BMRKb&ClDR0bj54>1A*j{j4d3yV1`UEfW6qjS%k1 z!Mz-TVOhFWhHvhV@)6J4m-_u>YRKU}T zrcvUFuTF#_MM z<7YX15Bb3&=KBhhR8mf1 z>A1+R8Sa2z4hUX`j66<*;Y-#9v6!XQ6?M}!C{W?|*%q#kypQ+6)7j>TW0i-34}Zv{uiyDh z=FqHN+62^^&+D!af2?)Vsf0a$s`b!LVaXjjhU@Ia>P4u+S_a+k8R=gea72MWXIn?% z*-`jQ4xVGR*kp$|glg2mUpc=ggOx|$@r-&Qu7fQSE>mS+t_*);MA+Yx3@KvB2}270 zkz{BxBO_VOE>^Q9;&e#gLjOJ?;HG#_%CL;0vWkA-c}}tG>w(E^s+N5CXWiWRkg{fn zI6{$;{+t5;;&pF)Mz!{Qp#uL-a-IuDv1q`Hne+>XW)c0|orY~c?*%rzv_FdSsln{T z_Lmo%Cnlly(lt2li`C>rPwmBOX-8l0rKi4DHNIwz`+R%3l2)Er@DN>-<-OWBNj}rq zq0F8%KChrY*}sDN2=?YXS8P>r@8~)c?1?ehoqfG61RF)hX=~GBK<$weR??l;n{Hk7 zpAP(=t~0GSC$?T!U+>J*(KPeA18P_MHR^(=?oYpZ5PBAN9#A)1o7JQ9A50gI@9H_Q z<`H%8iB9#&+s4$%&3B~ldHG@Wu1u%;{Q5C<%L50~N9P|?m)yNlJ^spJ_1L;*^|>RD zt1TOQ)K?B3R$o2Xq!Qj|Z<)AzD`KVVOn4W@@OJn0E;t?D#*IXkb&Xy%sa5Y3(A27T z!qfKROPT+bzh&d_mq?tyigez0!`yqRmZ0lQHm;7@xV*17{aS2nv^oAvw*5biuwP={ z7^$d)uJN%?Nwf*<7`Fjt;lAEx1mCTYhNHI@t|W%(8q79vva8=j;z3O(Q+=!ZLeeN2 z4=noHQEOMBrjw#eI~oP!5mGNW4PT{ef}M?maaXMu++8Qw)hHPEsd~W<3e;%WNofQ^O+vlHn1Q=%RJXug>HUt