From 7a65cf64e980d13b54bd61658b4fb6afb51decea Mon Sep 17 00:00:00 2001 From: "wuzhifan.wzf" Date: Wed, 8 Feb 2023 08:29:56 +0000 Subject: [PATCH] add structured model probing pipeline for image classification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加对structured model probing pipeline的支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11376544 --- ...ge_structured_model_probing_test_image.jpg | Bin 0 -> 28297 bytes modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 15 +- .../models/cv/image_probing_model/__init__.py | 24 ++ .../models/cv/image_probing_model/backbone.py | 308 ++++++++++++++++++ .../models/cv/image_probing_model/model.py | 93 ++++++ .../models/cv/image_probing_model/utils.py | 148 +++++++++ modelscope/pipelines/cv/__init__.py | 4 + ...image_structured_model_probing_pipeline.py | 79 +++++ .../test_image_structured_model_probing.py | 29 ++ 10 files changed, 695 insertions(+), 7 deletions(-) create mode 100644 data/test/images/image_structured_model_probing_test_image.jpg create mode 100644 modelscope/models/cv/image_probing_model/__init__.py create mode 100644 modelscope/models/cv/image_probing_model/backbone.py create mode 100644 modelscope/models/cv/image_probing_model/model.py create mode 100644 modelscope/models/cv/image_probing_model/utils.py create mode 100644 modelscope/pipelines/cv/image_structured_model_probing_pipeline.py create mode 100644 tests/pipelines/test_image_structured_model_probing.py diff --git a/data/test/images/image_structured_model_probing_test_image.jpg b/data/test/images/image_structured_model_probing_test_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..54f79fea80245ebcb61da62c3794dc5bb5d62e69 GIT binary patch literal 28297 zcmb4pbx>SEv+v>%EWu@gB)Cg(cXyWsLI@C8d~pj-g1cLAcU>e25Ej@h?u!JL4Nihf z2zh+B-n(_*`|tIs>Z((Hrsp>^Jv}|A``_}vAAqNtsv4>Q3=9mw+s6m+4+&5L;9y~6 zV`JewG8`NnTs$ItyhkD>BqShuMoLEZjP%(vaw;G-IRzc%vu89fXy_Oin3$Nzsb8|Z zWMly{GBN%~3C3enJX|~ye0&l{if0sz|KIfQ3*Z?(h6yGB3*#98^BD%#GmL*D03ZN> zf%6}<|4o>6A)1vh1-r-J|0FqW@M7z<6wqgZn7seD>HW=KrSuXnxd< z`3&m?nIN_T`AZ=K3OhqDX2poKf2#nZNAaWR89)wjnJP5MeNtFSLWU{-18rPpLO`t? z8%<(>z3^HULQi*Eex4 z{icZLnpNZKu9|TqLmVJQTwtj%_EWNk)S;GYXl)rD$GdAwB&#pQ>m7<18ksDmjz(7X z`48EGSv6TwrYK5j1-FEK7uf`ZtQsh?jsAikAu-mzh*j;A{&K65WR;<-8*_J0^D zN!v|+=-l%2;P430b5HN4uQ+}&KC4-Cb}+80*e95@Cb$wNWZ?}C%q9xSlnNx2I%IU* zE28hZ?L^zS*`jA_UT!%wRz))S2=i1;qoS8t8>bv$DK??~1LxqrPwW(u1rsrU%Ln0S6t zVZ8r{j%lsi&6@q31~f zsjn!_ifUBX0RZ9}ksVbveQo9#f8aXfZDi+WSUXH1>b;;7y`{e>kIGVx;>birKQ3b| ztLRgrVSmjS=KySU^%8x%Ufxv=WcXQM=<9o%ASdY+6@ZHZYWE{yf@s&dVTJg&+2 z5U!7_?S@6lECiVHssF_S^d)R5uTy+6EHIVH_~Q4Uc-G;2mI_efF1#^_(RVXtrCiWV z_3V6VA7fZZ{}SjOF)e5q9$-+Dh{lxWS%{cOCNbouJvc!mDGCV|D<6H zkx|$`DUHN4rZ>SRG2l`yBd)7Tw|&y)X%e2_5vj~#O@`sk;Yy`M&(VTAVS-^{PEARN zlQe@n9VstRqikW!mKwvjwQe2_;wcculOTS|#T!M{){0M5s-i!ENerFW46R|@7>lMq zYQUXwF#K3GF+w9|sg(G*sc5_(k>~PtbBjoWU7M6HV(g1RG44XW)EF)=N5BRMN0G9I zS26L3{|mmS=OEu^DR&L!xV*d_R)l9_1r3$FQZzPS)0`5HC=ggjFz`6(^qlRA&L+5X zk3EzF8(+#B*;$L`b4SOkJRMn4!Qg<3426R#^|PpGVubqy2rA_8fs*4blVUneiF^gv zPu!mDx5W41KY#wCfCN~eLj(NuDZ=;l^9D_Bt~e|Xe6DE48qLTzv+wMjxm1Z7mGv01 zE8bM6WX;6%JlYM{nl$gI={DCf6*IS2Ozwha%eQ<4jlq8#4GSZYD)mTM_=>!a0<$LD zh`I##~f;K)nxZZ(39SvhHW{%(NO8_@HP)EfYqxeb~U zP5ui!w=kIFf0$4Bf3SbRr;jF?%Op$4qNP`?e2O23h5xiog?h#KMS6o~T7;23?pRDO zF^e)^WJeqR`Gf+ij(KuFLCYAd+^>ea#vIq-ItQOAlH>>Bvn$YeUG#Nny}YI;`Wbo%h39nx>Kp5ky)}*vnw@{d1tucuEe)`4EE4{RDVaG&i26*ge8OI z2U?a?P}($S(%D50vCiVgr#k05FEF=DdCa6J#t5e7VX|m1|CcNEFE?wzApu3`!W<`Z zXpz&frJ(VZDGgMW>XW#splR7Sp@!);+SK&TmXDny%STi5O2?%C=~8s!Z-T?Z8w&8F zz^nufGMO>jm_ninylfv~9CUkSPd-gc8E{;a>{Uatg(6k!$tqkLW|FB=|B`)*$BwX2 zO*B??SJXLr3Z4Qq1|TnTII&lqfdN+cw2S^g@)a_c+!n}iRds* z$NeKLQTk4E2c9kP?KZPhUEF}HOJNb#F?T@$K6)!)~uh_z>+yiY}ElMUDhSNd#^YhBBT9F zvNB!AzB5UT_03A+x-V*2%2qHS0mjI(=&N~GB&k75G}&%X@`InX=TGw|=}T$HwPS-t z=DfBAQ>cVH4+dKBKG65&3Owm-K#DXbOgL`%K_=<$djlmuBi)+i;&5M>nn zVe-u@rgf;+)#W*bO{ldC(Trg;CN!f0KZ|_CCc4ra{hY`B+O3xpr8my?Zc7H=Z1me% zg2B*f0cKMjs4QUhhxaVG_OXjd&BVxCf6EVT!wIa9j=ydE{~X%bQA02)TY@`Gg6Ez> zA6r*2aj+!ziD@Jc&1}!Dy>4dU{OK(d9IMbtB9Mnntrd~m!&9g( zk3Dq;7jGlhKA{mUP1o#a2Vdgcbdy>7OCtpbeHaJse#7$BtNXMPpYzWQax8Xpcxo=^dAr=iaRelB4 zkuY+2VB4t^B8T1J!t=*)g6W^YJ{-qN-HC~7iL;*HC~wSNcyV=ibwaYkdi<=>RDZi( z#1eXRk+8)7R_rtVok|^kP2H+htr~NMGyLnFaqi7Vj*abj=xSJznVRr#iP`v&Q9dEX zi!d8$w)*jara+GaagS#Ql=C7;o0(_9Uakeazqcfq}8XSuV!5gsoV@5l7oF)H^V;}L2J%V=W*W_=> zgqlo?-JkijQ&s#->+RM}YEr5j@jMFCEU=Ai%ho<|yJJggQnvF``s^!1Zw__?)5q|N zrgHp`2#t7^N5Ym=?{gdSTK?dLH5TJD{r;jw)q$j(N0;k1la{cK@!ZTn(}Lj78tx^me3$xU;q3FxPGln_v*J#7V2?eNITg`|c9CH1^898gn=`pbs_7JQb!g zm~BDfb|-J`1WY!Q(hC;=SSBLd)=Y>ZKFWX8CxED#$V)ns)V2~H{kU?C5Me-~=5gH{N!)fAJ6SWfb+WHWe@J zptMbu)88S)ydRv5IT=!tEqVHTk~$YGVjyQ2(sAO7RGX(50##ouTG*pMyEN6;rn8ML z^@#Z`#+_Fd*N%bR6+#(CGfok|{@oiR!|B4ZE>;vSw@=)XATNSxCHC9|Mvq=%pekE-%Kx_;NS%T9207`~{uY<|Tg-2*5!|scD6QhJ(PPWK#H=}llP+WfI>{r|mU#{h6ZC{YU zQ#P-x1-qrYL9N4#vU2#nR<>HzQyNNV8v#7`YB^-gSKZW4!hmj{k+wSloAi=)2dkBj zV`T}ai=ZE1@#G!)6$cyDG0Vt6qrLyi$Vd%ul3B|mBqLIIaJJ4Roq*(#GT|0GX9srs zP0|79>r?B&6(URaAg105Oj56P|Fw`q5mWnlxe>C*xW4%x;Hf2*eKUi)f+N?ocL5cx z@$fjIqUpqj0vdx5Sn8=RubuYY?`KcP9|}6*AcQLb^DpXN!6v(@HBfm8Hr)_;u6~f5KWZ z{3slilQ69qL-<+c1-j-lj z9-JIEXLztoR)htX6h1a|T_0-)4DnfR}?yJ=|OIMT1!glvE zFxcd%lK%!+E@K?4e2ZyS{FJwbd&yWyMf>jK2F)m8x^cITeaysxbl=E70NHjhdz@ao zk~?gBLR6dZUPC?e^wbg*EczDF~=D7^YL=)EfTyAy?JVZT)CEc>I!*|)A;fownEB4-XXy4bZqML}5^qd{-Ss5u=B7_W+6heB>{L+i#%mETSfV($WOENv6U1iX?o;Q37a&I+q@+k@K+fOVkwMa=J4(gl~Yyc z-N^VS>U-7$jUI)wT2X2>R#kXAxwhO78L~jZYxl$Hewn7(L(32opKbM#LLmkF_vcpE zz*oaV`}IT!h#2fSMC0YL7%Gcd)Z&=g$rG`l+`7>F%u(_v#c*-Po$Hqm^_wOJ&9($D zlg{DPrjetxrYDg-^4-tHtJ1~X?bl|dDCp)bFO5ZAOFqo3OR64IwQCpx_|Dc>Cd^e6 zKtMix=E`j2J!$#WLB^7RC4He!UA9>iTV43BIxWhwPsPntCu=+jd#E+U_6vNEKdDq} zuqL!nH>aR}LC?KIl5`7)RF8ar`nnUSc;j@YuHtBRFP`sI{%Nn98NK z4sLD4Y_m7wcQ5Eli6jd7>~eOti^@?fFLy2fz3gu}>9ti`A}X$%xBUa2e!ZvtDUvpy z%N~p(Az}Fkcm|SV^WD2Q-|n_*_%gP}SsSIKvnt=&(EA`=KLZ)2ox>$o>v)IE>4_ z^FrsYInwoqa?x{{0tMNYI#u+cs za1~#d2r2nNmw6e{wv4&G9!TWz9M-w{+hZCi+f=bWIbpXY8Yfm${-8nSL){1CS+tBZ zgURxAY-7ZFo%=}Y5jg%Mw_gbvF_6HfZKd`uI%&lsqRn&uFWh0YxN3^U{%atES_ZgKkCS^)v;J{0;`FClvazF+;Fz)T6~mlI8p>@=o3)O ze=7}5#SrBHto?q9m0Oh*=MeWqm>~!Kw7+!)%T?x?Ugvt6u&GRZC64lp5p2*bR@hs^ zaItlmd;jN#&`m1mPJW7d!ekgSFe=EGgGG+w%W)wLC zuJi&OCR43OrG>^!@G(t_k7HTRZMhh+Xm)!6BD$4J~j6M zZ>HVyYB8Imz=P9+FQZT&eLh(m*jD4ZCY!+=;UY0d+X5SI^Q$+Yv&B1~T!iy4OZPk- zBUN2#I5%^Y;6<^D%{TxVRxj#O;5P4^IcI=3L=T0B34NV{KTG;nVIDK3$P;56HKwFd zUf3N+>6egs*@O+sT%N9OBr$t)1QA&`$$6u*B#(mPk+hMR!hgt~4t=7wz#@+kCfB9J zeJYOKgE4NZbCKtZG}D=5V&B%dXm%nU-SkEMyp=ook#C8t z!u0uh<@+lLC(INAHOd{(HGTF0CDx_W?~+uEOZ})pw-SQG3`^y4QG8Z4ze^YbHVqfR zEEH|;SD(5rhF#lm7wz;R550b)hLZ=^qJ7&`rzD-VEX9+0SsJ$zNAn!Kv-2z44@6Xd z<53Ab>!!K3Zh`SM6-u~^m(K5iooRKI^6^_>+yP$K$T)W4oCd?LDj|ZnRPcS6iQGh94l*p=z&ql5Yr;Ll*Cv)u?fWbjoXkx!+A?(>{|^vXr?2 zwQk6+pBEdMj~bLEDz6j|*?x}m*8J6Tw!*W#JHJ^Wf3x2-PRF;cs=?fI1bs7Vu@-E9 zAX%x~NY1me$;=D!KDsouukU8m1sBV(#j+Wzpr>9%eHf!^VG4VI2RoT=%+Bm7&*mv_ z=q}k7Znlkr26<2KUE7L8-~3*nhO62fPX8$XBYZYYmv=gg`b=Dq0_#&iS}GEWail)&z`QiH}ts}&!|;S=hYRk{R8m(xZa&Z5Gm9L zomop7f`1y;*>^9G{gRWWk|nCMI*x==>43Wce1?dL$x1NrG1z9ZKKM-Rv$Z9Ir&eB< zs#d=Eq*-dh{T7k@+wFd7w{0C}Fg3ZPRH^+s)5n9#h@Pf?bwD8hg>ay#cC;j~<5HmB z_LcTKH}9x_0O8jK^^T+DxsnNdZ`cK*jXdkfb@g?C6?QnQK|^!tCKtlBhgjj%ri5vO zs4Rt2bi`H&^`^``mXDG9lQ8p_y@!8as4(6@XwUNSu`IXrSf0h_joz71Mo6Vwl zQKUcJwONbpr2M-q%e`9JY{}n}sxg)#gluc-0t&tR4-mC??-AG6zag_tN@fNod08Xr zblh{)pFbsq+?yhiXfVfgF3XB;5lIva3%qJOPE4CbTQahq5{vK%AH`Rpd>cy0$@Z$p+QmZW|! zTe*D#@tGlk-2_237n-9B1pT9`bi*Ik+=6{Uy&=_JNYqO4HiKdnaYW`0R{9~3)FDxz z8`;gyvUD)z3xVg4t?v)4VeW5TgFN^QHRx9PCtdOrUdO9%pIL&$=mEUxrtXqp;7KTl zi*Y1f>-qWUsK^p>5`_fv^=x#?^sn_Ge!ZBxM;NDdeaIdTHB~CuO<91;+zq+|4zE(!lHlq&}}LubKaMd!gKBm%>=H=#1}c@3gS|4_x_yp zwa=D!uaY(DDY@@eh$+rASxGPoLb6kpK2>#4J=q>mjuh|f ze*ht^RNL2Fx)cBsGOnq3~8S0#0U9-D>CT6aNEHUud`VhBiqr(K`0I2N?EVP(p3m_2^4X z2zTrYm&`}E`55^Vk`>pAUscH!ArEcC2icoXQ-)V9|C(Ykybk`6g(Or?w(I&KY1_6h z=C}++1{+oiFHrCG&fCoHk}JOnGi~jog=W2~st|c{E>e# z#`T}O7m&U#AfHuWRc?%D%aK_FbS8^Q)r?wvcRXyEyV$$92@hD8pk}$SHA{%JlL{Gh z4=$QF7vpaK4(6i!hE>YF{e`~fY0Q*txx*ZgxZY&c{FlS4{Mt@-K9@I7@!gFS9fb}R z^}j#2dMZ2hcABkI--=>dAwF*9O8$=`MjOfdQCaTUdl6;FhN4fq`%Z?YuXb`_?tJv{ zpa8j)2UfGgPR#sV=o5BADd+J&&3Np5pdvlJ`md)g+XaJbydeI-kb>xBX?JNJRP=ya zNPGLdL8$h{k&wvO3KWvCO8lV5p;-HRO1P6A`665JvcMc?4*2%^7$&RT?3LCu;Z*n? zxce+zDrPAqWmLNG_I6RKENzMe0@V28ArC#VDGqfAb}{Ty9UI>z&Un}Inq~;?F*g^z zt=yOTpVN78@h;<=JR@zYZgiT@!quh%f!^Qsb$0plh`*ck~l~LfVFkK$r z9`(Y_{VO8>&_TU1bniNk&xvBhCw`dz4FMZ4;s3`(B!85}A zQCE-rq^N>rTh|lwB9>W^#&^BF7QNur@7{yuQK$YRNwOI_sr$2S8b@xBq4V7{uZ-&M zgde+E>leS@lYtlLE{*EBcF4Nx3q57w<5|7dgH^zQEjcqOmebCZGvGAxcXu@OIPq+$ z%KnfWv~Rgwe67yIC;+e)7^Ytfyp{02q^aSzh4{f0fy)2_1md?}rKJyT){`6}i(#nzF#=XAC z059+H~tR2G~@rk^ImOyeTJTM4f%xPPZ_VZ_e+ z{9}RBi}TI2U3Xh8OJ3|OwVxtEhCk;ScIRFr88tr3!hyYqbIo{^aKG-d{QRt`>dKns zMR9l0fnV_V^4?}SQr;&^eJMcC{KS1}H6yb+DV^#nFmu3*=vRmDM8#zKvfODMr0eB! zfZjbEvx_FfHt)6~QYG1bdGfZ#E6PA;_SS$c#8vVH|z--IsgkH;p!DM&na0l=h$-#&gpt*-c&a8eSRTJBwYt+ z5c#^Ag{G)DI3xK9VrRpq$cZ{xc7@l$ile}SGh}`V;yMTMdJ`ymI5Cyj-Ep6%@sL*q z)o@um7lw#_{Z?!};~6>zp-R6=rJ* zXO|Xf>BaZZPr!}lNvauh9^PyG!Oj4eG;_o{b6*sBW7_;2*R@X<&A90N0`x;W+6#U@ zo;`Q)UKo&FyRmOpXb5nLEt9xklBN1c!0hSy^UAu~3q7G>2^}pKyFHC5Y^`Bn3~f3q z`BVSF@}cDCh+|?&MLp&X#y4IZb=1+hslBlOXLDRZ1($YCpfK`)jz;^vDtp(mgbY*ml;19 zZ?@bLioG>t`kFw+j2p$kR|BSKC60H5(L$2`sx!f)#?cS(505|a{{ak8CuUZ6f0}uU{p_lB znBv~;tFWqm67Dq71>11GuzUR)_m9p4JQj2vDpUBLDga(IJ6hgoYp{HiU=Dgg zt6v~@d=C)DdRRbJaE}RHqnb%|%`?Ms?%J5f=YWm1C~4BXaOjAu z8DGods*6M?y6w%}at8dozWQ9knX;I**J-4Oy^TnlO~u-fEhvrvpRQj_-I5P|&*TEL zN1SyLRM+ehi~%E*b#4M|^1HHNax*ke|40yTmQ`8-BVs;Q*vEi3rxlC*Yk3aY%Q zX(aX<64Cc{*Kyocvm!Kg5wDjIvj_U}!%%gIE^8Rbb;vZm4ef0|uzn2F>5YDmKI=HDB^70>G z*H`escSEwa8hl7d%!&m>&Ix0A%5UGCKqpbzrdD(B4I<_{F;#yHwZ6AZav@kT0^oI- zUU%3XW|CCYQ0ik^z1Jr!Hb3lXWMFIU6cBqZ&O7n3-kIKTl4om$X9f6IN2)x~nTk87 z(l}eF=0iJIERQ_WHG&ZPY1Z$OQgOWE_yl1A$*NDO9k!5gtL>rP?84qaX~E7ld{iX< z{+nT%S0`pGr@2!!H{JZ)Z65e>ZFu+uNvw;s9Q+!3XEVa%Iejk=y_4d_UxqS+a1vtG z&3mom|JdSvHyT4RV+XJ5$z*Mel3S%kKb=XPDBE2laqoorW*=DexYSE(*kH65CDazJ zv&=Srt{cj*$|R52gHBh<)+^`<#4HezigzxCR9x++Fq(R3H3mSG^4&>Hb1Qb)wwcAa zT+N#5Bjbza<~5g5?^x1Zt1}al#kd=6Oi-KgEh0A>P}??Y`xl$u)lpWMeF0X1hfdd>;v&eh!G21p z=vCz$y@BPeV^--ie_SMCmozj#g%89S=)RF|IHS0toq3emk~H`l=x+hf2wNA}%3U8_ zmlm;%ZdKp2I%`P9bC-~r*g~@+mmDpotofF{Qe?#Df8^Yv6+G?<%^aQj+h?PNns+pt zx2%w?{Pn!sRpC1TV;XhGcVmzK2hd~U=WFhC;c_`hdc)R_5J;s|oUGF(y6XHqcM6z`Foz*zP>K(XD0tu%=cR-%*V3>qilS<# zzwJhRw^Zw!;MMotvx}?^YYKNt{Xxez5PmKQg!*+$Z%Y^w5U8WZsC^&2Ra6t1kY>>-w`2hQcY{QjFxkg2{)Sr*Os% z9W>9WUDBJn+_5o-fO8CzzHWNP9L^MAyQ(cSr(d&ZEzdIyisW|zV|(PMh<%T7sG7!4 znKtz+UB}IcU;zFBsmW}Qedzq82x$Uj%|0etY(B7e+`zlnfxJb2?7PnUJo*#R_D0t~ zK%PaxAaYcrMJOM44AyVmNCozof^N1pC>E&Q{)q^0`b{Tc$pBN|ck`)*>FL3T%#Okp6q}#BK9EztB zu+VR>&o3ggWilR$I!Y+jP;O!)5Hm!UwGYhs}Qg+9}IWH6uc< zQeHoI!J}@+wQ}5a3*m!bY>IvvF(2>NNg}o-wq6|LG!fbt)+ zJjdB=1Jx|NrC?_Z!&z4`9G(>#5sDN01Q3szoMzn>{qLF&9Z^HLt55Jp%=|PkaJ;3O zW2*P?q00($0h)P@v6@RM`WNjdgC!%V`LpJnvs5jh@%LA9==G2B=5=j z4r8j9Pq$)hul~L-UT@Py`|N^~q8qolufW*>Rtb}qH^t(&{lnV2t)k!JCuhU1Dk)s! zpE76uK6!H1HH+#@#V&P~0M;r?d6wEoKV4AYB8&=t!KksVBhhjb&sA3kQyu!V+(!t5 z!~VE0t=!x%@yXFL7L>YBWclz=bP{JZ0?3|R+GVz$-L}B#`{?}l7c6&%PCeo;K9d?s zTGk?eBIjsi%UVlVhTQu+e>=c^ndI&jZX#*|vL-8K9b{g}W(zNuKu?|E*}j)OgR(%Q z_wJl1xLtolt8bw6AD)}XustWVRRm!fOJ_&^hV@t+(mU}hj()1SVUFdLfc?ES0$y)M)mtN+7Dla%roi5%J=1wL0+u#~WkklTLN+?oTC zGh--+%cq97<#R02Jpk2!Zo;|4wS9CS{9*NY_F~!D!-!JHP3f<9oIq#e)oZ-Y-fPgS z-6ewq*fuKjZy)=C;Xtd=oWVE_Juk!R42_srtBk4Z1Olp+%wYd4Af7>@MW`#UrFFVP z()_I$KYLPjEz^Q{TfwPbT4Kh;x1h!#+!3d$k8MN5c?<)(z|`|U>puNLKer@yhX1xW zEOFP>vei|cb1pM3j*S7g60?f~K402=j1}77f*0_do^34V0WNa zzc5yxwSoHP!9d=})T*|O!JH!SxK!k!%<2o?U&?u8vW3Me$McwJAq$psI!FuYL8r&X zq_|F?rG94%oK^X)0F#yiFsu&zM-U1JQ)ZEUfEVv(gUG`^=eMc%{w#Gt6F=VY39Gjjd#WBvds=TeGfvGO``8f2IdHq}(`)iz# zIm?l>LrkD81JxrlyA_snUyLwd6?c~Xl1gyj&?@I%6v&^n+8tN_Hs(y-+6ydB9=tA6 zPP*mhG(WoR!bBLBO;6g^0Jf&UXeyTKhZVRFG;Ut$KI9MS40hb$XjJL^=)FKlx?TD} z0xQu!dXGJ+P^FzYR*6fxfoqLH*$JlmZ_X-4l7=JDtahTU@~g{yGtP8;_*$Zz?=Jbk zNh0d)$cV67+#+S3Vru1o0Db9TZdDm!$o{2`R#c*{IG4_mt8f8Un2EBl@NrGzg{L8+ zX@zWBSA!_Az|VZP7NL7HwJnJF?3w@U)(&uVudrPDJ!_FQM2jiR*p>CHGyH9qC(K{f z%y0=(*$yvHe(`1EkN6buZ3xfV*~G$CqP3~Ag!~>$WixnP__(n}gJtA6q|*b{rR4d02Ke>I3@{TbD!n323z#FsX{Opzqb9PP4D z6HwAy*PZls_o`JGVb;8Z;Z2Ixyfu1uUSyJYps8I#un~QFKD-u4`Ylw$yi1!0pUCnTieo0NnsN+o3j<;lm2~(tUJy5Ml7Y3 zH!sX$Y%E6VDeiJ{2DzCmk&D`1YdQ6~$w>WomCJfBigXfk2pMp)6S9LSF3>OUE*hGd46eXb-E z!MaS{&s?-UO+z{joKF^Ljr5E(|KRDq@Ry;(bLbmxxFtt#d1V@4|K25MDm^SsM))$w)mt?1g|F5)1mzHcjZtuk;go&*o@!85!u~$WsQ#N# zRqmjT%SRVn7!}A(Z11F7I+;9y^f<76je;^mgSlJO6n7KuDo}ullTa9Jtrh#HUV=@= zmGZ6;DEzCuHmSXJl0CrvT1uRxbJ)xxW^nD*&M?E2s|Dy*%ff{I#Rg&jhcyP)&)?i} zpnN^22&&Q@Zs&(w5X7fD~xeic_ zS-7{vnM2Ns$pPOy^LvCaZlK>g=gy*gxGWtUd>e9l3ZHk8fX{tQxfsOfC?8cJHXq-ZN4#Yh)|C_}pc3fnk#`-F;F(6jSiSvcEj2 zI2r8w2Poeph?m1N0}E-W-_`X-dwp2%Jyo5#=t;g`@gD_K^PCM@294HM%XDH^74^Mk zvI?$@ol^#x6Xq)%m2oZ3h&o$w%aD=OR<4fSi;te+^s?gW`1L7~-?uu0q&U&k8p7*u zVyjGHufB_k_Ks;7A|%y){Ku}+UNO+etX};#ns_Psfqky{P;7P!oaSk0q0Hzb1luuD zTqzeHYh`-Vl;g^4r4h=H8u}$vv?lH?6F+HS(d1}(0>{EMI@k(vt^YN$svYKFudzp9 z9VRMk(GybG@9A%xv^u@<1?RH48@8AozE|o)fcrV_4X)al{Fk9j>-u7e0t3~hFm9{Z z>p$F*C&{n7S)PHtOQ~U9lg4^VDy*RFq^jW@9|FGEVIMzHp<9-rPF_a(j^~l`s&xRl+jrJ^%tz=gf@=Zzcp4(q;b495>Np zUT+b1D7&$7mvwDi>r1H~KMnXv9ySBE;cKX^6FTrMc95F1hjhjM9dbZ;AF5H0!bn$? zbQgcb4=e~CJZpJ%i{F_;3=;EhR?C(^HQ3MlY(9G@vP$IJX*7y@L+U+HlEL6_%P&iG z|0^*tE?F9EV!_ivN1<0yP3@e&$bQs+>+KNq?P&W)E@EfP`M^qdbhh;*-(T_=h;_b_hLQ4KjETx@-~G6!;Blc+Ghx>5xN+*0@IKdZpYMPFmFT{we%TW z>TX^2G}lXSv(KE&Wp&-kT8|~niW$%03q6LBp})u5{ycPp*Zk{+z5IldT>$X!7u<{t z-(Ey)ahP68QD1V-FeBp`yZ_bT9P@Uc9EC-BV$4)d_WOg&-@l#sPnN5vowwdmnD7R} zHQG0qx>r2=Gx&;YhsEbPIu3&9@{3Hn%J3u_aM@xD@QQ~UBR6v_y$i!6K!;f&Q|WF~YKW_!2jo%K0euHo6}Z*Rex6z_}{}Hs*AH z;+r$xOPt7F{{zHWCH4A@z-f~Q_pKv~(qi$*y+L|=4Qp3{_H2=h)B+t=35-0Q-7=O~ z^q+_VJXszqk_xPnMqi$zczBuJpj^A&o1VrQ->;}T2TLF`BJ_Dy#~cN25g2iaB>Eev zCc2^`!lU9u%%}3ro~4wT*Uyp=XTU8xH3 z2=)}v;dv&@vEvE;_7S)exaif0-^Z)>o5|>xsdCcI&T-Xs1RQN$F*=+lO~& z3r|Ax*^PJGAmB4=hTw2Yme}If)QQL}zpoP%4j(8)SG|ZP0 zC{`D*HOGvMlF7qat|yHW&2A#xg)478M-qG4Pr8%MXhBWl=*%<*<-3s#O`696%H5|Y z2p31Kdl%;3jnAI${vwbBm42$JpIyCMw)d%A&OiXLHvVP{O)(ruZN;4at#e6*MXbiaL-fu=j4%+eAg*U37P0?Zn2RpPsyU0 z`F<5?zArydX!c`fy$*gIqju0|`N@6}mnL&+gL%JVHKYf=4(4b$^mj^ax&uvjgNwxK zA6$yFGH0~VxE;Uj-FY;|yF;u4LmVyq?ifx{HLw!R`4^C)SK4-0QZ1_9B&2dPexoS1 z@~L8A!MbHb@?dML0nFg%6G!U?SgEhrudoAq!O`P6{_$eT%}NBF^N*;Bm>J{w&umJ( zpWeyvyFwohZBwY_5ZO)K`a(>3kovlH0asl}36{Iq>L0D1)|1Yl0!TJ!XS_hfG08|D zHjvTU;M}wMUI2-eNcb>e2Z62GXp6gdnn>Rf63(>I67+fdT}peK?q>Hv+KDe|FG+>w zWfeVomtgk?O|J}K88*gJc(;7XK`&^qIXk3s2{n5EFjrQ-U}NGh(-6&bm=iKr&-61v zeA1i&`kj=0{O&CK$!o;2#)6$1qq@K_ly2 z;oG?)RhXzeOv?Gna@5_0=295{Mbq5I^$6>SHHkiw{}pI0m(nkOp;+s*16=%G|Y4j-)1%6hp%>iZ|@{?zcqU+LWVgaw%e7{3#TmPuG_)>BQ$afT9RK*<81!`3iM;!jgLTyu}Kx# zwn}$|wqIi#RnuKeOp)EuAO8TAam#9>3*QUFnCn}z(mcAQ#`#|tD*N-&_l8~n0CEB9 z;eXY7eiu@_KReuabjJtmy|v(2bYf$A1PpuyE5F4OCwExny^ zKI4_w8q3PHx*;wLLFiXJvF+C!tIxbkq_ymPZj2V6;UROl%QWdZEbe=Fhr}Nb!9iCX z=54Hdz;jPz*1bAhza`yOnukf9vDMK-c$foASpNV`$8fxto045+A4?yP!XO0-jaq2ATTD)}>#%7rw`5&?>Qa5gIy6MT~I%kJ0 zEb`LjQz@zHg>3VG+(&1KnA>`~40c>ECR*`&Zx$0jEuu9w74CexwPA(;@+|nt&8`K4 zmml=kpOtF9l4|Wf4YduWV@_Eh+8hrA1&n60-E_+CbdIhM?B>2VO6azYu6uxMN4_#w zPaAJeQz<>(*_B$8qS7OeOVnB_A(QbQ!L&GYv|pcf(;2U)9P*iC@3lB8BGxe4s*vJ$ zzAmBHx1@!s$0c*T*3CuA$_S~aV629EX3EXmi}$?d*RePYUU=)KJTkZT>0XVxbv1vT zEYj7?<`B^?KyM#Ue|661m70@IxNvdfoVM+$8|G{tMLi_WGrSnw(Y2(EU*5}?hn08y zo;dWcJaN&=#`%J-z2xu65r9b-y|2RUIP#ZK%Pv`emO`zlq50Ok=6-0f*@3RRLy7c} z!N=KLc;v~?E?HwPwbeCDbTUOl4Iyph@=nh*gKM~sPb;_mJl5Q|GW4%ktFg{x5JY09 zj5VHIAd=vI#}cz9*Ks<}-3!ZV$4T3s)XVJ&sZzGu|6!xz_`--ZN89dFADm zrn0G{+UOXJAObFlv@%2;$W@&kuGwb1Ruk zEQOQ6TWMJC4T9DjmkFrpkG|V9x~7>|1>$Qc+#CUo!pA=SSEn9aZ>GAvwmecP=w{RW zqSnheEtn21w*;-{{VrV#L7@%vDqPXF#53JuZdUNceho?O)L9V{J&Zft*v4I9!2H*y zXW`N9Q^mTBl^zVIr`1{pqU!h00LJDp-uM3i4nX^_A3j$sKW}?yDK_E{HqYj~tk61E zZZYJ#*|6amwd^1Czp995)Oc(>1^09}r1FAg8EtbS^ftzlCqxFGGp(PY8Pmd~X=*b)MxInzotb-dF)K zo?F>;$0JW0n#SdPe53Blv!F;DS8uZP{9dES^JeUK^u!x`uWRS5j39d?3d)T?56iZw zu>SxXarX*QZ(IO*XRC?*mBdk_dz$A&<7U4fpHJOUwB5Bdws0F>Ea0(vqcci|vI?&# zwatK%9C~B2S~jC-YXddy~q^={J)!bu%)!_whn&U&v-JUMH1$;+ha`gZuisEdu` zr@fbllb0*}qwMhfd^qRpkSDLx;sPpYW1cWRhntQ$S;Kx?jq*%2tJ|H@>h|jFii!Z) z8aA-=w(Hkrv$xO3JZ5^4(@zW$21x35%R8`5z&Gwf^GTNV9?NmFc`nm$gHC;Jr8_BG zBfkz38taa(U3@q0x!zA5*^4f1N`gw64LOpr=ELz_w_xvnpJm$~3^Bb&lP{H@B=*AA zb&5%yEjaUD?ZV9?4S!Ny{9RW`xpk+Tlam~=>h`5wJEzm>XvI4sY;JM-tO&TxyQH7B z+d1X`0FGm=R<$Ji8roVK$SF3Kmbup@1c;|%QPT6qsXJ=xq(p^idHI6Ybz}Ie)_hVdr#t6>=Pd}<&EBfU= zm^#LzRr6B(yX9=qvAd*gzzqO?i?0?syKwm&zgqOM{imwbFn?=mkW<74BybLrPMb#2 z{_E2&EpT|+bGw`M zQwh|p^54MW%~SbG1!D{!Ah&8l~+G-y6|PI zr=R6xyms_WK5g*E@tg(koyW~+1p}%T&`n7xigw648>%+-apuy+Bze1lADWZc&XSs@ z5Dd|f-{xx6yQ5P{m_4syKBk+W^;yVWunzmIf*c^Ft3>Bh`KR0r>;Y{pD;Yp#AacJ}^h zagi^KJU>}bBbAR0^1%*aA5zaQ9#d^_^7uUL=yj&F!s6J0%erH#<#L!~+tq};xpci2 zn?p@h1b8Qy11xUk{7XUgh!6D*&LRW9~3Yl++Hw*%X%^h_?br=Koa zGV{%f|i9#1d4fOS~vHg1A2){Dj(13)(HJ=Ua=E`DB_&uq+!V)ktrCaPL0!0kn~RnL3e#*@FPcR;yv zaaoO+yYc*teC5yR+LZqQJ62|E13&@4qSp>L40Txf8O~hyia6SF?z*^o8Tzb&6bTT@ zRSn8Ol|vLXWHd+@OD52fNdvc@RHE3gtqoAr(iTUAxGRMg8b zXVM@xhU#IHS{#&dhkLAHFJx12_oO>W*$Fu z8(taLWTB8jF)oSR-0-)(%~aFPBz7~BzFY09$0&}endb`e^_v{C%NK9qdciF_Vkgw! z=vUB}Z;^U*@Ey#_$XW&!1f>gwVA|3u1eETD*>gQj9r~$7sj7@_B0$e9@tlyF6bTg5 zHlm#Giuga(rC)Gt!y_e)uVJ`$Lg%1rhHwR>9tg4zKG$i(Y6#>4f2!6)3a$`F0?Nis zRSTqP1*#5rSJ7JbkmmF3v!_kz<^DYI{WBvq4)Y>&MVjXDy#D|rQf~Tv4~1@DtoHIs znBcCg6md;9=!J}|^&57u@51SXnBy#S@tEJQ-m5i-yuszZ4APbvb+4slYZySy7cJp#hOm44Ek3jy-hJrRp z=LzNGaR76Nd-t;Q^ILM;>3mFan0M~A=4Cq@EOfHQ+zp0NlYU)~W#+ls_0BnCKDwek zE}@N~iKIQFBE`kvbk9ZC338s?GyOW!uAxmmKn|{~rY|RR-6WWZcCh684y(2}cJuOQ zHL9ygLrB}pD~Jsib7UaxAIxle^c@#jHPm_cC$)O2=t0{ILEsQWo-PH4LD5`sziwG) zjP$2W(NO0}Jeyhw+DSKeQWr-lJh!ns<((XGhBhfTyt-Cgn{>xT&n&XrzBuESGwQ2J z73O`UbfBND+`I0&W9!q+l<(fl$`%*;m86Lq70zzobXhw|v;cPF*+Am~K;Yq674UL^ zOO^G|8=_MZFUkZDL<^z@z0oNIAzWP)=o}zv9vH$-6rwgv6jO&^ghTEFCDsT^gvj|{ zvKG1I_6gV8$#13IB8Was(=@(K-j?=UxpUO%>fqts-Fwf%W>-r}`RL?i41&TqtSuw@ z-ErZ|Jg1i4pOH9mzBYS%eFvIpBQlS~EgP-fdhzu6R!MqTb49SNL^3f!1ET=lMIgY6 zk&Mn1MDdF_DvLdD?5I~jQM=TDYJd+82P7cjJkUV_eAn92rBjF?jzVu;n%vWS!`1#= zvB({VV6T4`TaT5^g~unMP|Vjb5O6~3J)h#^_nAZOQ zc6~zm0RI5G49z~IQ%YMc-2|e~_V@gh>}yg~ww7@lqz-~ivEziWYl5;IpyRwl5917X z{;Q@@BMw-A7Fktyvz#fVX4X5xP zksXEX;5+wIJCwz&Fv91ovEJJ{jPuRt)FyY!>pVLifiA zy_(DxrdCVcF~=)+j5`}FFMjL%!D!v-=vHOlX0n|gFaYc9x_vogzqGfQw z>m^N?NDovtUSR4_*~7tZGK;d@1+sy+0uDgNFVztn-(iGBcT|8q(Ar34-rN4Eq?siS zxJB6qbwG=NK*i&bQ9>jf-4O1bt>-9g5;(hQBEt8zuZ5L~VxSKN#(OyA4lWil%TBoQ zOs4kP&^#%rYZDgJ5wWBfxyJy1nzJpYTz($i-50{@2Y_noDjw_HhMNVaD(Ljjhk7M3 z9$hge)+NUGQq8rFUM)s^Khhrdy_DiN1AH!bieGM+e2lSoo)7T#Hl0o)r;IVTyN$Uk zc;PXciOI`zyROnU?O*`l1-B|Tfz(_ofRU{x?Z85Ia7k63o!(3FZV z+~o}$yY zl2=T z&Abc5z67hIbVin%?GDSjgtqHTE;jauPh}+?1hG~@BW8Qf*A`oS#kzC{qVsZ0?e){& z#g-Uj`(D&4Xr~g(Sm~PsVrVXIH@p7;$>dw~TI0v{-i~~vvvrqQs|JTMo+d{Y=D0@) zIXv8+O7D!4JUQ}is{=IhiKK6p5FU&DOMU%0?$nniGxF@SK*{T6b+X^-UGjQMe~G{X zdZwB3r38X08s|jG15Nh3>Ays?b4i9*8@0-2Gf;HTsiM9czZH`CaM1 z9=AYhSZsFywCA(BUic#M_q*b;nK?g%`cdlY1|An}wTvZq`l99ae>@1F~kDba3}x>@)J6K36v%jd}); zUfEdRP{#|Si*gt5ZPtXUz{ivhIl?N_Udo6>5@3&{_9~T`J6{u>8=G!~6sWA@x2l5P znNrGnrWrFCc5Hui8v45$_{-XN4qAPR*U-b53*an@@^GN1QH$9@hpW=IR@~X&;Zb9V z@g|A6-*b50xtfvS5jejhmd`y6fYwRps-!O;AO$20G7l&v7K1*kApYs}jg( zNMt0qmv#HACYiU%6W;e;#=zk1O+?F-t!}b)~UZ+t}vN|~#DIY+& z-S%3>$ zmD3D^BdpSW)XMeOwz1axYEgItMy98|O)035fLu$7&s!w4%UkKqYI;=v0NN8%X?jDT zg^vwsjyt*S^*CP?rEgKyf(@1)eJn;c=!jNan} z9;z(F&TClTG0r>A13FvXJ35K!;fhT?-lm7Nf#tge!@t}vd0!lz`0wQCKQcqWK2v;aXFxLGB$(?XBKP(aYuwYx?*L(1Q67__LEeG#;K*dVmZ zCX5D!t+JK0mguEa%%#5}Kw^kHk#MN74Gd$5+CL>UZFVYa05*#&5oS0o(NSVJkWU02Ik|? zbWKkCM(o{tUpG$6OVk9G1cjlo9F1_n{E$d8s*<6`I-qG{WW%74_3nn0q+_FTjl$l@ zhzKpn7YRhsPB7PSLV)CGAbYE^N_g>briOh>a!r(@8&Tb~^if_D@$Jks@p)StRxx`G z#l446FK!neOsABU>iFDw=Z;=6;gUmKWR%GR%jS0C@+>pebefZ=9Od%Wi*os~#>TL{ z&uPOz1Xz2l`>wWJG2VeAocI8I2YJZH)9km~mCGA8q}b{j@Z!?*V|#_B68y0>l@7{!_PZu)nBkv<+j=bwjDAlY5CJt+^NW zTIxre`DMDlrg%3|s6J!qBxoSpa3hqranj>4?X4ezXKNZ9(rxUn+<{-=Dv0SE_#u1b zZHQ0&SP3{(N zcW$_4;@s1G9LU_#=);Hr0`YriqHFBA+cfF){RSw;V;`b1BVwOTV*s>Zen>5{j>CjD zm6X7ALt+&E*F$1$Fuw?GBWY0*vYWU;D``h%M69N5xKvEaaoGgTha4c4kbpQqVljRb zLRQD@g3M7^DlEp!DlS1E2rt6ZwA^6?pkV^C26tot;dKNq37|H=LKF<v7#Snp3E!b27oXqpBjSmja zYXQ+u4My)Jw{^-%@UmA`HSC^7HLlA`Z+7-m#y9zQ?a@&90-2z1fv)dxw)5X@DEtGb z4ry8XDw{_)!`(2rkxs>xNl)O~YI_Dk+;QxKLLCA&{{Z7}=!md1Xn?lBi?Vi$3$5rx z3D9ZJ@C8C{g#Z|0fSaSj+xh|m8bl+8gk7X=n;T9AkrK9vgQ5Zy%pDLorURk}@@zOj zOoS1|&=7+>ATbq(2o13w%85i?$RuoD%88o~vIzuO;Q%(LgV8`dgBKEo`{l zps@ucvzw7%u7{6E(Pi^v?b?X?da8&ecPOm7_iYZj&qc>PwB9)Fi&XK`o>z-5-ken* zE85`j;@cnHXLfouzMYSSJW*#do~_$nFdTZXsqy|+{{URQzE{flm);LX(sf=*MKxjM4$oGz8NCcnSgeVeF`5OW8^o(vYBIPB9p|I2TE`3u7&onG+Cnxj{<&14TAZPw1M~H$fzVzq+;rC&eo6$^K#{xp2Y1D{`fdi^F~>mW}O(sqTD0*gL3R?hli2>iT$oN5uS} zwm#mMqia#k8(!G~2Z9RvSmi9-e4OyO-rZiCstkJ7Wzp&Spml?sgJKGX$kAdI*qj@p zolucae@%`@Wl>ssNZ*+K1oS*2GioVih@I`fxJoHV&t)_#xu)d+fouhkyKH1dU(;1c9jpjTux=7b=NBg0;FJ zeuzGQb*=8cW9_1lSSdwHf-P$Wn^rmPL)9l$1wOVA2%Nx7gCWd^mNW*70KV_W=csT_$7~MdxWC);08ps=H zK#M{FEI_=p0$73e4#*o|JrFpIAYNEN-dI4qusR@KSnPp$VFUSL0uclc<%9xQZh^#j z_dxAp2^!IK1{-7q*5Lt$tDs@MP)NhKDv5-R?wSqf3)m;1V2fNEpiVgDSR*^n0|*mo z-t&f+9_naWWe)~dU)fkJ#Tz3QV?aPwK&*kSn*nu2h^bs#x?Sq1v3vz+OkUh{j3T!rS#uW~MZ9ul5aMTNG2iQ_6 z{S?rB-4q9KfyJ%L2e3eZKtLcMFoA%|2ik#cK+YEGgEPc#nYt}s%e=Bt5iKp}@>t~G zjn}tc4%>hBcKroH99TlR0U;>2`b9v)kFGRgKixFpwx8+jddW1p#4 z%G%iTII#(#BWVfC+=52aFmrSh z#T0mLih}bP*0cZ>1!Mw&AZvBN2pP|fxhCncUTXF;0DwJ|*v+cdS_uH>y3NRuky)ye z4c@bIZezAZUZ-J~VRO$nTzK3uYmI>yUN_Q4_MSqht3Kk=(Z>a6FHKKuDyHryrdNxT z9^LV*`9nj9e zCyPIa{{RPl>A!>UGkDY0@YxZW`3b|()6V2{-n2oY{jJ(V^}gn(SoLAV7G1eB1x`j`+` zlT4>44GK|SO(s5CEK@=U>12NyO+XHykYGCLp&O{?9g{)4iHAfCVwd?4KFL3oK(@Se zDuE&Y0LoONH;H;iPvnC-wvMWyrGs-p!kVibZnN;#k5GGLk_)}2;EXO;haK+6s~?x1 zGg7&YSK$R@#JY{FW2{B{j>=`j*G`;&E=&75--W&atMIin&3hTXQVDoDD_k6R3E#GU zU&x%+c>O&NDmj47g6#1oZ7YW*0kW3-loB?Ty>b->DWd|wATbzdzzIWvWdo`XILpn@ zPRN?v+bASJ4sZvMf;UXYi<>6Jc`M^>{X_LqHd2_l2<>dKT;~u_0kCsk$yIZQ9{aYG% zvhebJoU(hk{ysQ08Br9m2eHD=Udd&fmFKfSecXAn(1uvf(k^l{bp|DUn<2LsUQTas zMjf)MY(a4xoE7D+s&jg-4V9IBD!G|Czh$XesAX@f{T7i6noGS-veXg09xrvNKsXSt z(G!3tADShAc&DI`$*+!09kIFbXFxv1hYWv#UrDCLJHK&UaL15eYdo)7V0!) zHShBpJ@UI@@^i(__Km)pso#kl%)+GmL&m7ua_IDOY? zW?r6qjh@?-n!_Dqyl35ARNGFUQfcAex`RRuOn^Oqx(m>1=ei5fYdOu7*t-s^8-_?% zV$^leyN3mKCuKtj`hso(60rXO5KXSnxOY_9W4Z)J5e!-nL=ZkNRt6A8V<7NQ31R^v z-}*qsql5xjJ&-YA_CUO_fwm6A^g!XGut4@5FoCV^fPlIO+JV6KRsw5p(F2w`+!WYB zhTJB@h7aPxn+^nRH@HoNY9nVP1vVg(z9!$)3T#!ERbUuX$Wf@(jmZQlwbF$8n~D0C zcGR0un^DR257kYz;LoYlfNAv))n_cU)WMHmuT#L=zIKE6T(aiIeY@esa`hb6q2l(E zTTIg8I@x&neq8!`_`VL+0IGCh!ZSTtOj^|P2-utahKT?_ThgbqR7r|Y)xw5@f%MDhn6B>v^+}xja z>Yc|r&efRKzl+^)=bDb3b%juD@I9A`^!2Mbt=vf$NeYH*0gPw#Rg?w0@p~)eX#<)! zI^7oaqe$G_7aQ;uh=t<9_vLO)&;VZk$eKhlZN^fwK;wgqt1g+n8=P9!TH1RJlih+p z;<_a>zN__Ddp#$nuB)bSFwXYd-EL=;y;d^A9KFX$qWG^HSV@pM$Woa77t09Cl=*QT$O+aJ2R7vZarpZ8g!!IC8! ze6|>9QqdfhFKFR5D@$7-94Tn2C3My{Q(~mGv!CKAwbG=OlM|EstnGYNrK*(QDs5y{ zlCDb>+fx-es-#nEF;Jkqd}nNz&n(hfSGnJExMh{T-Q<(fW;8Slev3L-m>U6wEz+fdf%Z8U RvTbZ!Z_?lZ7E^5P|JlKaK`8(L literal 0 HcmV?d00001 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f8a788f3..fa475ebb 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -80,6 +80,7 @@ class Models(object): image_casmvs_depth_estimation = 'image-casmvs-depth-estimation' vop_retrieval_model = 'vop-retrieval-model' ddcolor = 'ddcolor' + image_probing_model = 'image-probing-model' defrcn = 'defrcn' image_face_fusion = 'image-face-fusion' ddpm = 'ddpm' @@ -310,6 +311,7 @@ class Pipelines(object): video_panoptic_segmentation = 'video-panoptic-segmentation' vop_retrieval = 'vop-video-text-retrieval' ddcolor_image_colorization = 'ddcolor-image-colorization' + image_structured_model_probing = 'image-structured-model-probing' image_fewshot_detection = 'image-fewshot-detection' image_face_fusion = 'image-face-fusion' ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 44a6896f..0f4f33c2 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -9,13 +9,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_denoise, image_inpainting, image_instance_segmentation, image_matching, image_mvs_depth_estimation, image_panoptic_segmentation, image_portrait_enhancement, - image_quality_assessment_mos, image_reid_person, - image_restoration, image_semantic_segmentation, - image_to_image_generation, image_to_image_translation, - language_guided_video_summarization, movie_scene_segmentation, - object_detection, panorama_depth_estimation, - pointcloud_sceneflow_estimation, product_retrieval_embedding, - realtime_object_detection, referring_video_object_segmentation, + image_probing_model, image_quality_assessment_mos, + image_reid_person, image_restoration, + image_semantic_segmentation, image_to_image_generation, + image_to_image_translation, language_guided_video_summarization, + movie_scene_segmentation, object_detection, + panorama_depth_estimation, pointcloud_sceneflow_estimation, + product_retrieval_embedding, realtime_object_detection, + referring_video_object_segmentation, robust_image_classification, salient_detection, shop_segmentation, super_resolution, video_frame_interpolation, video_object_segmentation, video_panoptic_segmentation, diff --git a/modelscope/models/cv/image_probing_model/__init__.py b/modelscope/models/cv/image_probing_model/__init__.py new file mode 100644 index 00000000..e97a1b77 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import StructuredProbingModel + +else: + _import_structure = { + 'model': ['StructuredProbingModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_probing_model/backbone.py b/modelscope/models/cv/image_probing_model/backbone.py new file mode 100644 index 00000000..8f3ed5b6 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/backbone.py @@ -0,0 +1,308 @@ +# The implementation is adopted from OpenAI-CLIP, +# made pubicly available under the MIT License at https://github.com/openai/CLIP + +import math +import sys +from collections import OrderedDict +from functools import reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import models + +from .utils import convert_weights, load_pretrained + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed + # after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, + # and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :].to(x.dtype) + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor, idx): + features = {} + x_norm = self.ln_1(x) + features['layer_{}_pre_attn'.format(idx)] = x_norm.permute(1, 0, 2) + attn = self.attention(x_norm) + features['layer_{}_attn'.format(idx)] = attn.permute(1, 0, 2) + x = x + attn + mlp = self.mlp(self.ln_2(x)) + features['layer_{}_mlp'.format(idx)] = mlp.permute(1, 0, 2) + x = x + mlp + return x, features + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for i in range(layers): + block = ResidualAttentionBlock(width, heads, attn_mask) + self.resblocks.append(block) + + def forward(self, x: torch.Tensor): + features = {} + for idx, block in enumerate(self.resblocks): + x, block_feats = block(x, idx) + features.update(block_feats) + return x, features + + +class VisualTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + print(input_resolution, patch_size, width, layers, heads, output_dim) + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, return_all=True): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + zeros = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + # shape = [*, grid ** 2 + 1, width] + x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, features = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if return_all: + features['pre_logits'] = x + return features + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIPNet(nn.Module): + + def __init__(self, arch_name, pretrained, **kwargs): + super(CLIPNet, self).__init__() + + if arch_name == 'CLIP_ViTB32': + self.clip = VisualTransformer( + input_resolution=224, + patch_size=32, + width=768, + layers=12, + heads=12, + output_dim=512) + + elif arch_name in ('CLIP_ViTB16', 'CLIP_ViTB16_FP16'): + self.clip = VisualTransformer( + input_resolution=224, + patch_size=16, + width=768, + layers=12, + heads=12, + output_dim=512) + + elif arch_name in ('CLIP_ViTL14', 'CLIP_ViTL14_FP16'): + self.clip = VisualTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768) + + else: + raise KeyError(f'Unsupported arch_name for CLIP, {arch_name}') + + def forward(self, input_data): + output = self.clip(input_data) + return output + + +def CLIP(arch_name='CLIP_RN50', + use_pretrain=False, + load_from='', + state_dict=None, + **kwargs): + model = CLIPNet(arch_name=arch_name, pretrained=None, **kwargs) + if use_pretrain: + if arch_name.endswith('FP16'): + convert_weights(model.clip) + load_pretrained(model.clip, state_dict, load_from) + return model + + +class ProbingModel(torch.nn.Module): + + def __init__(self, feat_size, num_classes): + super(ProbingModel, self).__init__() + self.linear = torch.nn.Linear(feat_size, num_classes) + + def forward(self, x): + return self.linear(x) diff --git a/modelscope/models/cv/image_probing_model/model.py b/modelscope/models/cv/image_probing_model/model.py new file mode 100644 index 00000000..e7636f40 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/model.py @@ -0,0 +1,93 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +from typing import Any, Dict + +import json +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import CLIP, ProbingModel + + +@MODELS.register_module( + Tasks.image_classification, module_name=Models.image_probing_model) +class StructuredProbingModel(TorchModel): + """ + The implementation of 'Structured Model Probing: Empowering + Efficient Adaptation by Structured Regularization'. + """ + + def __init__(self, model_dir, *args, **kwargs): + """ + Initialize a probing model. + Args: + model_dir: model id or path + """ + super(StructuredProbingModel, self).__init__() + model_dir = os.path.join(model_dir, 'food101-clip-vitl14-full.pt') + model_file = torch.load(model_dir) + self.feature_size = model_file['meta_info']['feature_size'] + self.num_classes = model_file['meta_info']['num_classes'] + self.backbone = CLIP( + 'CLIP_ViTL14_FP16', + use_pretrain=True, + state_dict=model_file['backbone_model_state_dict']) + self.probing_model = ProbingModel(self.feature_size, self.num_classes) + self.probing_model.load_state_dict( + model_file['probing_model_state_dict']) + + def forward(self, x): + """ + Forward Function of SMP. + Args: + x: the input images (B, 3, H, W) + """ + + keys = [] + for idx in range(0, 24): + keys.append('layer_{}_pre_attn'.format(idx)) + keys.append('layer_{}_attn'.format(idx)) + keys.append('layer_{}_mlp'.format(idx)) + keys.append('pre_logits') + features = self.backbone(x.half()) + features_agg = [] + for i in keys: + aggregated_feature = self.aggregate_token(features[i], 1024) + features_agg.append(aggregated_feature) + features_agg = torch.cat((features_agg), dim=1) + outputs = self.probing_model(features_agg.float()) + return outputs + + def aggregate_token(self, output, target_size): + """ + Aggregating features from tokens. + Args: + output: the output of intermidiant features + from a ViT model + target_size: target aggregated feature size + """ + if len(output.shape) == 3: + _, n_token, channels = output.shape + if channels >= target_size: + pool_size = 0 + else: + n_groups = target_size / channels + pool_size = int(n_token / n_groups) + + if pool_size > 0: + output = torch.permute(output, (0, 2, 1)) + output = torch.nn.AvgPool1d( + kernel_size=pool_size, stride=pool_size)( + output) + output = torch.flatten(output, start_dim=1) + else: + output = torch.mean(output, dim=1) + output = torch.nn.functional.normalize(output, dim=1) + return output diff --git a/modelscope/models/cv/image_probing_model/utils.py b/modelscope/models/cv/image_probing_model/utils.py new file mode 100644 index 00000000..c2b13ae5 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/utils.py @@ -0,0 +1,148 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import re + +import torch +import torch.nn as nn + + +def load_pretrained(model: torch.nn.Module, + state_dict, + local_path: str, + map_location='cpu', + logger=None, + sub_level=None): + return load_pretrained_dict(model, state_dict, logger, sub_level=sub_level) + + +def load_pretrained_dict(model: torch.nn.Module, + state_dict: dict, + logger=None, + sub_level=None): + """ + Load parameters to model with + 1. Sub name by revise_keys For DataParallelModel or DistributeParallelModel. + 2. Load 'state_dict' again if possible by key 'state_dict' or 'model_state'. + 3. Take sub level keys from source, e.g. load 'backbone' part from a classifier into a backbone model. + 4. Auto remove invalid parameters from source. + 5. Log or warning if unexpected key exists or key misses. + + Args: + model (torch.nn.Module): + state_dict (dict): dict of parameters + logger (logging.Logger, None): + sub_level (str, optional): If not None, parameters with key startswith sub_level will remove the prefix + to fit actual model keys. This action happens if user want to load sub module parameters + into a sub module model. + """ + revise_keys = [(r'^module\.', '')] + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + + for p, r in revise_keys: + state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()} + + if sub_level: + sub_level = sub_level if sub_level.endswith('.') else (sub_level + '.') + sub_level_len = len(sub_level) + state_dict = { + key[sub_level_len:]: value + for key, value in state_dict.items() if key.startswith(sub_level) + } + + state_dict = _auto_drop_invalid(model, state_dict, logger=logger) + + load_status = model.load_state_dict(state_dict, strict=False) + unexpected_keys = load_status.unexpected_keys + missing_keys = load_status.missing_keys + err_msgs = [] + if unexpected_keys: + err_msgs.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msgs.append('missing key in source ' + f'state_dict: {", ".join(missing_keys)}\n') + err_msgs = '\n'.join(err_msgs) + + if len(err_msgs) > 0: + if logger: + logger.warning(err_msgs) + else: + import warnings + warnings.warn(err_msgs) + + +def convert_weights(model: nn.Module): + """ + Convert applicable model parameters to fp16. + """ + + def _convert_weights_to_fp16(layer): + if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)): + layer.weight.data = layer.weight.data.half() + if layer.bias is not None: + layer.bias.data = layer.bias.data.half() + + if isinstance(layer, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(layer, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(layer, name): + attr = getattr(layer, name) + if attr is not None: + attr.data = attr.data.half() + + for name in ['prompt_embeddings']: + if hasattr(layer, name): + attr = getattr(layer, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def _auto_drop_invalid(model: torch.nn.Module, state_dict: dict, logger=None): + """ + Strip unmatched parameters in state_dict, e.g. shape not matched, type not matched. + + Args: + model (torch.nn.Module): + state_dict (dict): + logger (logging.Logger, None): + + Returns: + A new state dict. + """ + ret_dict = state_dict.copy() + invalid_msgs = [] + for key, value in model.state_dict().items(): + if key in state_dict: + # Check shape + new_value = state_dict[key] + if value.shape != new_value.shape: + invalid_msgs.append( + f'{key}: invalid shape, dst {value.shape} vs. src {new_value.shape}' + ) + ret_dict.pop(key) + elif value.dtype != new_value.dtype: + invalid_msgs.append( + f'{key}: invalid dtype, dst {value.dtype} vs. src {new_value.dtype}' + ) + ret_dict.pop(key) + if len(invalid_msgs) > 0: + warning_msg = 'ignore keys from source: \n' + '\n'.join(invalid_msgs) + if logger: + logger.warning(warning_msg) + else: + import warnings + warnings.warn(warning_msg) + return ret_dict diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d5839eab..c37a5630 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -86,6 +86,7 @@ if TYPE_CHECKING: from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline + from .image_structured_model_probing_pipeline import ImageStructuredModelProbingPipeline from .video_colorization_pipeline import VideoColorizationPipeline from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline @@ -207,6 +208,9 @@ else: 'ddcolor_image_colorization_pipeline': [ 'DDColorImageColorizationPipeline' ], + 'image_structured_model_probing_pipeline': [ + 'ImageSturcturedModelProbingPipeline' + ], 'video_colorization_pipeline': ['VideoColorizationPipeline'], 'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'], 'image_quality_assessment_mos_pipeline': [ diff --git a/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py b/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py new file mode 100644 index 00000000..bc2561e2 --- /dev/null +++ b/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py @@ -0,0 +1,79 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math +import os +import os.path as osp +from typing import Any, Dict + +import numpy as np +import torch +import torchvision.transforms as transforms +from mmcv.parallel import collate, scatter + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, + module_name=Pipelines.image_structured_model_probing) +class ImageStructuredModelProbingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a vision middleware pipeline for prediction + Args: + model: model id on modelscope hub. + Example: + >>> from modelscope.pipelines import pipeline + >>> recognition_pipeline = pipeline(self.task, self.model_id) + >>> file_name = 'data/test/images/\ + image_structured_model_probing_test_image.jpg' + >>> result = recognition_pipeline(file_name) + >>> print(f'recognition output: {result}.') + """ + super().__init__(model=model, **kwargs) + self.model.eval() + model_dir = os.path.join(model, 'food101-clip-vitl14-full.pt') + model_file = torch.load(model_dir) + self.label_map = model_file['meta_info']['label_map'] + logger.info('load model done') + + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_img(input) + + data = self.transform(img) + data = collate([data], samples_per_gpu=1) + if next(self.model.parameters()).is_cuda: + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + results = self.model(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + scores = torch.softmax(inputs, dim=1).cpu() + labels = torch.argmax(scores, dim=1).cpu().tolist() + label_names = [self.label_map[label] for label in labels] + + return {OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores} diff --git a/tests/pipelines/test_image_structured_model_probing.py b/tests/pipelines/test_image_structured_model_probing.py new file mode 100644 index 00000000..563e131c --- /dev/null +++ b/tests/pipelines/test_image_structured_model_probing.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageStructuredModelProbingTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_classification + self.model_id = 'damo/structured_model_probing' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + + recognition_pipeline = pipeline(self.task, self.model_id) + file_name = 'data/test/images/image_structured_model_probing_test_image.jpg' + result = recognition_pipeline(file_name) + + print(f'recognition output: {result}.') + + +if __name__ == '__main__': + unittest.main()