diff --git a/data/derived/100_freq_embeddings.pkl b/data/derived/100_freq_embeddings.pkl new file mode 100644 index 0000000..242f4d9 Binary files /dev/null and b/data/derived/100_freq_embeddings.pkl differ diff --git a/data/derived/491_within_n_across_disps.csv b/data/derived/491_within_n_across_disps.csv new file mode 100644 index 0000000..aade892 --- /dev/null +++ b/data/derived/491_within_n_across_disps.csv @@ -0,0 +1,492 @@ +within_verbs,across_verbs +0.605480839972912,0.725044923255003 +0.49736222140377945,0.631654891899663 +0.47168608927506156,0.6138123822114263 +0.4742792113299598,0.7023129564307181 +0.4169654134273455,0.7616582546135822 +0.5073276350776147,0.7010983548259359 +0.5080952583325881,0.6184856455982031 +0.6336920196694058,0.7135379130646389 +0.5096374318767398,0.6247402327348868 +0.44965059748390324,0.6823137186224669 +0.5177768097930858,0.7348989092314762 +0.5264673201782423,0.5656085507436923 +0.45703008446181675,0.6753671791010405 +0.48451652957508223,0.675969522638134 +0.5793227974132165,0.7934024646556234 +0.4913358732252781,0.6095717139270582 +0.45817187199836795,0.64373781232785 +0.4859303505313909,0.6564623362610232 +0.5217150473881683,0.6633335815310539 +0.45370235145655574,0.7058189496751563 +0.520530325532616,0.6495729620128191 +0.4853643349500969,0.7018000450919247 +0.480950212116653,0.5894266788293532 +0.44550003460714416,0.62064761324277 +0.607334509814434,0.6488159916632716 +0.4877989698269567,0.5836701510547643 +0.4826012084755831,0.6052885169187038 +0.491130469858808,0.738665553634054 +0.5283830946696207,0.7423043754047103 +0.6058576848053016,0.6718158813486695 +0.49611379015129337,0.6865991262904046 +0.45995816364604886,0.74150820599475 +0.4693663152080771,0.5758253147021858 +0.465618984115237,0.6288827659106568 +0.4275132711943728,0.6064278036275366 +0.522436400030277,0.8633923120617873 +0.4874375235782103,0.6406178862813954 +0.4750879879671092,0.7103825562514424 +0.43435649184083336,0.7567480449667344 +0.47588486149035236,0.6669337847163456 +0.5122148682566616,0.7088152200286199 +0.5278252347390088,0.7557996255172601 +0.5121233945081163,0.7106874805636786 +0.45985626317125283,0.7479672908574394 +0.49211643399141797,0.6178351055158131 +0.4663617291798036,0.690202691786898 +0.4726400200652983,0.6449681048260761 +0.4794069529899159,0.6749247992817111 +0.32895459201304594,0.8163349355607313 +0.4669423497568996,0.7065692348167673 +0.556727180680389,0.6460490110795182 +0.5118774918835025,0.6106099825723187 +0.45413602944260634,0.6372996361435752 +0.42567105798092547,0.7766913773120602 +0.5150247407300836,0.793609809691233 +0.5398529540916295,0.7383232563468183 +0.5028750299429781,0.6800528762773124 +0.3581015171984239,0.6377255102090702 +0.5227575192012937,0.6733601957479756 +0.5561123224654,0.6870150416476007 +0.5292444825767806,0.7455346450004855 +0.5340326012837339,0.7051355779081702 +0.4675557089419409,0.7572185625207623 +0.46524793800760417,0.6744336427288068 +0.4912697940303896,0.6656903958528677 +0.43641262967072025,0.6949632756319819 +0.4703460579253098,0.6917245639856182 +0.5019123156227828,0.6705063329217806 +0.48739246003339043,0.61614298564692 +0.4567562698403358,0.7723161076764565 +0.3929110063584417,0.8134693933213919 +0.4535871712949118,0.7127874987153471 +0.5557559228921515,0.6333975467102619 +0.48030355881692355,0.7895453070079386 +0.44636368750524097,0.7812104003247197 +0.4450164826288345,0.6219141498532731 +0.36083337421876777,0.6886393143430082 +0.33826329175074293,0.5892837933941131 +0.4491586433039609,0.6886572009849073 +0.3400201248591462,0.6693439430473546 +0.5571581929004141,0.7112815305580468 +0.4569447371083403,0.7650726020960796 +0.46033377732874714,0.7453510224901295 +0.44268074752299935,0.7030451135368131 +0.4746925279879096,0.7014440156208497 +0.5193301186750272,0.6274909187742723 +0.44953475131540727,0.6884262541419361 +0.4065687938934455,0.6521468055941153 +0.4815863944197614,0.5871107344493782 +0.5096991073374229,0.655158528103925 +0.4708724170991446,0.6933973547197685 +0.47086235510365,0.7114821521021251 +0.4544683681732724,0.6915068684665931 +0.4407412055689817,0.6581853768507563 +0.3648591879017801,0.6384008360322971 +0.5617535875444779,0.7653312128224149 +0.3910765851239329,0.6466664058114333 +0.4111000573165128,0.5544918087487614 +0.4523271254093024,0.6336410628792969 +0.42161837033265237,0.6975447261250193 +0.4809417385156048,0.6860294254949358 +0.5508333930969999,0.6563558619536491 +0.3977168026923137,0.6488118554755715 +0.49910387271882295,0.7579884987929474 +0.37622353737507935,0.6361923092558474 +0.3574804529704139,0.6420250567314957 +0.42384896262923943,0.6654405685528438 +0.4317548645465034,0.5995038701009657 +0.5387502473946272,0.6909213572010845 +0.4947213495323156,0.7816281935184247 +0.44802125929586695,0.6376577751780176 +0.4846832750231929,0.6537935489957094 +0.4600624238588429,0.7053979706279919 +0.46100788617193844,0.6257681827834849 +0.377514311859985,0.6457564538259768 +0.4578190858733184,0.6995302979249203 +0.41874412472653233,0.6728141720798153 +0.473149252233069,0.7400684749188887 +0.5003692414253775,0.6579748519226128 +0.38671387150327524,0.6507910260223319 +0.4329288435011897,0.7554428872961404 +0.38937738950985945,0.6632714422558932 +0.507866448194171,0.6697989654204488 +0.4503577175594915,0.6436987778546469 +0.4004246379411161,0.7084190193924603 +0.5175674155634723,0.592000461758388 +0.49096655150152685,0.7179735738467282 +0.40157507962559785,0.724116842640191 +0.4996563962257568,0.7219602737166688 +0.4091447739402725,0.6607966997588274 +0.4908000549408518,0.654323415307034 +0.4029540032740767,0.7752269484038838 +0.4896063399150976,0.7151692301592091 +0.37365763622593734,0.6605479646305524 +0.45346212226665594,0.7753901076453064 +0.33296084571574036,0.614266466139344 +0.40595239179797177,0.7232822774566101 +0.4512160027410494,0.704067822452151 +0.387437481995008,0.690940315995756 +0.3712252979731263,0.6892111587672914 +0.5711661659061097,0.703213254888616 +0.46938253599058066,0.6541600457644752 +0.4664196694872788,0.5938181337796151 +0.4682109685010753,0.6839917471696231 +0.4012940182279379,0.6963437212566033 +0.5566515181145882,0.6597889423241242 +0.5103069429820604,0.6485648974692793 +0.5547128212607425,0.7739386059086971 +0.40791332374782513,0.7189626112158645 +0.43831715375374275,0.713504833105968 +0.37238584111033374,0.6988144453295679 +0.48974065751301804,0.6958484714351436 +0.5692746243412625,0.7403417508986659 +0.41830905885228964,0.7634910496425027 +0.42199736012733086,0.7174980202292595 +0.49464193272381196,0.7361537788265963 +0.40907087347300763,0.6947946456557822 +0.5083428799414953,0.716828781048775 +0.4831330925746159,0.6247491194793605 +0.44865348138723904,0.5758531785946958 +0.4577811559637835,0.7379324357150788 +0.522405194942378,0.6737231435721743 +0.3862756721518123,0.6616181297653976 +0.4839168415317975,0.7271759685965089 +0.4110369138397671,0.7350427686567975 +0.40245055365606325,0.5796611914894637 +0.46790976219762026,0.672467715775991 +0.33134028011154776,0.7404073648550494 +0.4467874961511285,0.7531835861421509 +0.4088755433995198,0.7282914303989606 +0.4207078706101553,0.5930391708000566 +0.4117520706022664,0.6670040951943181 +0.42262382063582715,0.6521819959973684 +0.5139962578168572,0.6383414092367927 +0.3548675056073501,0.5756398781300174 +0.42737402271464386,0.6247226107248017 +0.4029245441915341,0.647720891088715 +0.4782163437043551,0.6493472650600458 +0.4084167466961997,0.6582527366417187 +0.4475291225009991,0.72742828667379 +0.5134000921112114,0.697318202211257 +0.3327686918306017,0.670927312357904 +0.542421551246306,0.7508850033882188 +0.41017350697467575,0.7686777031689295 +0.5334854982285328,0.6520753496640578 +0.5158682775368698,0.6457766299734807 +0.461559559918012,0.7704836184047354 +0.463496040926273,0.6401845701253146 +0.4665695906964886,0.7419659836170712 +0.4177874796925448,0.631331360120928 +0.4703878859858616,0.6346517952850421 +0.46087266397833115,0.700380940984194 +0.411991398297398,0.6685049441165432 +0.4822643197177507,0.7051680221161787 +0.4517640130810059,0.7408881809307233 +0.4321050011953649,0.6223072091852055 +0.5107466717954819,0.6079034790413904 +0.35654712490855694,0.621035503229739 +0.47036771057183174,0.6607480721163265 +0.3970297011370373,0.5894988720629702 +0.5599467559265916,0.6983675149602062 +0.46535963592278895,0.6969992977512803 +0.41361443993831637,0.6222560626131873 +0.4401822208387065,0.6218675074291452 +0.4432148267348234,0.541456109615857 +0.3501775777648271,0.7014905938688097 +0.44790214814925716,0.6018382883180227 +0.3962874390994017,0.5915417550719401 +0.4863839234197195,0.6625220093412677 +0.44006434290460616,0.7511768119652019 +0.45087688544555077,0.6053088424500135 +0.5067792135238026,0.6456383855686012 +0.5015016746613962,0.681988995962167 +0.4772336166423342,0.6484208029970459 +0.4173796169176804,0.7072181642840298 +0.4165336020457162,0.680543214775675 +0.3661609696700325,0.6608371913564056 +0.391987820121063,0.7236968493302014 +0.3892719685118593,0.7788697431673581 +0.5301085138999533,0.7626243342095277 +0.40200739338602365,0.6926561696541189 +0.4455053646464521,0.6246017619993043 +0.37040012916985976,0.7379868002866075 +0.4467781146748516,0.7216118393512461 +0.37154841956817497,0.6817295372907666 +0.4283913961094055,0.6845269183355495 +0.4702484572609351,0.6914749284126309 +0.3250931779686398,0.709621715508677 +0.4475749519391271,0.5679401498619403 +0.3882477519584314,0.699734760648275 +0.47159840677368403,0.6802578520016836 +0.41909666112789046,0.6190248792008989 +0.4006203140881012,0.66626339061815 +0.43973851716347107,0.7711336279177594 +0.3828021992232982,0.5882801305744978 +0.48363001474834144,0.7406587809313476 +0.43276226247017857,0.605605136421127 +0.4134575158857422,0.5839298825616782 +0.44400261430935034,0.6465157306811863 +0.4469422220118166,0.6102083504533458 +0.42724207569630906,0.6624145462549679 +0.39839353854392734,0.7422760205833286 +0.4110615323138288,0.6776750744430879 +0.5228710908838383,0.6133393705720591 +0.3718251851015608,0.6014913801460962 +0.47173545138570167,0.7384099486166896 +0.314055821103469,0.7420212466178266 +0.3568781263819845,0.7334176056287606 +0.36628422285571244,0.5335654862619377 +0.401689334888296,0.49793864189914905 +0.4394530071033062,0.6866140374925823 +0.4858423959549596,0.6015158082597444 +0.4010009213773241,0.7007909137939402 +0.4147436144935248,0.5780459509303775 +0.4325167875234476,0.7247228336126706 +0.5255723955954461,0.6369121713007524 +0.45691415833608584,0.6725831607799382 +0.48701232313631854,0.6231924690328337 +0.4104198083234973,0.6346867369681876 +0.4154368871178765,0.7859391905024669 +0.42651858714997465,0.6349205353041888 +0.48696750152345664,0.6741890539200615 +0.4651053785619632,0.6878417474715798 +0.30544723084808073,0.6411847677172069 +0.5303273754144491,0.6254257774800678 +0.4153165184126042,0.6132076719575392 +0.4273529921890888,0.7251932559382348 +0.4432850675573397,0.7039513592146693 +0.39259405956005106,0.6752339140851509 +0.41699584433004183,0.654170912515384 +0.4378219504936188,0.693439988514547 +0.31074410517129336,0.6433117321786577 +0.4379154266725257,0.7329514302320006 +0.41838542638802123,0.7552836784227681 +0.4258031164208473,0.6618401320220635 +0.41051809655809735,0.7408174411896121 +0.44127126368588554,0.6965947848753219 +0.3879306405891426,0.6900284837385406 +0.3526267071505626,0.6505614909934571 +0.4547952257966208,0.6419340168983827 +0.43942794506539207,0.6772180750347245 +0.43904410066043037,0.7204214852544447 +0.37094734277040536,0.6884043051213539 +0.4804032025669604,0.6361341500395522 +0.4748681695187831,0.7038381554545363 +0.3752978586523322,0.6245646819364572 +0.3604660202267007,0.7102459866672436 +0.4313042617858951,0.6414565748978879 +0.3491083548036878,0.6836895447655255 +0.5063967905582591,0.7198761534611415 +0.5022123443442756,0.6325419073307752 +0.4514332279853225,0.7527208421769097 +0.4539609570140348,0.6856336671542795 +0.46024996276761243,0.670160098014123 +0.4032590875644408,0.6231571902786621 +0.5042226199579852,0.6581530974385711 +0.3749593199345647,0.5859377714566751 +0.474002091440451,0.5987310819504316 +0.3753882692708233,0.6293707706548016 +0.402447506553297,0.6165443835253203 +0.32945609530556363,0.7797694581363979 +0.33705929560993114,0.704550266120669 +0.3259150315399129,0.6806397994818961 +0.4074551472030225,0.6870636885356911 +0.4358242407512993,0.7380298780080875 +0.392500806676426,0.5452713712173056 +0.3725922709887964,0.6541188379541371 +0.38545822238503524,0.6738650951992445 +0.41487752507585124,0.6796227412409733 +0.4808943021452444,0.6848458788853872 +0.4709071139801141,0.7141419969359782 +0.37427889497038747,0.5914655527513522 +0.40032897214204166,0.684157089975876 +0.37819282887819583,0.6674277224501204 +0.36388799507810266,0.6709420745813102 +0.42700816749942666,0.7376558015111455 +0.4458915308306818,0.6856536890990212 +0.43924881033949437,0.6813977909855944 +0.40495675320442287,0.6911519270636368 +0.540182774807035,0.6416565314633761 +0.44648993367211515,0.5956692254378788 +0.4614870741199872,0.7765101484959016 +0.4437215349616229,0.6864939253515959 +0.3480917485915793,0.6471402273770314 +0.318324820026019,0.691188692620949 +0.3784229521039853,0.695009384357639 +0.3985385731774097,0.6323510482356004 +0.37871609434592923,0.6464530754139789 +0.420537273088553,0.7262314999568562 +0.4464314909253557,0.6809161443292338 +0.2937141301156161,0.5921438712272182 +0.332501602418597,0.6385528146702355 +0.2683299019117491,0.7170842964628498 +0.4768886642968768,0.6712431158219913 +0.4243744129628049,0.669719833562517 +0.4112139009604527,0.6874873347067759 +0.4258918483361846,0.5046472638000384 +0.3706097657119352,0.6529285814536034 +0.3720089400729066,0.6490451602352165 +0.2910956531664201,0.5823520420865577 +0.3288821702409864,0.6862355030747569 +0.4608462213552861,0.643857548934823 +0.38719954965418146,0.6683210889253145 +0.4385585385046104,0.6681024089102245 +0.40826224600084876,0.5765348411523948 +0.502343610548671,0.6942562016732381 +0.4376420809921928,0.5858832581699851 +0.322886775456483,0.7642576280788437 +0.4629167293407771,0.7076195499950553 +0.25648699420986293,0.5819044791657042 +0.384952427932897,0.6619620804121928 +0.4825520353996003,0.7464263651272439 +0.39743901189155795,0.6390506192677402 +0.4491942924590527,0.7008493719991722 +0.44547117320838325,0.6059368917360985 +0.5020909391761998,0.7573940243942867 +0.4411911454870951,0.6594391839994898 +0.35972294488000905,0.680168975847298 +0.3495776972600118,0.6588697043517336 +0.4706264783666526,0.6822399413802366 +0.4575663986142285,0.6885515052243139 +0.3999867850487907,0.7143818747114001 +0.3868416440584384,0.6920739438946413 +0.4475814495780464,0.6728913256262967 +0.4211681455642312,0.726190895065533 +0.45139797371520496,0.6266613200385976 +0.42749077245206824,0.5824493413667338 +0.4363580018747338,0.6427561356450857 +0.43631441003877325,0.6121405361172033 +0.4210684257053611,0.6080818246820725 +0.43818950594823103,0.6212786570291743 +0.3902767156628187,0.6989424157102436 +0.40192102594779255,0.6927744174044579 +0.32510487049903625,0.5831071309501307 +0.4112817335428219,0.6036841773257691 +0.27578499551928565,0.6315093433253087 +0.3643091828333205,0.7013233681277065 +0.39773286987510925,0.6984213247218786 +0.4072921305080137,0.6605768914176604 +0.3805568403549015,0.6374875389036991 +0.39579939722026947,0.6368387996482997 +0.49078203854883956,0.6749652111273917 +0.43452132343707983,0.6897640352302 +0.3626592265552786,0.657312666632607 +0.33541375769575726,0.6809286383957143 +0.2872490782172016,0.7091035570866387 +0.3835868858416108,0.7170936023183017 +0.4055573008609051,0.7534918452836743 +0.36120720128130757,0.6604007209745034 +0.36458314732269825,0.694152449377389 +0.4964665907512039,0.7230510993199195 +0.30779629013493404,0.6286074241330063 +0.47585143764389853,0.6120085357977869 +0.28269331375125745,0.6564745204352749 +0.44064895563133694,0.6413320087726189 +0.38767606692636314,0.6936527105292679 +0.4399126018576885,0.6458326449334223 +0.35960195987856514,0.7131802117187407 +0.4440495963243875,0.7163923836506433 +0.3938313617688598,0.6342305966232035 +0.3599330528518706,0.6348398261177491 +0.47262295305947055,0.6789989794489294 +0.3653900747625449,0.580416316921563 +0.39928074910856204,0.7833512251615905 +0.34715152754608775,0.7108957680562618 +0.3949075686072314,0.6840850731104757 +0.37490224472644007,0.588320271354348 +0.3393764315660381,0.6329918882347758 +0.3816604842909984,0.5978764250535629 +0.4739573512039859,0.7324331324123289 +0.3361253977731877,0.6195327388800251 +0.28878528715556,0.6532902310714384 +0.446517794219802,0.677241436838184 +0.3569465341673962,0.711526786848606 +0.09904138325179207,0.6888696627981844 +0.3694819894944994,0.7058785588490529 +0.3999237917162196,0.7264527799322226 +0.4620937048073881,0.6769205500448008 +0.3226403080004566,0.6508420498472192 +0.4478549504777613,0.605872134392586 +0.4344638475943639,0.6038073993896459 +0.384952427932897,0.6005010000979168 +0.2858920308035506,0.7575404253712411 +0.4619957231871289,0.6656746981333183 +0.3808520637291561,0.7459199590514725 +0.3626965609998325,0.6374949100276223 +0.3156695348284401,0.6221098998857035 +0.35216008587677455,0.6307232238446118 +0.31961147685815283,0.7036006987600455 +0.46593227898830486,0.6728317645770132 +0.4431557614474359,0.6842628680958961 +0.20771332545327162,0.7235942521697355 +0.29647574090828765,0.6791536723871815 +0.3622954466397087,0.6411763630136317 +0.2516864467436095,0.6829493055714859 +0.3525364278117469,0.6328894318238335 +0.3233488079793177,0.6371572986444344 +0.4433135163993962,0.7211199459527639 +0.3974874687136971,0.675857173560631 +0.36599666093940814,0.7177726764805604 +0.4136557902862971,0.6306205992381831 +0.324702451208668,0.6383390560921148 +0.36870173937243716,0.5336972007917005 +0.2634496678777473,0.7139459549497681 +0.40030148824628514,0.534450427009259 +0.2823007282228538,0.685692783261218 +0.38571132627907767,0.6637856465714131 +0.34246705708069414,0.638110040498204 +0.4636648106328049,0.7922432549025404 +0.4471763032792981,0.6051288405638067 +0.451588483448622,0.7138932699404216 +0.3909026341623874,0.6534206819580887 +0.38639537554502185,0.6166020483188487 +0.3519221314944825,0.6761171167634745 +0.3977425985749451,0.6445348475529444 +0.41855020982737645,0.66109572546289 +0.4032770580314131,0.657534645430372 +0.40435384732883023,0.6547536355955197 +0.4496839908742663,0.6825280800867187 +0.30322485514096936,0.7408376104122038 +0.237710770696659,0.7125461958832695 +0.3625794963151563,0.6382131356543783 +0.3388994983211692,0.6568785394723037 +0.39594947555479904,0.6173107052375011 +0.3904468576682564,0.6625875297533569 +0.4074209783203813,0.7148809549422189 +0.37425050712880487,0.633072059317657 +0.36888240875524125,0.5804305191251875 +0.37977564981070994,0.6852620979660202 +0.40910389599122104,0.5175842960078423 +0.32872678270800365,0.724180941586436 +0.30937109739105845,0.673869753684424 +0.4626102897888014,0.6284981339699873 +0.3728933031274825,0.6996243304049029 +0.371293488940554,0.6504547061489132 +0.4644627104469697,0.6496135416358549 +0.3677529799107022,0.6103225382953412 +0.3420542335760234,0.7066333808873914 +0.44698164662538475,0.48824363447870156 +0.30546080222247857,0.5834806278985212 +0.39736913544922614,0.6870741770969471 +0.28332634253316963,0.576837929787108 +0.4030382823300008,0.6659948066609327 +0.45954305512528937,0.7131022324025604 +0.3283838166414001,0.6926718598484425 +0.3588280405958526,0.7186070394509304 +0.3516939832552742,0.6673542912886128 +0.283266582717373,0.6766299135883178 +0.3204124739472318,0.7247314078462532 +0.33904636400740623,0.7259311481840701 +0.2133236517188936,0.6614990678580702 diff --git a/data/raw/verb_counts.csv.gz b/data/raw/verb_counts.csv.gz new file mode 100644 index 0000000..2e311f9 Binary files /dev/null and b/data/raw/verb_counts.csv.gz differ diff --git a/pelinker/util.py b/pelinker/util.py index 6decbe4..73f9c9c 100644 --- a/pelinker/util.py +++ b/pelinker/util.py @@ -3,9 +3,10 @@ import re from string import punctuation, whitespace from transformers import AutoModel, AutoTokenizer -from sentence_transformers import SentenceTransformer +#from sentence_transformers import SentenceTransformer import faiss import torch +import numpy as np import pandas as pd from sklearn.metrics import accuracy_score @@ -25,7 +26,8 @@ def load_models(model_type, sentence=False): else: raise ValueError(f"{model_type} unsupported") if sentence: - tokenizer, model = None, SentenceTransformer(spec) + #tokenizer, model = None, SentenceTransformer(spec) + pass else: tokenizer, model = ( AutoTokenizer.from_pretrained(spec), @@ -171,6 +173,78 @@ def process_text(text, tokenizer, model, nlp, max_length=None, extra_context=Fal return sents_agg, sent_spans, tt + +def get_verbal_embedding(text_df, column, vb, tokenizer, model, nlp, layers): + '''The texts in `column` of `text_df` should be capitalized, otherwise the sentence-splitter + won't work. ''' + + + # split texts into sentences and only choose those sentences with + # different tenses of the target verb (exceptions raise for up/downregular and overexpress) + if vb in ['upregulate', 'downregulate']: + vb_t = 'regulate' + vforms = np.unique([nlp(vb_t)[0]._.inflect(x) for x in ["VBZ", "VBG", "VBP", "VBD", "VBN"]]) + vforms = np.array(['up'+x for x in vforms]) + elif vb=='overexpress': + vb_t = 'express' + vforms = np.unique([nlp(vb_t)[0]._.inflect(x) for x in ["VBZ", "VBG", "VBP", "VBD", "VBN"]]) + vforms = np.array(['over'+x for x in vforms]) + + else: + vforms = np.unique([nlp(vb)[0]._.inflect(x) for x in ["VBZ", "VBG", "VBP", "VBD", "VBN"]]) + + + # splitting into sentences + sent_splitter = lambda x: split_into_sentences(x) + split_df = text_df[column].apply(sent_splitter).explode() + # once split, we lower the sentences + split_df = split_df.str.lower() + contained_df = split_df[split_df.str.contains(fr"\b(?:{'|'.join(vforms)})\b")] + # removing too long texts + contained_df = contained_df[contained_df.str.len()<1500] + # removing texts with too many non-english characters + engmatch = re.compile(r'[^a-z0-9 .,-]') + bad_char_cnts = contained_df.apply(lambda x: len(engmatch.findall(x))) / contained_df.str.len() + contained_df = contained_df[bad_char_cnts < 0.5] + + if len(contained_df)==0: + return np.nan + + # we capitalize back the texts so that they could be processed + text = contained_df.str.capitalize() + text = ' '.join(text.tolist()) + + # getting the embeddings + sagg, sspan, tt = process_text(text, tokenizer, model, nlp) + + + emb_inds = [] + embs = [] + word_inds = [] + contexts = [] + for i in range(len(sspan)): + for j in range(len(sspan[i])): + # structure of sspan[i]: [( + # (string bounds of j-th verb in string sagg[i]), + # [indices of embeddings of (sub)token(s) of the j-th verb] + # )] + + # we only consider verbs that we are focusing + str_bounds = sspan[i][j][0] + sent_vb = sagg[i][str_bounds[0]:str_bounds[1]].lower() + split_vb = sent_vb.split(' ') + if np.any(np.isin(split_vb, vforms)): + emb_inds += [(i, sspan[i][j][1])] + # first averaging across layers, then averaging across (sub)token indices + # at the end, we will get 768-D vector with ndim=1 + embs += [tt[layers, i, :, :][:,emb_inds[-1][1],:].mean(0).mean(0)] + + word_inds += [sspan[i][j][0]] + contexts += [sagg[i]] + + return word_inds, contexts, torch.stack(embs, axis=0) + + def sentence_ix(sent, nlp, token_offsets, extra_context=False): spans = get_vb_spans(nlp, sent, extra_context=extra_context) @@ -266,8 +340,25 @@ def embedding_to_dist(tt_x, tt_y): return dfa +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + def encode(texts, tokenizer, model, ls): if ls == "sent": + # encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') + # + # # Compute token embeddings + # with torch.no_grad(): + # model_output = model(**encoded_input) + # tt_labels = mean_pooling(model_output, encoded_input['attention_mask']) + tt_labels = model.encode(texts, normalize_embeddings=True) else: tt_labels_layered, labels_spans = text_to_tokens_embeddings( diff --git a/run/analysis/plot_discrim_dist.py b/run/analysis/plot_discrim_dist.py new file mode 100644 index 0000000..7595c5c --- /dev/null +++ b/run/analysis/plot_discrim_dist.py @@ -0,0 +1,82 @@ +# pylint: disable=E1120 + +import click +import faiss +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import numpy as np + +from pelinker.util import load_models, encode +from pelinker.preprocess import pre_process_properties + + +@click.command() +@click.option( + "--model-type", + type=click.STRING, + default=["biobert"], + multiple=True, + help="run over BERT flavours", +) +def run(model_type): + model_type = sorted(model_type) + fig_path = "./figs" + + df0 = pd.read_csv("data/derived/properties.synthesis.csv") + + report = pre_process_properties(df0) + labels = report.pop("labels") + + layers = ( + # [[-x] for x in range(1, 6)] + + # [[-x for x in range(1, 4)]] + + [[-1]] + + + # [[-1, -2, -8, -9]] + + ["sent"] + ) + + df_agg = [] + + for mt in model_type: + tokenizer, model = load_models(mt, False) + _, smodel = load_models(mt, True) + for ls in layers[:]: + tt_labels = encode(labels, tokenizer, smodel if ls == "sent" else model, ls) + layers_str = ls if ls == "sent" else "_".join([str(x) for x in ls]) + + index = faiss.IndexFlatIP(tt_labels.shape[1]) + nb_nn = min([100, tt_labels.shape[0]]) + index.add(tt_labels) + + distance_matrix, nearest_neighbors_matrix = index.search(tt_labels, nb_nn) + ds = 1.0 - distance_matrix[:, 1] + + dfa = pd.DataFrame(ds, columns=["delta"]) + dfa["model_type"] = mt + dfa["layers"] = layers_str + df_agg += [dfa] + df0 = pd.concat(df_agg) + path = f"{fig_path}/discrim.pdf" + col_wrap = min([4, len(layers)]) + + sns.set_style("whitegrid") + g = sns.displot( + data=df0, + x="delta", + hue="model_type", + stat="density", + common_norm=False, + col="layers", + col_wrap=col_wrap, + bins=np.arange(0.0, 0.6, 0.05), + alpha=0.5, + facet_kws=dict(legend_out=False), + ) + sns.move_legend(g, "upper right") + plt.savefig(path, bbox_inches="tight", dpi=300) + + +if __name__ == "__main__": + run() diff --git a/run/analysis/plot_dispersion_hist.py b/run/analysis/plot_dispersion_hist.py new file mode 100644 index 0000000..296e11d --- /dev/null +++ b/run/analysis/plot_dispersion_hist.py @@ -0,0 +1,27 @@ +# pylint: disable=E1120 + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import numpy as np + + +def run(): + df0 = pd.read_csv("./data/derived/491_within_n_across_disps.csv") + df0 = df0.rename(columns={"within_verbs": "std"}) + fig_path = "./figs" + + sns.set_style("whitegrid") + _ = sns.displot( + data=df0, + x="std", + stat="density", + common_norm=False, + bins=np.arange(0.1, 0.7, 0.02), + ) + path = f"{fig_path}/std.verb.pdf" + plt.savefig(path, bbox_inches="tight", dpi=300) + + +if __name__ == "__main__": + run() diff --git a/run/analysis/plot_interterm_dist.py b/run/analysis/plot_interterm_dist.py index e2b536a..9c934e8 100644 --- a/run/analysis/plot_interterm_dist.py +++ b/run/analysis/plot_interterm_dist.py @@ -20,6 +20,8 @@ help="run over BERT flavours", ) def run(model_type): + model_type = sorted(model_type) + fig_path = "./figs" df0 = pd.read_csv("data/derived/properties.synthesis.csv") @@ -28,10 +30,9 @@ def run(model_type): labels = report.pop("labels") layers = ( - [[-x] for x in range(1, 6)] - + [[-x for x in range(1, 4)]] - + [[-1, -2, -8, -9]] - + ["sent"] + # [[-x] for x in range(1, 6)] + + # [[-x for x in range(1, 4)]] + + [[-1], [-1, -2, -8, -9]] + ["sent"] ) df_agg = [] @@ -40,7 +41,7 @@ def run(model_type): tokenizer, model = load_models(mt, False) _, smodel = load_models(mt, True) for ls in layers[:]: - tt_labels = encode(labels, tokenizer, model, ls) + tt_labels = encode(labels, tokenizer, smodel if ls == "sent" else model, ls) layers_str = ls if ls == "sent" else "_".join([str(x) for x in ls]) index = faiss.IndexFlatIP(tt_labels.shape[1]) @@ -64,7 +65,10 @@ def run(model_type): dfa["layers"] = layers_str df_agg += [dfa] df0 = pd.concat(df_agg) - path = f"{fig_path}/interdist.png" + path = f"{fig_path}/interdist.new.pdf" + col_wrap = min([4, len(layers)]) + + sns.set_style("whitegrid") _ = sns.displot( data=df0, x="dist", @@ -72,7 +76,7 @@ def run(model_type): stat="density", common_norm=False, col="layers", - col_wrap=4, + col_wrap=col_wrap, bins=np.arange(0.0, 1.1, 0.05), alpha=0.5, )