From 634ef86a2c0a753c7999373b27d5d1607a2aad98 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 20 Jun 2023 17:42:36 +0200 Subject: [PATCH] fix: example test fixed --- examples/siamese_glvq_iris.py | 3 +-- examples/siamese_gtlvq_iris.py | 3 +-- examples/warm_starting.py | 4 +++- glvq_iris.ckpt | Bin 0 -> 5331 bytes iris.pth | Bin 0 -> 4927 bytes prototorch/models/glvq.py | 20 -------------------- prototorch/models/unsupervised.py | 2 +- 7 files changed, 6 insertions(+), 26 deletions(-) create mode 100644 glvq_iris.ckpt create mode 100644 iris.pth diff --git a/examples/siamese_glvq_iris.py b/examples/siamese_glvq_iris.py index b00e308..ff63da9 100644 --- a/examples/siamese_glvq_iris.py +++ b/examples/siamese_glvq_iris.py @@ -51,8 +51,7 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( distribution=[1, 2, 3], - proto_lr=0.01, - bb_lr=0.01, + lr=0.01, ) # Initialize the backbone diff --git a/examples/siamese_gtlvq_iris.py b/examples/siamese_gtlvq_iris.py index 4f036d1..35405e1 100644 --- a/examples/siamese_gtlvq_iris.py +++ b/examples/siamese_gtlvq_iris.py @@ -51,8 +51,7 @@ if __name__ == "__main__": # Hyperparameters hparams = dict( distribution=[1, 2, 3], - proto_lr=0.01, - bb_lr=0.01, + lr=0.01, input_dim=2, latent_dim=1, ) diff --git a/examples/warm_starting.py b/examples/warm_starting.py index a9e17dc..22acf51 100644 --- a/examples/warm_starting.py +++ b/examples/warm_starting.py @@ -55,7 +55,9 @@ if __name__ == "__main__": # Setup trainer for GNG trainer = pl.Trainer( - max_epochs=1000, + accelerator="cpu", + max_epochs=50 if args.fast_dev_run else + 1000, # 10 epochs fast dev run reproducible DIV error. callbacks=[ es, ], diff --git a/glvq_iris.ckpt b/glvq_iris.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..45793132ee27222167aa307aac09d5a032853016 GIT binary patch literal 5331 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfvL8TK_h~dfq@~l zAU`>ykkOklf}4SXAw4HQDKRI$xFofpkja}dLXv@jp`fxPzbH9FHzzYaqa-ggFFn32 zwWv5VKd+D(Y=V)Vfu3<8O9U4K14D61Vo7RzN@j9NA!~ASeojtma!F=>UNM({QA%o2 zYKlu{atT)fuMe$`uT!ow(8WHSZ3-gN-(^Ct%AX>QGa`F>Pf=j@HT!q{bj0_A6 z28BEk%nS?+$pxi_yxtNeg?xeD4BkxMJZ**i-aOunZG{4z86Ysr@uuwDuY@`^-NFHydwnA}lCXkU55F;f^ z3Z;Syr6YvVjm$|*O3f)Qlz|urb-ho1UOL?IvS7y>70PAsdh?VN%7blaEmZJkY%NrT zn66Y(s2o(NQd(Q68o|lHz!0CCT9OEg_ChtVhtxG9Kq;h9qqa~J99*!N)AD94j^JZp zU`WXUy@&vm{Vw5TWAO27o{eqR2JHMGlGK=6x}(gC8;Te4&IF5a4RlJ zECTTzp(+ZB@{?1Gi&IkyoxB-K!Ag=#i;7b7N(!B83tcoa?7bN?9K0Db9K9Jcz)E-- z7#K2(<8u;=OCatqbnPsSkVW%taz<)OX-;Yp!p&~Ah3*;|wzY*Ga6?NooNEg`5!})U z9aLiyOH1-|6H79aVqR%t4o<&D)D}iY7!lS2_7FHZA$%KETNn*?SPaM&3>gmI3}8pa zBDkdya;RpPCFVdb|%yoCW7ruf?JZ|=*^Jf1hqREi!dlH zfwD6s?ZXp%N^M~(*r+s+9@G>Gijs6p;nLCwP*N>R%*jjvmF3|4nNeGq3C^Ec;6ek< z!P&KiIbd6ILB@dt3B|#AScF|WOCgD%B()e?2;|om7Jw}&%*d`SECQQYj76(+ZD9!( zVX%oj;9@ovQ~`nGps=*IuncTyImplou%VS$bUW7;R$&U4LLHf$n3I!~n4Dc)SY2CK z6H(2;z~GyolA7b3k(!)cke``XQmvkwpO;yZU!-oO;Fq75s-vKuo1c=ZZl$1}o0+Gs zqoAHzmReL9pBGgQzBoC*D7CP`FM=Ocydbe8qp&eU102iHG7GA> z$q!tRrDPU?6*WirFfcIaXXNLm>L=!;RwU-7q!#JBtdeHj|1 z#o(|4XLnHf6%Wp*#fAN`g%fHECqgO&uvp=w+QP}+47G(*G%^BE>@_Kz3bA(@lD*SQ z3TFfr&IAWyYDGbOVp)3OEEH{~g|i{eg*o8L9MoKx3u!LQgE)PDN#TN^!iC_r6x4|L z;=;m3D5jVdE{2%01a8Vwh$+h;rYtWhToF{bva~n?oF5X45_99zi}FhgiVIi87Ot)> zT%!@e1gfQr3fH>Zx40EIUcdbnTtg?NmLwJzu5-74a|Q&?y4%11e(t7NyLMaQdT<*m zwV=3ggS(w;W#1!%9xJf<<*A^CVSGwza$;rSMsEhN{}XeI!6o~qPH;`1n^*y=hf)hS zcY>wU@{3XvlQRmp_<_BeoLEp&T9lZSlUlg76I{fkWTvI1g6h*uu;{iY@8Ok?IIGPzeC?adF|!*uq`j42h+QrHSBlms1p9469&^ z3wOsB?x`)@3yJRZ#N6D(!hP=c@7IZd3<5hkDX};;KBuU-aDQy!0k{o`U>Aea6*$8d z9)z|L%2JDx@{3ao4|Rf@(%>u}pPXNsS5kP`n<+vH)ILj3Es4)5iU*Y^sVVX0nI##} zvY_w?B%mM-gq)(nqp^j@;Bf=;0Vpan3cx;#&r8frEj$hmNvKfa32>s$s4M_A9zd#6 zOHzx93s2S-p3(qE6R4$El$lftYO@udt}Q&H0d{6yX>NRSPGWI!YH{INZ)R|A1DOF9 zD?I1TQknruBcMPLsRp-@fyuvH)pcW|zgRL(qO3W)xOD&2|%PYJJ4)x5G)Vz|+lFGts;7Efj z042@B>+TLJAiWF>rKNfS-s~Kg>XuwN#>2n>!nk`yd<=Fdy&`=BHzzB!J`qX>i2w)U z1X4&3dsNj9cBp&?TVy`SXEzt}rYA5kFn};VpQ)hw%+SqAfuPSsnin|KXYH}qFkIkp zQE4{@y~uT+!;Qrq4%tE;;;OiAjUp-zo2A4RSjOli+?M zV%(1$R@m5}ULQyt7Y!=>K=Qa~ko${|-I@n-I|!5Deq-X?-%t;-;iR55W2;98Q|^(4OHkGpzB30H&FF{K-X)C zt`|AaqUznihi0!4x?beGh^n`MA5E_@x?bdbgsS%nx?U4>y~yblRd0v@n!TpzdXduz zs$K&@G`(i%dXZxv)JH(X(FY{G&@>g`&Bg{AWB@6WW7Y*xUU+YyxBmbbPR={u|5U{22g<<;LQqZq%m+Za58|@L(~EQljpZf literal 0 HcmV?d00001 diff --git a/iris.pth b/iris.pth new file mode 100644 index 0000000000000000000000000000000000000000..7e485c0d5d8cca1092c35f3ddc1f6b4573e25669 GIT binary patch literal 4927 zcmWIWW@cev;NW1u0K5#M44Fll#ri3UC5d_k**R`bf(%jUpn)m5peVm2za+mXIYTcu zKP5G%ST8-NtdPsy*C#BHtB^snv5>L0kVzu~WJgI+VrE`uUV0&OM+7$m14Dd4Vo_pl zYDsEQaUn}`a(+%uYH~?teqJ$`e^E+mQEG}yW^xHvA**IkAzK6+NJD;ZYI@55M6fe3Fcg>M7bT{r7IHzf zaJl8=Czb?51i1>iBN!PN7z_$|BA6K%7?KN03wgZ-N(%V`y&1e&yjj```Mp`Z8QTg4 zIx|4-6f7we3Mv%t0CPmz3q^wp#frh6k54L1O9S~-94sdhA;Q4G013Uyg4Fo(%)I!b z#FEVX;zG#`P7RF=t_+CPhJ{iYyxu$|h0-7^nY`KB3T3?6Kvv2^tduJ$ln*LYD2|W@ zg-~98UVK4nQE_H*NormR%+mPc)RIEQ7pj4ss~*9RA`g+%0Ly7ch=JrXi{sNub8_P0W`NWb zYWYR*fn?M2i@*kh1d9u`!5VcUq%fr7%QH(d;z5o`FD}#tE7OY*LRN-wnm$S=vV5LS8f*_^EC5a`e@hO?fB~U*agJn&;86ZhOFD0|Mq$o3~v?Mh}FTOYf0w7GPzT-i%OZ0cMR4ZX0fUHG#3djhqLOXEg z1jj>eeoARhYN34xIQv0)4jGK4NzmkmZbCtRF)UCW!M<^VBz}n2;zDPzpo>NXC{chM znV6TH8lRI`nOaonid>lF<>{4!{aRGa<&&73l#=KJQp8p0rrB8NUR&s)k-^-N!2&kU zGXs*RyfPFslwu3LTMK=H3VkymDc3Ipl1lwEwEQw4=`J7xlCT0ZASoy)1Cl&~Gav~e zBm)v*p&5`E3Cn=^B|HP-{D=r91_p+-yuwIi7boWzrGirq$V81Qsg*m8Q_(h{Z5GIX|}`KQA?}1QaAmsX3`B&IpmhIL*ew z_}an*q<~2T8yb($SD2K+fl|5{6((o!dP|fPrhrNpHc)|<3MtUiAf-!sNnu7%VJ4(N z%W5yo4l2wkhK6@8xcrRINlZ%3DK5-|7z{~(Tt4}E>F`1{A6$qU7Zzmjdb5=j7J}_+ zEiCe8Y%MH?*i%wcSQ=DVRt$;4^47wNpu$RM6jni_uo@ACHP9%mg+^f=Gz#mXQP=>D z!bWHmHbJAX85)HxrAgqj02DN-C7GaF2O8|HSc4s^rkE={#L=0nuuZeEu)Vgh11YRK zp@Gr`@j-WMVNXzDFVqKpP#^Rod@up(gNaZdOoIAgGSmlCpgx!i^}#f#52izXFrzdH zoT75_i;D|qVmKNkoLbD~?h_X1lV4oSRX9ttv2b>6;T$9<&xJa99>mG>TMHKi6)uE2 zc@fmfixEy<0(J6IsFRmLoxB|CC(lhf4*Sfb^NP7LQ&RIvGD|AC3fF_ZpOjjXShxX_^q@7S zN#Vu}UT?;d!cBqR44Rsd^tHL9a7$3(RgOYMT9)PlsK)Ob*1 z2UL5;=jErQ7H;o|Fk)a}NX*H}FORQEEy|D2NzF?y$p9+@6?~~hkUA!}G^Zr9ASX3G zCABOwIkmWOM`wfps4~dOPmj;2EJ!T^H*|^%cS6e1_>_{$g4DuY(5#XMYH%1EaTV@{ z&jN-Zip=m&05#HUsy=7QXknO9I+5}#O9 zlvr7KC$eCsPJ?Lq?QBwIKC*gur#wMwKzV#2v+G8o~bQ7 z3#rj_Q%j06lj9Rhpp|WL;W@tuUXbCSY?+ypS`?pLQc-x`4-zhEiAhD7$%PmEAi5!3 znC+mJ)y3GtOA!4TVC#z^F1`%b4izfA0tq%4_iA!MCAc*ipOcxMQ37h7>VcY-nI)O2 z#d-yaMa7wU>0FL5Cqi1J*EAaouWLkbGB7Yef;lBKxA2BHO9U4K14B+Cxa9>By6Mdl z!Nb7704nOi?TO6%yuw?xg|{`pX+EzsH$FKhvA8(3xbTiQGuXTWkf%Xng?GJ~O2KZ0 zRtiOh_uTDUsvf$p6RD11Wnf@PN{Y`ZD!h*n1l3A8MTHL#BA}Sc%*!mvOw7rwN(D9R z@=NkTEt-PV!iUiG6`Wa>%2oKto29kzackj|pu(pah`gDx4(tS2t}A>74qtd?1?Rr! z?hY!Dyj%)NGSIM%huU2DA_Ej)FEui5c(Y{O^k&JpRa^KSVS)|B{7%F5kwTe0S9qTQG9W7Mk=VeQdIaBMJ=)vmum&23e8N+ z@d@H8dS4IN z50YEw-e}KxUD*aC2Ets|m2E&`AZJIfwMlSbU|;}Ye9mS^a<-wH6En!!kXXk?de}om z9qb@9gDr>#*(aEy{lx&(rD9-Uz-OO0l6^*QPQoDj&;vrad4YowM~s8=?Va|2Ru;0Qkk1_eiunLE0qw}QG5AdJtC z&PaYVb#t;O#*gTJM&^V3MhOPF-7LAc6qI&B7`NMb7}y|gF9Y?B^Yh%Cm>8m>jJZH2 zT%4C43JPHm4)A6KQSdPll28l!7OPB^F<51?o?LDz^JYp5DU1dwes zMc0TNS)i&F5vmtp8lh1f;LXOS1C^Fz)`hYd7(gRYAP$T^4b5#rps@jPp~MCn7Gk&y a8g*b`U;s%6c(a1WpBOk8I2b_cA!-3*4~)0~ literal 0 HcmV?d00001 diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 2ca9834..8f77fdd 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -123,26 +123,6 @@ class SiameseGLVQ(GLVQ): self.backbone = backbone self.both_path_gradients = both_path_gradients - def configure_optimizers(self): - proto_opt = self.optimizer(self.proto_layer.parameters(), - lr=self.hparams["proto_lr"]) - # Only add a backbone optimizer if backbone has trainable parameters - bb_params = list(self.backbone.parameters()) - if (bb_params): - bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"]) - optimizers = [proto_opt, bb_opt] - else: - optimizers = [proto_opt] - if self.lr_scheduler is not None: - schedulers = [] - for optimizer in optimizers: - scheduler = self.lr_scheduler(optimizer, - **self.lr_scheduler_kwargs) - schedulers.append(scheduler) - return optimizers, schedulers - else: - return optimizers - def compute_distances(self, x): protos, _ = self.proto_layer() x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos)) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 8833de2..80d7d85 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -63,7 +63,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): strict=False, ) - def training_epoch_end(self, training_step_outputs): + def on_training_epoch_end(self, training_step_outputs): self._sigma = self.hparams.sigma * np.exp( -self.current_epoch / self.trainer.max_epochs)