2
0

pytorch.js 212 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802
  1. // Experimental
  2. var pytorch = pytorch || {};
  3. var python = python || require('./python');
  4. var base = base || require('./base');
  5. pytorch.ModelFactory = class {
  6. match(context) {
  7. return pytorch.Container.open(context);
  8. }
  9. open(context, match) {
  10. const identifier = context.identifier;
  11. return pytorch.Metadata.open(context).then((metadata) => {
  12. const container = match;
  13. try {
  14. container.metadata = metadata;
  15. container.exception = (error, fatal) => {
  16. const message = error && error.message ? error.message : error.toString();
  17. context.exception(new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'."), fatal);
  18. };
  19. }
  20. catch (error) {
  21. const message = error && error.message ? error.message : error.toString();
  22. throw new pytorch.Error('File format is not PyTorch (' + message.replace(/\.$/, '') + ').');
  23. }
  24. return new pytorch.Model(metadata, container);
  25. });
  26. }
  27. };
  28. pytorch.Model = class {
  29. constructor(metadata, container) {
  30. this._format = container.format;
  31. this._producer = container.producer || '';
  32. this._graphs = container.graphs.map((graph) => new pytorch.Graph(metadata, graph, container));
  33. }
  34. get format() {
  35. return this._format;
  36. }
  37. get graphs() {
  38. return this._graphs;
  39. }
  40. };
  41. pytorch.Graph = class {
  42. constructor(metadata, graph, container) {
  43. this._nodes = [];
  44. this._inputs = [];
  45. this._outputs = [];
  46. this._groups = true;
  47. this._littleEndian = container.littleEndian;
  48. this._name = graph.name || '';
  49. const type = graph.type;
  50. switch (type) {
  51. case 'script': {
  52. const traced = graph.trace();
  53. const initializers = new Map();
  54. if (graph.constants) {
  55. for (const constant of graph.constants) {
  56. if (pytorch.Utility.isTensor(constant)) {
  57. constant.initializer = pytorch.Utility.createTensor(constant.__variable__, constant, this._littleEndian);
  58. initializers.set(constant.__variable__, constant);
  59. }
  60. else if (constant && constant.__class__ && constant.__class__.__module__ && constant.__class__.__name__) {
  61. const type = constant.__class__.__module__ + '.' + constant.__class__.__name__;
  62. switch (type) {
  63. case '__torch__.torch.classes.xnnpack.LinearOpContext':
  64. case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
  65. case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
  66. case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
  67. for (const key of Object.keys(constant)) {
  68. const value = constant[key];
  69. if (pytorch.Utility.isTensor(value)) {
  70. value.initializer = pytorch.Utility.createTensor(value.__variable__, value, this._littleEndian);
  71. initializers.set(value.__variable__, value);
  72. }
  73. }
  74. break;
  75. default:
  76. throw new pytorch.Error("Unsupported constant context '" + type + "'.");
  77. }
  78. }
  79. else {
  80. throw new pytorch.Error('Unsupported constant.');
  81. }
  82. }
  83. }
  84. if (graph.data) {
  85. const queue = [ graph.data ];
  86. while (queue.length > 0) {
  87. const module = queue.shift();
  88. if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') {
  89. continue;
  90. }
  91. for (const key of Object.keys(module)) {
  92. if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') {
  93. const obj = module[key];
  94. if (!Array.isArray(obj) && obj === Object(obj)) {
  95. if (pytorch.Utility.isTensor(obj)) {
  96. const parameter = obj;
  97. parameter.__parent__ = module;
  98. if (!parameter.initializer && parameter.storage()) {
  99. parameter.initializer = pytorch.Utility.createTensor(parameter.name, parameter, this._littleEndian);
  100. }
  101. if (parameter.__variable__ && parameter.__count__ === 1) {
  102. initializers.set(parameter.__variable__, parameter);
  103. }
  104. }
  105. else if (obj && obj.__class__) {
  106. obj.__parent__ = module;
  107. if (!obj.__id__) {
  108. obj.__id__ = key;
  109. }
  110. queue.push(obj);
  111. }
  112. }
  113. }
  114. }
  115. }
  116. }
  117. if (traced) {
  118. if (graph.inputs) {
  119. for (const input of graph.inputs) {
  120. this._inputs.push(new pytorch.Parameter(input, true, [
  121. new pytorch.Argument(input, null, null)
  122. ]));
  123. }
  124. }
  125. if (graph.outputs) {
  126. for (const output of graph.outputs) {
  127. this._outputs.push(new pytorch.Parameter(output, true, [
  128. new pytorch.Argument(output, null, null)
  129. ]));
  130. }
  131. }
  132. if (graph.nodes) {
  133. for (const node of graph.nodes) {
  134. const item = {
  135. type: node.type,
  136. node: node
  137. };
  138. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  139. }
  140. }
  141. }
  142. if (graph) {
  143. this._loadScriptModule(metadata, container, graph.data, initializers);
  144. }
  145. break;
  146. }
  147. case 'module': {
  148. this._type = (graph.data.__module__ && graph.data.__name__) ? (graph.data.__module__ + '.' + graph.data.__name__) : '';
  149. this._loadModule(metadata, graph.data, [], []);
  150. break;
  151. }
  152. case 'weights': {
  153. for (const state_group of graph.data) {
  154. const attributes = state_group.attributes || [];
  155. const inputs = state_group.states.map((parameter) => {
  156. return new pytorch.Parameter(parameter.name, true,
  157. parameter.arguments.map((state) => {
  158. const tensor = pytorch.Utility.createTensor(state.id, pytorch.Utility.toTensor(state.value), this._littleEndian);
  159. return new pytorch.Argument(state.id, null, tensor);
  160. }));
  161. });
  162. const obj = {
  163. name: state_group.name,
  164. type: state_group.type || 'torch.nn.Module',
  165. attributes: attributes,
  166. inputs: inputs,
  167. outputs: []
  168. };
  169. this._nodes.push(new pytorch.Node(metadata, '', obj, null));
  170. }
  171. break;
  172. }
  173. default: {
  174. throw new pytorch.Error("Unsupported container type '" + type + "'.");
  175. }
  176. }
  177. }
  178. _loadModule(metadata, current, groups, inputs) {
  179. if (current.__class__ && current.__class__.__module__ !== 'torch.nn.modules.container' && (!current._modules || current._modules.size == 0)) {
  180. this._createNode(metadata, groups, '', current, inputs, false);
  181. return [];
  182. }
  183. if (!current._modules) {
  184. throw new pytorch.Error('Module does not contain modules.');
  185. }
  186. const sequential = current.__class__ && current.__class__.__module__ === 'torch.nn.modules.container' && current.__class__.__name__ === 'Sequential';
  187. for (const pair of current._modules) {
  188. const key = pair[0];
  189. const value = pair[1];
  190. if (value) {
  191. const type = value.__class__.__module__ + '.' + value.__class__.__name__;
  192. switch (type) {
  193. case 'torch.nn.modules.container.Sequential':
  194. groups.push(key);
  195. inputs = this._loadModule(metadata, value, groups, sequential ? inputs : []);
  196. groups.pop(key);
  197. break;
  198. default: {
  199. inputs = this._createNode(metadata, groups, key, value, sequential ? inputs : [], sequential);
  200. break;
  201. }
  202. }
  203. }
  204. }
  205. return inputs;
  206. }
  207. _createNode(metadata, groups, key, obj, args, output) {
  208. const type = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  209. const schema = metadata.type(type);
  210. let inputSchema = [ { name: 'input'} ];
  211. if (schema && schema.inputs && schema.inputs.length > 0) {
  212. inputSchema = schema.inputs.slice();
  213. }
  214. const inputName = inputSchema.shift().name;
  215. const inputs = [];
  216. if (args.length > 0) {
  217. inputs.push(new pytorch.Parameter(inputName, true, args.map((argument) => {
  218. return new pytorch.Argument(argument, null, null);
  219. })));
  220. }
  221. const parameters = obj._parameters || obj._buffers || [];
  222. for (const parameter of parameters) {
  223. const key = parameter[0];
  224. const value = pytorch.Utility.toTensor(parameter[1]);
  225. let visible = true;
  226. let inputName = '';
  227. if (inputSchema.length > 0) {
  228. const input = inputSchema.shift();
  229. inputName = input.name;
  230. visible = input.visible === false ? false : true;
  231. }
  232. if (value) {
  233. const initializer = pytorch.Utility.createTensor('', value, this._littleEndian);
  234. inputs.push(new pytorch.Parameter(inputName || key, visible, [ new pytorch.Argument('', null, initializer) ]));
  235. }
  236. }
  237. const group = groups.join('/');
  238. const name = group ? (group + '/' + key) : key;
  239. const outputs = output ? [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ] : [];
  240. const attributes = [];
  241. for (const name of Object.keys(obj)) {
  242. if (name.startsWith('_')) {
  243. continue;
  244. }
  245. attributes.push({ name: name, value: obj[name] });
  246. }
  247. const item = {
  248. name: name,
  249. type: type,
  250. attributes: attributes,
  251. children: obj._modules && obj._modules.size > 0 ? true : false,
  252. inputs: inputs,
  253. outputs: outputs
  254. };
  255. const node = new pytorch.Node(metadata, group, item, {});
  256. this._nodes.push(node);
  257. return [ node.name ];
  258. }
  259. _loadScriptModule(metadata, container, module, initializers) {
  260. if (module) {
  261. if (pytorch.Graph._getParameters(module).length > 0 && !module.__hide__) {
  262. const item = { module: module };
  263. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  264. }
  265. const submodules = pytorch.Graph._getSubmodules(module);
  266. for (const submodule of submodules) {
  267. this._loadScriptModule(metadata, container, submodule, initializers);
  268. }
  269. }
  270. }
  271. static _getParameters(module) {
  272. const parameters = [];
  273. if (module && module.__class__.__module__ && module.__class__.__name__) {
  274. for (const key of Object.keys(module)) {
  275. if (pytorch.Utility.isTensor(module[key])) {
  276. const parameter = module[key];
  277. parameter.__id__ = key;
  278. parameters.push(parameter);
  279. }
  280. }
  281. }
  282. return parameters;
  283. }
  284. static _getSubmodules(module) {
  285. const submodules = [];
  286. if (module && module.__class__ && module.__class__.__module__ && module.__class__.__name__) {
  287. for (const key of Object.keys(module)) {
  288. if (!key.startsWith('__')) {
  289. const value = module[key];
  290. if (value && value.__class__ && value.__module__ && value.__name__ && !pytorch.Utility.isTensor(value)) {
  291. submodules.push(value);
  292. }
  293. }
  294. }
  295. }
  296. return submodules;
  297. }
  298. get type() {
  299. return this._type;
  300. }
  301. get name() {
  302. return this._name;
  303. }
  304. get groups() {
  305. return this._groups;
  306. }
  307. get inputs() {
  308. return this._inputs;
  309. }
  310. get outputs() {
  311. return this._outputs;
  312. }
  313. get nodes() {
  314. return this._nodes;
  315. }
  316. };
  317. pytorch.Parameter = class {
  318. constructor(name, visible, args) {
  319. this._name = name;
  320. this._visible = visible;
  321. this._arguments = args;
  322. }
  323. get name() {
  324. return this._name;
  325. }
  326. get visible() {
  327. return this._visible;
  328. }
  329. get arguments() {
  330. return this._arguments;
  331. }
  332. };
  333. pytorch.Argument = class {
  334. constructor(name, type, initializer) {
  335. if (typeof name !== 'string') {
  336. throw new pytorch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  337. }
  338. this._name = name;
  339. this._type = type;
  340. this._initializer = initializer;
  341. }
  342. get name() {
  343. return this._name;
  344. }
  345. get type() {
  346. if (this._initializer) {
  347. return this._initializer.type;
  348. }
  349. return this._type;
  350. }
  351. get initializer() {
  352. return this._initializer;
  353. }
  354. };
  355. pytorch.Node = class {
  356. constructor(metadata, group, item, initializers) {
  357. this._group = group || '';
  358. this._name = item.name || '';
  359. const type = (metadata, name) => {
  360. if (name instanceof pytorch.nnapi.Graph) {
  361. this._type = name;
  362. return;
  363. }
  364. this._type = Object.assign({}, metadata.type(name) || { name: name });
  365. const identifier = this._type.name;
  366. this._type.identifier = identifier;
  367. const index = identifier.indexOf(':');
  368. this._type.name = index === -1 ? identifier : identifier.substring(0, index);
  369. };
  370. if (!item.module && !item.node) {
  371. type(metadata, item.type);
  372. this._nodes = item.children;
  373. this._inputs = item.inputs;
  374. this._outputs = item.outputs;
  375. this._attributes = item.attributes.map((attribute) => {
  376. const schema = metadata.attribute(this._type.identifier, attribute.name);
  377. return new pytorch.Attribute(schema, attribute.name, attribute.value);
  378. });
  379. }
  380. else {
  381. this._attributes = [];
  382. this._inputs = [];
  383. this._outputs = [];
  384. let module = item.module;
  385. if (module) {
  386. this._type = { name: 'torch.nn.modules.module.Module' };
  387. for (const parameter of pytorch.Graph._getParameters(module)) {
  388. this._inputs.push(new pytorch.Parameter(parameter.__id__, true, [
  389. new pytorch.Argument('', null, parameter.initializer || null)
  390. ]));
  391. if (parameter.__variable__) {
  392. this._outputs.push(new pytorch.Parameter(parameter.__id__, true, [
  393. new pytorch.Argument(parameter.__variable__, null, null)
  394. ]));
  395. }
  396. }
  397. }
  398. if (item.node) {
  399. type(metadata, item.type);
  400. module = null;
  401. let match = true;
  402. let count = 0;
  403. for (const input of item.node.inputs) {
  404. for (const argument of input) {
  405. const parameter = initializers.get(argument.id);
  406. if (parameter) {
  407. if (parameter.__parent__ && (module == null || module == parameter.__parent__)) {
  408. module = parameter.__parent__;
  409. count++;
  410. }
  411. else if (parameter.__variable__.startsWith('CONSTANTS.c')) {
  412. argument.initializer = parameter.initializer;
  413. count++;
  414. }
  415. else {
  416. match = false;
  417. break;
  418. }
  419. }
  420. }
  421. if (!match) {
  422. break;
  423. }
  424. }
  425. if (module) {
  426. const params = pytorch.Graph._getParameters(module).filter((p) => p.__id__ !== 'num_batches_tracked');
  427. if (params.length == count && match) {
  428. module.__hide__ = true;
  429. for (const input of item.node.inputs) {
  430. for (const argument of input) {
  431. const parameter = initializers.get(argument.id);
  432. if (parameter && parameter.initializer) {
  433. argument.initializer = parameter.initializer;
  434. }
  435. }
  436. }
  437. }
  438. else {
  439. module = null;
  440. }
  441. }
  442. for (let inputIndex = 0; inputIndex < item.node.inputs.length; inputIndex++) {
  443. let inputName = inputIndex.toString();
  444. if (this._type && this._type.inputs && this._type.inputs.length > inputIndex) {
  445. inputName = this._type.inputs[inputIndex].name;
  446. }
  447. this._inputs.push(new pytorch.Parameter(inputName, true,
  448. item.node.inputs[inputIndex].map((input) => new pytorch.Argument(input.id, null, input.initializer || null))
  449. ));
  450. }
  451. for (let outputIndex = 0; outputIndex < item.node.outputs.length; outputIndex++) {
  452. let outputName = outputIndex.toString();
  453. if (this._type && this._type.outputs && this._type.outputs.length > outputIndex) {
  454. outputName = this._type.outputs[outputIndex].name;
  455. }
  456. this._outputs.push(new pytorch.Parameter(outputName, true,
  457. item.node.outputs[outputIndex].map((output) => new pytorch.Argument(output.id, null, null))
  458. ));
  459. }
  460. for (const attribute of item.node.attributes) {
  461. const name = attribute.name;
  462. const value = attribute.value;
  463. const schema = metadata.attribute(this._type.identifier, name);
  464. this._attributes.push(new pytorch.Attribute(schema, name, value));
  465. }
  466. }
  467. if (module) {
  468. if (module.__id__) {
  469. let current = module;
  470. this._name = current.__id__;
  471. while (current.__parent__ != null) {
  472. current = current.__parent__;
  473. if (!current.__parent__ && !current.__id__) {
  474. break;
  475. }
  476. this._name = [ current.__id__, this._name ].join('.');
  477. }
  478. }
  479. }
  480. }
  481. }
  482. get name() {
  483. return this._name;
  484. }
  485. get group() {
  486. return this._group;
  487. }
  488. get type() {
  489. return this._type;
  490. }
  491. get attributes() {
  492. return this._attributes;
  493. }
  494. get inputs() {
  495. return this._inputs;
  496. }
  497. get outputs() {
  498. return this._outputs;
  499. }
  500. get nodes() {
  501. return this._nodes;
  502. }
  503. };
  504. pytorch.Attribute = class {
  505. constructor(metadata, name, value) {
  506. this._name = name;
  507. this._value = value;
  508. if (this._name === 'training') {
  509. this._visible = false;
  510. this._type = 'boolean';
  511. }
  512. else if (metadata) {
  513. if (metadata.type) {
  514. this._type = metadata.type;
  515. }
  516. if (metadata.visible === false) {
  517. this._visible = false;
  518. }
  519. else if (metadata.default !== undefined) {
  520. if (Array.isArray(value)) {
  521. if (Array.isArray(metadata.default)) {
  522. this._visible = value.length !== metadata.default || !this.value.every((item, index) => item == metadata.default[index]);
  523. }
  524. else {
  525. this._visible = !this.value.every((item) => item == metadata.default);
  526. }
  527. }
  528. else {
  529. this._visible = this.value !== metadata.default;
  530. }
  531. }
  532. }
  533. if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
  534. this._value = '?';
  535. }
  536. }
  537. get type() {
  538. return this._type;
  539. }
  540. get name() {
  541. return this._name;
  542. }
  543. get value() {
  544. return this._value;
  545. }
  546. get visible() {
  547. return this._visible == false ? false : true;
  548. }
  549. };
  550. pytorch.Tensor = class {
  551. constructor(name, type, data, littleEndian) {
  552. this._name = name || '';
  553. this._type = type;
  554. this._data = data;
  555. this._littleEndian = littleEndian;
  556. }
  557. get kind() {
  558. return 'Tensor';
  559. }
  560. get name() {
  561. return this._name;
  562. }
  563. get type() {
  564. return this._type;
  565. }
  566. get state() {
  567. return this._context().state;
  568. }
  569. get value() {
  570. const context = this._context();
  571. if (context.state) {
  572. return null;
  573. }
  574. context.limit = Number.MAX_SAFE_INTEGER;
  575. return this._decode(context, 0);
  576. }
  577. toString() {
  578. const context = this._context();
  579. if (context.state) {
  580. return '';
  581. }
  582. context.limit = 10000;
  583. const value = this._decode(context, 0);
  584. return pytorch.Tensor._stringify(value, '', ' ');
  585. }
  586. _context() {
  587. const context = {};
  588. context.state = null;
  589. context.index = 0;
  590. context.count = 0;
  591. if (!this._type.dataType) {
  592. context.state = 'Tensor has no data type.';
  593. return context;
  594. }
  595. switch (this._type.dataType) {
  596. case 'boolean':
  597. case 'uint8':
  598. case 'qint8':
  599. case 'int8':
  600. case 'int16':
  601. case 'int32':
  602. case 'int64':
  603. case 'float16':
  604. case 'float32':
  605. case 'float64':
  606. case 'bfloat16':
  607. case 'complex64':
  608. case 'complex128':
  609. break;
  610. default:
  611. context.state = "Tensor data type '" + this._type.dataType + "' is not implemented.";
  612. return context;
  613. }
  614. if (!this._type.shape) {
  615. context.state = 'Tensor has no dimensions.';
  616. return context;
  617. }
  618. if (!this._data) {
  619. context.state = 'Tensor data is empty.';
  620. return context;
  621. }
  622. try {
  623. context.data = this._data instanceof Uint8Array ? this._data : this._data.peek();
  624. }
  625. catch (err) {
  626. context.state = err.message;
  627. return context;
  628. }
  629. context.dataType = this._type.dataType;
  630. context.dimensions = this._type.shape.dimensions;
  631. context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  632. return context;
  633. }
  634. _decode(context, dimension) {
  635. const results = [];
  636. const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions;
  637. const size = dimensions[dimension];
  638. if (dimension == dimensions.length - 1) {
  639. for (let i = 0; i < size; i++) {
  640. if (context.count > context.limit) {
  641. results.push('...');
  642. return results;
  643. }
  644. switch (context.dataType) {
  645. case 'boolean':
  646. results.push(context.view.getUint8(context.index) === 0 ? false : true);
  647. context.index++;
  648. context.count++;
  649. break;
  650. case 'uint8':
  651. results.push(context.view.getUint8(context.index));
  652. context.index++;
  653. context.count++;
  654. break;
  655. case 'qint8':
  656. case 'int8':
  657. results.push(context.view.getInt8(context.index));
  658. context.index++;
  659. context.count++;
  660. break;
  661. case 'int16':
  662. results.push(context.view.getInt16(context.index, this._littleEndian));
  663. context.index += 2;
  664. context.count++;
  665. break;
  666. case 'int32':
  667. results.push(context.view.getInt32(context.index, this._littleEndian));
  668. context.index += 4;
  669. context.count++;
  670. break;
  671. case 'int64':
  672. results.push(context.view.getInt64(context.index, this._littleEndian));
  673. context.index += 8;
  674. context.count++;
  675. break;
  676. case 'float16':
  677. results.push(context.view.getFloat16(context.index, this._littleEndian));
  678. context.index += 2;
  679. context.count++;
  680. break;
  681. case 'float32':
  682. results.push(context.view.getFloat32(context.index, this._littleEndian));
  683. context.index += 4;
  684. context.count++;
  685. break;
  686. case 'float64':
  687. results.push(context.view.getFloat64(context.index, this._littleEndian));
  688. context.index += 8;
  689. context.count++;
  690. break;
  691. case 'bfloat16':
  692. results.push(context.view.getBfloat16(context.index, this._littleEndian));
  693. context.index += 2;
  694. context.count++;
  695. break;
  696. case 'complex64':
  697. results.push(context.view.getComplex64(i << 3, this._littleEndian));
  698. context.index += 8;
  699. context.count++;
  700. break;
  701. case 'complex128':
  702. results.push(context.view.getComplex128(i << 4, this._littleEndian));
  703. context.index += 16;
  704. context.count++;
  705. break;
  706. default:
  707. throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'.");
  708. }
  709. }
  710. }
  711. else {
  712. for (let j = 0; j < size; j++) {
  713. if (context.count > context.limit) {
  714. results.push('...');
  715. return results;
  716. }
  717. results.push(this._decode(context, dimension + 1));
  718. }
  719. }
  720. if (context.dimensions.length == 0) {
  721. return results[0];
  722. }
  723. return results;
  724. }
  725. static _stringify(value, indentation, indent) {
  726. if (Array.isArray(value)) {
  727. const result = [];
  728. result.push(indentation + '[');
  729. const items = value.map((item) => pytorch.Tensor._stringify(item, indentation + indent, indent));
  730. if (items.length > 0) {
  731. result.push(items.join(',\n'));
  732. }
  733. result.push(indentation + ']');
  734. return result.join('\n');
  735. }
  736. switch (typeof value) {
  737. case 'string':
  738. return indentation + value;
  739. case 'number':
  740. if (value == Infinity) {
  741. return indentation + 'Infinity';
  742. }
  743. if (value == -Infinity) {
  744. return indentation + '-Infinity';
  745. }
  746. if (isNaN(value)) {
  747. return indentation + 'NaN';
  748. }
  749. return indentation + value.toString();
  750. default:
  751. if (value && value.toString) {
  752. return indentation + value.toString();
  753. }
  754. return indentation + '(undefined)';
  755. }
  756. }
  757. };
  758. pytorch.TensorType = class {
  759. constructor(dataType, shape) {
  760. this._dataType = dataType;
  761. this._shape = shape;
  762. }
  763. get dataType() {
  764. return this._dataType;
  765. }
  766. get shape() {
  767. return this._shape;
  768. }
  769. toString() {
  770. return this._dataType + this._shape.toString();
  771. }
  772. };
  773. pytorch.TensorShape = class {
  774. constructor(dimensions) {
  775. this._dimensions = dimensions || [];
  776. }
  777. get dimensions() {
  778. return this._dimensions;
  779. }
  780. toString() {
  781. if (this._dimensions && this._dimensions.length > 0) {
  782. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  783. }
  784. return '';
  785. }
  786. };
  787. pytorch.Execution = class extends python.Execution {
  788. constructor(sources, exceptionCallback) {
  789. super(sources, exceptionCallback);
  790. this.register('ops');
  791. this.register('ops.torchvision');
  792. this.register('ops.torchaudio');
  793. const torch = this.register('torch');
  794. const torch_storage = this.register('torch.storage');
  795. const torch_nn_parameter = this.register('torch.nn.parameter');
  796. this.register('torchvision');
  797. this.register('__torch__');
  798. this.context.setx('ops._caffe2',{ __name__: 'torch', __class__: this._builtins.module });
  799. const self = this;
  800. this.registerType('builtins.number', class {});
  801. this.registerType('__torch__.torch.classes._nnapi.Compilation', class {
  802. constructor() {
  803. this.__hide__ = true;
  804. }
  805. __init__() {
  806. }
  807. init(serialized_model_tensor, parameter_buffers) {
  808. this.serialized_model_tensor = serialized_model_tensor;
  809. this.parameter_buffers = parameter_buffers;
  810. const buffers = parameter_buffers.map((buffer) => buffer.__source__.storage().data);
  811. const serialized_model = serialized_model_tensor.storage().data;
  812. this.serialized_model = new pytorch.nnapi.SerializedModel(serialized_model, buffers);
  813. }
  814. run(inputs, outputs) {
  815. this.serialized_model_tensor.__variable__ = this.serialized_model_tensor.__variable__ || self.variable();
  816. this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1;
  817. self.push({
  818. type: new pytorch.nnapi.Graph(this.serialized_model),
  819. attributes: [],
  820. inputs: [
  821. inputs.map((input) => { return { id: input.__variable__ }; }),
  822. // [ { id: this.serialized_model_tensor.__variable__ } ] //,
  823. // this.parameter_buffers.map((buffer) => { return { id: buffer.__variable__ }; })
  824. ],
  825. outputs: [
  826. outputs.map((output) => { return { id: output.__variable__ }; })
  827. ],
  828. });
  829. }
  830. });
  831. this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class {
  832. __setstate__(state) {
  833. const pack_version = state[0];
  834. if (pack_version !== '2') {
  835. throw new pytorch.Error("Unsupported pack version '" + pack_version.toString() + "'.");
  836. }
  837. const tensors = state[1];
  838. const opt_tensors = state[2];
  839. const packed_config = pytorch.Utility.createTensor('', tensors[0], true).value;
  840. this.weight = tensors[1];
  841. this.bias = opt_tensors[0];
  842. this.stride = [ packed_config[1], packed_config[2] ];
  843. this.padding = [ packed_config[3], packed_config[4] ];
  844. this.dilation = [ packed_config[5], packed_config[6] ];
  845. this.output_padding = [ packed_config[7], packed_config[8] ];
  846. this.groups = packed_config[9];
  847. }
  848. });
  849. this.registerType('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', class {
  850. __setstate__(state) {
  851. const pack_version = state[0];
  852. if (pack_version !== '2') {
  853. throw new pytorch.Error("Unsupported pack version '" + pack_version.toString() + "'.");
  854. }
  855. const tensors = state[1];
  856. const opt_tensors = state[2];
  857. const packed_config = pytorch.Utility.createTensor('', tensors[0], true).value;
  858. this.weight = tensors[1];
  859. this.bias = opt_tensors[0];
  860. this.stride = [ packed_config[1], packed_config[2] ];
  861. this.padding = [ packed_config[3], packed_config[4] ];
  862. this.dilation = [ packed_config[5], packed_config[6] ];
  863. this.output_padding = [ packed_config[7], packed_config[8] ];
  864. this.groups = packed_config[9];
  865. }
  866. });
  867. this.registerType('__torch__.torch.classes.quantized.LinearPackedParamsBase', class {
  868. __setstate__(state) {
  869. this.weight = state[0];
  870. this.bias = state[1];
  871. }
  872. });
  873. this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
  874. __setstate__(state) {
  875. this.weight = state[0];
  876. this.bias = state[1];
  877. this.stride = state[2];
  878. this.padding = state[3];
  879. this.dilation = state[4];
  880. this.groups = state[5];
  881. this.output_min = state[6];
  882. this.output_max = state[7];
  883. }
  884. });
  885. this.registerType('__torch__.torch.classes.xnnpack.LinearOpContext', class {
  886. __setstate__(state) {
  887. this.weight = state[0];
  888. this.bias = state[1];
  889. this.output_min = state[2];
  890. this.output_max = state[3];
  891. }
  892. });
  893. this.registerType('torch.ao.quantization.observer._PartialWrapper', class {});
  894. this.registerType('torch.ao.quantization.qconfig.QConfig', class {});
  895. this.registerType('torch.ao.quantization.stubs.DeQuantStub', class {});
  896. this.registerType('torch.ao.quantization.stubs.QuantStub', class {});
  897. this.registerType('torch.autograd.variable.Variable', class {});
  898. this.registerType('torch.backends.cudnn.rnn.Unserializable', class {});
  899. this.registerType('torch.distributions.bernoulli.Bernoulli', class {});
  900. this.registerType('torch.distributions.constraints._LowerCholesky', class {});
  901. this.registerType('torch.distributions.constraints._Real', class {});
  902. this.registerType('torch.distributions.multivariate_normal.MultivariateNormal', class {});
  903. this.registerType('torch.distributions.normal.Normal', class {});
  904. this.registerType('torch.distributions.transforms.LowerCholeskyTransform', class {});
  905. this.registerType('torch.distributions.uniform.Uniform', class {});
  906. this.registerType('torch.nn.backends.thnn._get_thnn_function_backend', class {});
  907. this.registerType('torch.nn.intrinsic.modules.fused.ConvBnReLU2d', class {});
  908. this.registerType('torch.nn.intrinsic.modules.fused.ConvReLU2d', class {});
  909. this.registerType('torch.nn.intrinsic.modules.fused.BNReLU2d', class {});
  910. this.registerType('torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d', class {});
  911. this.registerType('torch.nn.intrinsic.qat.modules.conv_fused.ConvReLU2d', class {});
  912. this.registerType('torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d', class {});
  913. this.registerType('torch.nn.intrinsic.quantized.modules.linear_relu.LinearReLU', class {});
  914. this.registerType('torch.nn.modules.activation.CELU', class {});
  915. this.registerType('torch.nn.modules.activation.ELU', class {});
  916. this.registerType('torch.nn.modules.activation.GELU', class {});
  917. this.registerType('torch.nn.modules.activation.GLU', class {});
  918. this.registerType('torch.nn.modules.activation.Hardtanh', class {});
  919. this.registerType('torch.nn.modules.activation.Hardswish', class {});
  920. this.registerType('torch.nn.modules.activation.Hardsigmoid', class {});
  921. this.registerType('torch.nn.modules.activation.LeakyReLU', class {});
  922. this.registerType('torch.nn.modules.activation.LogSigmoid', class {});
  923. this.registerType('torch.nn.modules.activation.LogSoftmax', class {});
  924. this.registerType('torch.nn.modules.activation.Mish', class {});
  925. this.registerType('torch.nn.modules.activation.MultiheadAttention', class {});
  926. this.registerType('torch.nn.modules.activation.ReLU', class {});
  927. this.registerType('torch.nn.modules.activation.ReLU6', class {});
  928. this.registerType('torch.nn.modules.activation.PReLU', class {});
  929. this.registerType('torch.nn.modules.activation.RReLU', class {});
  930. this.registerType('torch.nn.modules.activation.SELU', class {});
  931. this.registerType('torch.nn.modules.activation.Sigmoid', class {});
  932. this.registerType('torch.nn.modules.activation.SiLU', class {});
  933. this.registerType('torch.nn.modules.activation.Softmax', class {});
  934. this.registerType('torch.nn.modules.activation.Softmax2d', class {});
  935. this.registerType('torch.nn.modules.activation.Softplus', class {});
  936. this.registerType('torch.nn.modules.activation.Tanh', class {});
  937. this.registerType('torch.nn.modules.activation.Tanhshrink', class {});
  938. this.registerType('torch.nn.modules.activation.Threshold', class {});
  939. this.registerType('torch.nn.modules.batchnorm.BatchNorm1d', class {});
  940. this.registerType('torch.nn.modules.batchnorm.BatchNorm2d', class {});
  941. this.registerType('torch.nn.modules.batchnorm.BatchNorm3d', class {});
  942. this.registerType('torch.nn.modules.batchnorm.LazyBatchNorm1d', class {});
  943. this.registerType('torch.nn.modules.batchnorm.SyncBatchNorm', class {});
  944. this.registerType('torch.nn.modules.container.ModuleDict', class {});
  945. this.registerType('torch.nn.modules.container.ModuleList', class {});
  946. this.registerType('torch.nn.modules.container.ParameterDict', class {});
  947. this.registerType('torch.nn.modules.container.ParameterList', class {});
  948. this.registerType('torch.nn.modules.container.Sequential', class {});
  949. this.registerType('torch.nn.modules.conv.Conv1d', class {});
  950. this.registerType('torch.nn.modules.conv.Conv2d', class {});
  951. this.registerType('torch.nn.modules.conv.Conv3d', class {});
  952. this.registerType('torch.nn.modules.conv.ConvTranspose1d', class {});
  953. this.registerType('torch.nn.modules.conv.ConvTranspose2d', class {});
  954. this.registerType('torch.nn.modules.conv.ConvTranspose3d', class {});
  955. this.registerType('torch.nn.modules.distance.CosineSimilarity', class {});
  956. this.registerType('torch.nn.modules.dropout.AlphaDropout', class {});
  957. this.registerType('torch.nn.modules.dropout.Dropout', class {});
  958. this.registerType('torch.nn.modules.dropout.Dropout2d', class {});
  959. this.registerType('torch.nn.modules.dropout.Dropout3d', class {});
  960. this.registerType('torch.nn.modules.fold.Fold', class {});
  961. this.registerType('torch.nn.modules.fold.Unfold', class {});
  962. this.registerType('torch.nn.modules.flatten.Flatten', class {});
  963. this.registerType('torch.nn.modules.flatten.Unflatten', class {});
  964. this.registerType('torch.nn.modules.instancenorm.InstanceNorm1d', class {});
  965. this.registerType('torch.nn.modules.instancenorm.InstanceNorm2d', class {});
  966. this.registerType('torch.nn.modules.instancenorm.InstanceNorm3d', class {});
  967. this.registerType('torch.nn.modules.linear._LinearWithBias', class {});
  968. this.registerType('torch.nn.modules.linear.Bilinear', class {});
  969. this.registerType('torch.nn.modules.linear.Identity', class {});
  970. this.registerType('torch.nn.modules.linear.LazyLinear', class {});
  971. this.registerType('torch.nn.modules.linear.Linear', class {});
  972. this.registerType('torch.nn.modules.linear.NonDynamicallyQuantizableLinear', class {});
  973. this.registerType('torch.nn.modules.loss.BCELoss', class {});
  974. this.registerType('torch.nn.modules.loss.BCEWithLogitsLoss', class {});
  975. this.registerType('torch.nn.modules.loss.CrossEntropyLoss', class {});
  976. this.registerType('torch.nn.modules.loss.CTCLoss', class {});
  977. this.registerType('torch.nn.modules.loss.KLDivLoss', class {});
  978. this.registerType('torch.nn.modules.loss.L1Loss', class {});
  979. this.registerType('torch.nn.modules.loss.MarginRankingLoss', class {});
  980. this.registerType('torch.nn.modules.loss.MSELoss', class {});
  981. this.registerType('torch.nn.modules.loss.NLLLoss', class {});
  982. this.registerType('torch.nn.modules.loss.NLLLoss2d', class {});
  983. this.registerType('torch.nn.modules.loss.SmoothL1Loss', class {});
  984. this.registerType('torch.nn.modules.module._IncompatibleKeys', class {});
  985. this.registerType('torch.nn.modules.module.Module', class {});
  986. this.registerType('torch.nn.modules.module.PatchForward', class {});
  987. this.registerType('torch.nn.modules.normalization.CrossMapLRN2d', class {});
  988. this.registerType('torch.nn.modules.normalization.GroupNorm', class {});
  989. this.registerType('torch.nn.modules.normalization.LayerNorm', class {});
  990. this.registerType('torch.nn.modules.normalization.LocalResponseNorm', class {});
  991. this.registerType('torch.nn.modules.padding.ReflectionPad1d', class {});
  992. this.registerType('torch.nn.modules.padding.ReflectionPad2d', class {});
  993. this.registerType('torch.nn.modules.padding.ReplicationPad1d', class {});
  994. this.registerType('torch.nn.modules.padding.ReplicationPad2d', class {});
  995. this.registerType('torch.nn.modules.padding.ReplicationPad3d', class {});
  996. this.registerType('torch.nn.modules.padding.ZeroPad2d', class {});
  997. this.registerType('torch.nn.modules.padding.ConstantPad1d', class {});
  998. this.registerType('torch.nn.modules.padding.ConstantPad2d', class {});
  999. this.registerType('torch.nn.modules.padding.ConstantPad3d', class {});
  1000. this.registerType('torch.nn.modules.pixelshuffle.PixelShuffle', class {});
  1001. this.registerType('torch.nn.modules.pixelshuffle.PixelUnshuffle', class {});
  1002. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool1d', class {});
  1003. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool2d', class {});
  1004. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool3d', class {});
  1005. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool1d', class {});
  1006. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool2d', class {});
  1007. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool3d', class {});
  1008. this.registerType('torch.nn.modules.pooling.AvgPool1d', class {});
  1009. this.registerType('torch.nn.modules.pooling.AvgPool2d', class {});
  1010. this.registerType('torch.nn.modules.pooling.AvgPool3d', class {});
  1011. this.registerType('torch.nn.modules.pooling.FractionalMaxPool2d', class {});
  1012. this.registerType('torch.nn.modules.pooling.LPPool2d', class {});
  1013. this.registerType('torch.nn.modules.pooling.MaxPool1d', class {});
  1014. this.registerType('torch.nn.modules.pooling.MaxPool2d', class {});
  1015. this.registerType('torch.nn.modules.pooling.MaxPool3d', class {});
  1016. this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class {});
  1017. this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class {});
  1018. this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class {});
  1019. this.registerType('torch.nn.modules.rnn.GRU', class {});
  1020. this.registerType('torch.nn.modules.rnn.GRUCell', class {});
  1021. this.registerType('torch.nn.modules.rnn.LSTM', class {});
  1022. this.registerType('torch.nn.modules.rnn.LSTMCell', class {});
  1023. this.registerType('torch.nn.modules.rnn.RNN', class {});
  1024. this.registerType('torch.nn.modules.sparse.Embedding', class {});
  1025. this.registerType('torch.nn.modules.sparse.EmbeddingBag', class {});
  1026. this.registerType('torch.nn.modules.transformer.Transformer', class {});
  1027. this.registerType('torch.nn.modules.transformer.TransformerDecoder', class {});
  1028. this.registerType('torch.nn.modules.transformer.TransformerDecoderLayer', class {});
  1029. this.registerType('torch.nn.modules.transformer.TransformerEncoder', class {});
  1030. this.registerType('torch.nn.modules.transformer.TransformerEncoderLayer', class {});
  1031. this.registerType('torch.nn.modules.upsampling.Upsample', class {});
  1032. this.registerType('torch.nn.modules.upsampling.UpsamplingBilinear2d', class {});
  1033. this.registerType('torch.nn.modules.upsampling.UpsamplingNearest2d', class {});
  1034. this.registerType('torch.nn.parallel.data_parallel.DataParallel', class {});
  1035. this.registerType('torch.nn.parallel.distributed._DDPUnevenInputsConfig', class {});
  1036. this.registerType('torch.nn.parallel.distributed.DistributedDataParallel', class {});
  1037. this.registerType('torch.nn.qat.modules.conv.Conv2d', class {});
  1038. this.registerType('torch.nn.qat.modules.linear.Linear', class {});
  1039. this.registerType('torch.nn.quantized.modules.activation.ReLU', class {});
  1040. this.registerType('torch.nn.quantized.modules.activation.LeakyReLU', class {});
  1041. this.registerType('torch.nn.quantized.dynamic.modules.linear.Linear', class {});
  1042. this.registerType('torch.nn.quantized.dynamic.modules.rnn.GRU', class {});
  1043. this.registerType('torch.nn.quantized.dynamic.modules.rnn.LSTM', class {});
  1044. this.registerType('torch.nn.quantized.dynamic.modules.rnn.PackedParameter', class {});
  1045. this.registerType('torch.nn.quantized.modules.activation.ReLU6', class {});
  1046. this.registerType('torch.nn.quantized.modules.batchnorm.BatchNorm2d', class {});
  1047. this.registerType('torch.nn.quantized.modules.conv.Conv1d', class {});
  1048. this.registerType('torch.nn.quantized.modules.conv.Conv2d', class {});
  1049. this.registerType('torch.nn.quantized.modules.conv.ConvTranspose2d', class {});
  1050. this.registerType('torch.nn.quantized.modules.DeQuantize', class {});
  1051. this.registerType('torch.nn.quantized.modules.dropout.Dropout', class {});
  1052. this.registerType('torch.nn.quantized.modules.functional_modules.FloatFunctional', class {});
  1053. this.registerType('torch.nn.quantized.modules.functional_modules.QFunctional', class {});
  1054. this.registerType('torch.nn.quantized.modules.linear.Linear', class {});
  1055. this.registerType('torch.nn.quantized.modules.linear.LinearPackedParams', class {});
  1056. this.registerType('torch.nn.quantized.modules.normalization.InstanceNorm2d', class {});
  1057. this.registerType('torch.nn.quantized.modules.Quantize', class {});
  1058. this.registerType('torch.nn.utils.prune.L1Unstructured', class {});
  1059. this.registerType('torch.nn.utils.spectral_norm.SpectralNorm', class {});
  1060. this.registerType('torch.nn.utils.spectral_norm.SpectralNormStateDictHook', class {});
  1061. this.registerType('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', class {});
  1062. this.registerType('torch.nn.utils.weight_norm.WeightNorm', class {});
  1063. this.registerType('torch.optim.adam.Adam', class {});
  1064. this.register('torch.optim').Adam = this._registry.get('torch.optim.adam').Adam;
  1065. this.registerType('torch.optim.adamw.AdamW', class {});
  1066. this.registerType('torch.optim.adagrad.Adagrad', class {});
  1067. this.registerType('torch.optim.adadelta.Adadelta', class {});
  1068. this.registerType('torch.optim.lr_scheduler.CosineAnnealingLR', class {});
  1069. this.registerType('torch.optim.lr_scheduler.CyclicLR', class {});
  1070. this.registerType('torch.optim.lr_scheduler.ExponentialLR', class {});
  1071. this.registerType('torch.optim.lr_scheduler.LambdaLR', class {});
  1072. this.registerType('torch.optim.lr_scheduler.MultiStepLR', class {});
  1073. this.registerType('torch.optim.lr_scheduler.OneCycleLR', class {});
  1074. this.registerType('torch.optim.lr_scheduler.ReduceLROnPlateau', class {});
  1075. this.registerType('torch.optim.lr_scheduler.StepLR', class {});
  1076. this.registerType('torch.optim.optimizer._RequiredParameter', class {});
  1077. this.registerType('torch.optim.rmsprop.RMSprop', class {});
  1078. this.registerType('torch.optim.sgd.SGD', class {});
  1079. this.registerType('torch.quantization.fake_quantize.FakeQuantize', class {});
  1080. this.registerType('torch.quantization.observer._PartialWrapper', class {});
  1081. this.registerType('torch.quantization.observer.MinMaxObserver', class {});
  1082. this.registerType('torch.quantization.observer.MovingAverageMinMaxObserver', class {});
  1083. this.registerType('torch.quantization.observer.MovingAveragePerChannelMinMaxObserver', class {});
  1084. this.registerType('torch.quantization.qconfig.QConfig', class {});
  1085. this.registerType('torch.quantization.stubs.DeQuantStub', class {});
  1086. this.registerType('torch.quantization.stubs.QuantStub', class {});
  1087. this.registerType('torch.utils.data.dataloader._MultiProcessingDataLoaderIter', class {});
  1088. this.registerType('torch.utils.data.dataloader.DataLoader', class {});
  1089. this.registerType('torch.utils.data.dataset.Subset', class {});
  1090. this.registerType('torch.utils.data.dataset.ConcatDataset', class {});
  1091. this.registerType('torch.utils.data.dataset.TensorDataset', class {});
  1092. this.registerType('torch.utils.data.sampler.BatchSampler', class {});
  1093. this.registerType('torch.utils.data.sampler.RandomSampler', class {});
  1094. this.registerType('torch.utils.data.sampler.SequentialSampler', class {});
  1095. this.registerType('torchvision.datasets.folder.ImageFolder', class {});
  1096. this.registerType('torchvision.datasets.mnist.MNIST', class {});
  1097. this.registerType('torchvision.datasets.vision.StandardTransform', class {});
  1098. this.registerType('torchvision.models.alexnet.AlexNet', class {});
  1099. this.registerType('torchvision.models.densenet.DenseNet', class {});
  1100. this.registerType('torchvision.models.densenet._DenseBlock', class {});
  1101. this.registerType('torchvision.models.densenet._DenseLayer', class {});
  1102. this.registerType('torchvision.models.densenet._Transition', class {});
  1103. this.registerType('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', class {});
  1104. this.registerType('torchvision.models.detection._utils.BoxCoder', class {});
  1105. this.registerType('torchvision.models.detection._utils.Matcher', class {});
  1106. this.registerType('torchvision.models.detection._utils.SSDMatcher', class {});
  1107. this.registerType('torchvision.models.detection.anchor_utils.AnchorGenerator', class {});
  1108. this.registerType('torchvision.models.detection.anchor_utils.DefaultBoxGenerator', class {});
  1109. this.registerType('torchvision.models.detection.backbone_utils.BackboneWithFPN', class {});
  1110. this.registerType('torchvision.models.detection.faster_rcnn.FasterRCNN', class {});
  1111. this.registerType('torchvision.models.detection.faster_rcnn.FastRCNNPredictor', class {});
  1112. this.registerType('torchvision.models.detection.faster_rcnn.TwoMLPHead', class {});
  1113. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNN', class {});
  1114. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNHeads', class {});
  1115. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNPredictor', class {});
  1116. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNN', class {});
  1117. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNHeads', class {});
  1118. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNPredictor', class {});
  1119. this.registerType('torchvision.models.detection.retinanet.RetinaNetClassificationHead', class {});
  1120. this.registerType('torchvision.models.detection.retinanet.RetinaNetHead', class {});
  1121. this.registerType('torchvision.models.detection.retinanet.RetinaNetRegressionHead', class {});
  1122. this.registerType('torchvision.models.detection.roi_heads.RoIHeads', class {});
  1123. this.registerType('torchvision.models.detection.rpn.AnchorGenerator', class {});
  1124. this.registerType('torchvision.models.detection.rpn.RegionProposalNetwork', class {});
  1125. this.registerType('torchvision.models.detection.rpn.RPNHead', class {});
  1126. this.registerType('torchvision.models.detection.ssd.SSD', class {});
  1127. this.registerType('torchvision.models.detection.ssdlite.SSDLiteClassificationHead', class {});
  1128. this.registerType('torchvision.models.detection.ssdlite.SSDLiteFeatureExtractorMobileNet', class {});
  1129. this.registerType('torchvision.models.detection.ssdlite.SSDLiteHead', class {});
  1130. this.registerType('torchvision.models.detection.ssdlite.SSDLiteRegressionHead', class {});
  1131. this.registerType('torchvision.models.detection.transform.GeneralizedRCNNTransform', class {});
  1132. this.registerType('torchvision.models.efficientnet.EfficientNet', class {});
  1133. this.registerType('torchvision.models.efficientnet.MBConv', class {});
  1134. this.registerType('torchvision.models.googlenet.BasicConv2d', class {});
  1135. this.registerType('torchvision.models.googlenet.GoogLeNet', class {});
  1136. this.registerType('torchvision.models.googlenet.Inception', class {});
  1137. this.registerType('torchvision.models.googlenet.InceptionAux', class {});
  1138. this.registerType('torchvision.models.inception.BasicConv2d', class {});
  1139. this.registerType('torchvision.models.inception.Inception3', class {});
  1140. this.registerType('torchvision.models.inception.InceptionAux', class {});
  1141. this.registerType('torchvision.models.inception.InceptionA', class {});
  1142. this.registerType('torchvision.models.inception.InceptionB', class {});
  1143. this.registerType('torchvision.models.inception.InceptionC', class {});
  1144. this.registerType('torchvision.models.inception.InceptionD', class {});
  1145. this.registerType('torchvision.models.inception.InceptionE', class {});
  1146. this.registerType('torchvision.models.mnasnet._InvertedResidual', class {});
  1147. this.registerType('torchvision.models.mnasnet.MNASNet', class {});
  1148. this.registerType('torchvision.models.mobilenet.ConvBNReLU', class {});
  1149. this.registerType('torchvision.models.mobilenet.MobileNetV2', class {});
  1150. this.registerType('torchvision.models.mobilenet.InvertedResidual', class {});
  1151. this.registerType('torchvision.models.mobilenetv2.ConvBNActivation', class {});
  1152. this.registerType('torchvision.models.mobilenetv2.InvertedResidual', class {});
  1153. this.registerType('torchvision.models.mobilenetv2.MobileNetV2', class {});
  1154. this.registerType('torchvision.models.mobilenetv3.InvertedResidual', class {});
  1155. this.registerType('torchvision.models.mobilenetv3.MobileNetV3', class {});
  1156. this.registerType('torchvision.models.mobilenetv3.SqueezeExcitation', class {});
  1157. this.registerType('torchvision.models.resnet.Bottleneck', class {});
  1158. this.registerType('torchvision.models.resnet.BasicBlock', class {});
  1159. this.registerType('torchvision.models.quantization.mobilenet.QuantizableInvertedResidual', class {});
  1160. this.registerType('torchvision.models.quantization.mobilenet.QuantizableMobileNetV2', class {});
  1161. this.registerType('torchvision.models.quantization.mobilenetv2.QuantizableInvertedResidual', class {});
  1162. this.registerType('torchvision.models.quantization.mobilenetv2.QuantizableMobileNetV2', class {});
  1163. this.registerType('torchvision.models.quantization.resnet.QuantizableBasicBlock', class {});
  1164. this.registerType('torchvision.models.quantization.resnet.QuantizableBottleneck', class {});
  1165. this.registerType('torchvision.models.quantization.resnet.QuantizableResNet', class {});
  1166. this.registerType('torchvision.models.segmentation.deeplabv3.ASPP', class {});
  1167. this.registerType('torchvision.models.segmentation.deeplabv3.ASPPConv', class {});
  1168. this.registerType('torchvision.models.segmentation.deeplabv3.ASPPPooling', class {});
  1169. this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabHead', class {});
  1170. this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabV3', class {});
  1171. this.registerType('torchvision.models.segmentation.fcn.FCN', class {});
  1172. this.registerType('torchvision.models.segmentation.fcn.FCNHead', class {});
  1173. this.registerType('torchvision.models.shufflenetv2.ShuffleNetV2', class {});
  1174. this.registerType('torchvision.models.shufflenetv2.InvertedResidual', class {});
  1175. this.registerType('torchvision.models.squeezenet.Fire', class {});
  1176. this.registerType('torchvision.models.squeezenet.SqueezeNet', class {});
  1177. this.registerType('torchvision.models.resnet.ResNet', class {});
  1178. this.registerType('torchvision.models.vgg.VGG', class {});
  1179. this.registerType('torchvision.models.video.resnet.BasicBlock', class {});
  1180. this.registerType('torchvision.models.video.resnet.BasicStem', class {});
  1181. this.registerType('torchvision.models.video.resnet.Conv2Plus1D', class {});
  1182. this.registerType('torchvision.models.video.resnet.Conv3DNoTemporal', class {});
  1183. this.registerType('torchvision.models.video.resnet.Conv3DSimple', class {});
  1184. this.registerType('torchvision.models.video.resnet.R2Plus1dStem', class {});
  1185. this.registerType('torchvision.models.video.resnet.VideoResNet', class {});
  1186. this.registerType('torchvision.models._utils.IntermediateLayerGetter', class {});
  1187. this.registerType('torchvision.ops.deform_conv.DeformConv2d', class {});
  1188. this.registerType('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', class {});
  1189. this.registerType('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', class {});
  1190. this.registerType('torchvision.ops.feature_pyramid_network.LastLevelP6P7', class {});
  1191. this.registerType('torchvision.ops.misc.ConvNormActivation', class {});
  1192. this.registerType('torchvision.ops.misc.ConvTranspose2d', class {});
  1193. this.registerType('torchvision.ops.misc.FrozenBatchNorm2d', class {});
  1194. this.registerType('torchvision.ops.misc.SqueezeExcitation', class {});
  1195. this.registerType('torchvision.ops.poolers.LevelMapper', class {});
  1196. this.registerType('torchvision.ops.poolers.MultiScaleRoIAlign', class {});
  1197. this.registerType('torchvision.ops.stochastic_depth.StochasticDepth', class {});
  1198. this.registerType('torchvision.transforms.functional.InterpolationMode', class {});
  1199. this.registerType('torchvision.transforms.transforms.Compose', class {});
  1200. this.registerType('torchvision.transforms.transforms.CenterCrop', class {});
  1201. this.registerType('torchvision.transforms.transforms.Grayscale', class {});
  1202. this.registerType('torchvision.transforms.transforms.Normalize', class {});
  1203. this.registerType('torchvision.transforms.transforms.RandomAffine', class {});
  1204. this.registerType('torchvision.transforms.transforms.RandomCrop', class {});
  1205. this.registerType('torchvision.transforms.transforms.RandomHorizontalFlip', class {});
  1206. this.registerType('torchvision.transforms.transforms.Resize', class {});
  1207. this.registerType('torchvision.transforms.transforms.Scale', class {});
  1208. this.registerType('torchvision.transforms.transforms.ToPILImage', class {});
  1209. this.registerType('torchvision.transforms.transforms.ToTensor', class {});
  1210. this.registerFunction('builtins.annotate', function(type, value) {
  1211. if (type === self._builtins.int) {
  1212. return Number.isInteger(value) ? value : NaN;
  1213. }
  1214. if (type === self._builtins.float) {
  1215. return typeof value === 'number' ? value : NaN;
  1216. }
  1217. if (type === self._builtins.number) {
  1218. if (pytorch.Utility.isTensor(value)) {
  1219. value.resize_([]);
  1220. }
  1221. }
  1222. return value;
  1223. });
  1224. this.registerFunction('builtins.unchecked_cast', function(type, value) {
  1225. return value;
  1226. });
  1227. this.registerFunction('ops.prim.data', function(tensor) {
  1228. return tensor;
  1229. });
  1230. this.registerFunction('ops.prim.device', function(tensor) {
  1231. return tensor.device;
  1232. });
  1233. this.registerFunction('ops.prim.dtype', function(tensor) {
  1234. return tensor.dtype.scalar_type();
  1235. });
  1236. this.registerFunction('ops.prim.is_quantized', function(tensor) {
  1237. return tensor && tensor.__quantized__ === true;
  1238. });
  1239. this.registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
  1240. return value;
  1241. });
  1242. this.registerFunction('ops.prim.NumToTensor', function(value) {
  1243. const tensor = self.invoke('torch.Tensor', []);
  1244. tensor.value = value; // TODO
  1245. return tensor;
  1246. });
  1247. this.registerFunction('ops.prim.min', function(value) {
  1248. if (Array.isArray(value)) {
  1249. return Math.min.apply(null, value);
  1250. }
  1251. return Math.min.apply(null, arguments);
  1252. });
  1253. this.registerFunction('ops.prim.max', function(value) {
  1254. if (Array.isArray(value)) {
  1255. return Math.max.apply(null, value);
  1256. }
  1257. return Math.max.apply(null, arguments);
  1258. });
  1259. this.registerFunction('ops.prim.shape', function(tensor) {
  1260. return tensor && tensor.size ? tensor.size() : undefined;
  1261. });
  1262. this.registerFunction('ops.quantized.conv_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1263. const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
  1264. params.weight = weight;
  1265. params.bias = bias;
  1266. params.stride = stride;
  1267. params.padding =padding;
  1268. params.dilation = dilation;
  1269. params.groups = groups;
  1270. return params;
  1271. });
  1272. this.registerFunction('ops.quantized.conv1d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1273. const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
  1274. params.weight = weight;
  1275. params.bias = bias;
  1276. params.stride = stride;
  1277. params.padding =padding;
  1278. params.dilation = dilation;
  1279. params.groups = groups;
  1280. return params;
  1281. });
  1282. this.registerFunction('ops.quantized.conv2d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1283. const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
  1284. params.weight = weight;
  1285. params.bias = bias;
  1286. params.stride = stride;
  1287. params.padding =padding;
  1288. params.dilation = dilation;
  1289. params.groups = groups;
  1290. return params;
  1291. });
  1292. this.registerFunction('ops.quantized.conv3d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1293. const params = self.invoke('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', []);
  1294. params.weight = weight;
  1295. params.bias = bias;
  1296. params.stride = stride;
  1297. params.padding =padding;
  1298. params.dilation = dilation;
  1299. params.groups = groups;
  1300. return params;
  1301. });
  1302. this.registerFunction('ops.quantized.conv_transpose2d_prepack', function(weight, bias, stride, padding, output_padding, dilation, groups) {
  1303. const params = self.invoke('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', []);
  1304. params.weight = weight;
  1305. params.bias = bias;
  1306. params.stride = stride;
  1307. params.padding =padding;
  1308. params.output_padding = output_padding;
  1309. params.dilation = dilation;
  1310. params.groups = groups;
  1311. return params;
  1312. });
  1313. this.registerFunction('ops.quantized.linear_prepack', function(weight, bias) {
  1314. const params = self.invoke('__torch__.torch.classes.quantized.LinearPackedParamsBase', []);
  1315. params.weight = weight;
  1316. params.bias = bias;
  1317. return params;
  1318. });
  1319. this.registerFunction('ops.prim.RaiseException', function(message) {
  1320. throw new pytorch.Error(message);
  1321. });
  1322. this.registerFunction('builtins.range', function(start, stop, step) {
  1323. if (stop === undefined && step === undefined) {
  1324. if (Number.isInteger(start)) {
  1325. return Array(start).keys();
  1326. }
  1327. if (isNaN(start)) {
  1328. return [];
  1329. }
  1330. }
  1331. throw new pytorch.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
  1332. });
  1333. this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
  1334. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1335. const tensor = self.invoke(name, []);
  1336. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1337. return tensor;
  1338. });
  1339. this.registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
  1340. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1341. const tensor = self.invoke(name, []);
  1342. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1343. tensor.requires_grad = requires_grad;
  1344. tensor.backward_hooks = backward_hooks;
  1345. return tensor;
  1346. });
  1347. this.registerFunction('torch._utils._rebuild_parameter', function(data, requires_grad, backward_hooks) {
  1348. const obj = self.invoke('torch.nn.parameter.Parameter', [ data, requires_grad ]);
  1349. obj.backward_hooks = backward_hooks;
  1350. return obj;
  1351. });
  1352. this.registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
  1353. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1354. const tensor = self.invoke(name, []);
  1355. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1356. tensor.quantizer_params = quantizer_params;
  1357. tensor.requires_grad = requires_grad;
  1358. tensor.backward_hooks = backward_hooks;
  1359. return tensor;
  1360. });
  1361. this.registerFunction('torch._set_item', function(dict, key, value) {
  1362. dict[key] = value;
  1363. });
  1364. this.registerFunction('torch.__and__', function(left, right) {
  1365. return left && right;
  1366. });
  1367. this.registerFunction('torch.__contains__', function(dict, key) {
  1368. return dict[key] !== undefined;
  1369. });
  1370. this.registerFunction('torch.__derive_index', function(index, start, step) {
  1371. return start + index * step;
  1372. });
  1373. this.registerFunction('torch.__is__', function(left, right) {
  1374. if (left === null && right === null) {
  1375. return true;
  1376. }
  1377. if ((left !== null && right === null) || (left === null && right !== null)) {
  1378. return false;
  1379. }
  1380. throw new pytorch.Error("Unsupported 'torch.__is__' expression type.");
  1381. });
  1382. this.registerFunction('torch.__isnot__', function(left, right) {
  1383. if (left === null && right === null) {
  1384. return false;
  1385. }
  1386. if ((left !== null && right === null) || (left === null && right !== null)) {
  1387. return true;
  1388. }
  1389. throw new pytorch.Error("Unsupported 'torch.__isnot__' expression type.");
  1390. });
  1391. this.registerFunction('torch.__not__', function(value) {
  1392. if (typeof value === 'boolean') {
  1393. return !value;
  1394. }
  1395. throw new pytorch.Error("Unsupported 'torch.__not__' expression type.");
  1396. });
  1397. this.registerFunction('torch.__range_length', function(lo, hi, step) {
  1398. if (step === 0) {
  1399. throw new pytorch.Error('range() arg 3 must not be zero');
  1400. }
  1401. if (step > 0 && lo < hi) {
  1402. return 1 + (hi - 1 - lo) / step;
  1403. }
  1404. else if (step < 0 && lo > hi) {
  1405. return 1 + (lo - 1 - hi) / (0 - step);
  1406. }
  1407. return 0;
  1408. });
  1409. this.registerFunction('torch._unwrap_optional', function(value) {
  1410. return value; // TODO
  1411. });
  1412. this.registerFunction('torch.add', function(left, right) {
  1413. if (typeof left === 'number' && typeof right === 'number') {
  1414. return left * right;
  1415. }
  1416. if (Array.isArray(left) && Array.isArray(right)) {
  1417. return left.concat(right);
  1418. }
  1419. if (typeof left === 'string' && typeof right === 'string') {
  1420. return left + right;
  1421. }
  1422. throw new pytorch.Error('Unsupported torch.add expression type.');
  1423. });
  1424. this.registerFunction('torch.append', function(list, value) {
  1425. list.push(value);
  1426. return value;
  1427. });
  1428. this.registerFunction('torch.extend', function(list, value) {
  1429. list.push(...value);
  1430. });
  1431. this.registerFunction('torch.insert', function(list, index, value) {
  1432. list.splice(index, 0, value);
  1433. return value;
  1434. });
  1435. this.registerFunction('torch.clear', function(value) {
  1436. if (Object(value) === value) {
  1437. for (const key of Object.keys(value)) {
  1438. delete value[key];
  1439. }
  1440. }
  1441. });
  1442. this.registerFunction('torch.replace', function(value) {
  1443. return value;
  1444. });
  1445. this.registerFunction('torch.dict', function(args) {
  1446. const obj = {};
  1447. if (args) {
  1448. if (Array.isArray(args)) {
  1449. for (const pair of args) {
  1450. const key = pair[0];
  1451. const value = pair[1];
  1452. obj[key] = value;
  1453. }
  1454. }
  1455. else {
  1456. throw new pytorch.Error("'torch.dict' arguments not supported.");
  1457. }
  1458. }
  1459. return obj;
  1460. });
  1461. this.registerFunction('torch.dim', function(tensor) {
  1462. if (tensor && tensor.size) {
  1463. const size = tensor.size();
  1464. if (size) {
  1465. return size.length;
  1466. }
  1467. }
  1468. return NaN; // TODO
  1469. });
  1470. this.registerFunction('torch.numel', function(tensor) {
  1471. if (tensor && tensor.size) {
  1472. const size = tensor.size();
  1473. if (size) {
  1474. return size.reduce((a, b) => a * b, 1);
  1475. }
  1476. }
  1477. return NaN;
  1478. });
  1479. this.registerFunction('torch.eq', function(left, right) {
  1480. if (typeof left === 'string' && typeof right === 'string') {
  1481. return left === right;
  1482. }
  1483. if (typeof left === 'number' && typeof right === 'number') {
  1484. if (isNaN(left) && isNaN(right)) {
  1485. return true;
  1486. }
  1487. return left === right;
  1488. }
  1489. if (left === undefined || right === undefined) {
  1490. return true;
  1491. }
  1492. if (Array.isArray(left) && Array.isArray(right)) {
  1493. return left.length === right.length && left.every((item, index) => item === right[index]);
  1494. }
  1495. throw new pytorch.Error("Unsupported 'torch.eq' expression type.");
  1496. });
  1497. this.registerFunction('torch.floor', function(value) {
  1498. return Math.floor(value);
  1499. });
  1500. this.registerFunction('torch.ceil', function(value) {
  1501. return Math.ceil(value);
  1502. });
  1503. this.registerFunction('torch.floordiv', function(left, right) {
  1504. return Math.floor(left / right);
  1505. });
  1506. this.registerFunction('torch.format', function() {
  1507. const args = Array.from(arguments);
  1508. const list = args.shift().split(/({}D?)/);
  1509. return list.map((text) => {
  1510. if (text === '{}' || text === '{}D') {
  1511. const arg = args.shift();
  1512. return Array.isArray(arg) ? '[' + arg.map((item) => item.toString()).join(', ') + ']' : arg ? arg.toString() : '?';
  1513. }
  1514. return text;
  1515. }).join('');
  1516. });
  1517. this.registerFunction('torch.gt', function(left, right) {
  1518. if (typeof left === 'number' && typeof right === 'number') {
  1519. if (!isNaN(left) && !isNaN(right)) {
  1520. return left > right;
  1521. }
  1522. }
  1523. if (isNaN(left) && !isNaN(right)) {
  1524. return true;
  1525. }
  1526. throw new pytorch.Error("Unsupported 'torch.gt' expression type.");
  1527. });
  1528. this.registerFunction('torch.ge', function(left, right) {
  1529. if (typeof left === 'number' && typeof right === 'number') {
  1530. if (!isNaN(left) && !isNaN(right)) {
  1531. return left > right;
  1532. }
  1533. }
  1534. if (isNaN(left) && !isNaN(right)) {
  1535. return true;
  1536. }
  1537. throw new pytorch.Error("Unsupported 'torch.ge' expression type.");
  1538. });
  1539. this.registerFunction('torch.is_floating_point', function(tensor) {
  1540. const type = tensor.dtype.scalar_type();
  1541. return (type === 5 || type === 6 || type === 7);
  1542. });
  1543. this.registerFunction('torch.is_grad_enabled', function() {
  1544. return false;
  1545. });
  1546. this.registerFunction('torch.set_grad_enabled', function(/* value */) {
  1547. });
  1548. this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
  1549. return data;
  1550. });
  1551. this.registerFunction('torch.jit._pickle.build_doublelist', function(data) {
  1552. return data;
  1553. });
  1554. this.registerFunction('torch.jit._pickle.build_intlist', function(data) {
  1555. return data;
  1556. });
  1557. this.registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
  1558. return data;
  1559. });
  1560. this.registerFunction('torch.jit._pickle.build_tensor_from_id', function(data) {
  1561. const constants = self.context.getx('CONSTANTS');
  1562. return constants['c' + data.toString()];
  1563. });
  1564. this.registerFunction('torch.jit._pickle.restore_type_tag', function(value /*, type_str */) {
  1565. return value;
  1566. });
  1567. this.registerFunction('torch.keys', function(dict) {
  1568. return Object.keys(dict);
  1569. });
  1570. this.registerFunction('torch.len', function(value) {
  1571. if (Array.isArray(value)) {
  1572. return value.length;
  1573. }
  1574. if (value && value.shape && value.__len__) {
  1575. return value.__len__();
  1576. }
  1577. return NaN;
  1578. });
  1579. this.registerFunction('torch.le', function(left, right) {
  1580. if (typeof left === 'number' && typeof right === 'number') {
  1581. if (isNaN(left) || isNaN(right)) {
  1582. return false;
  1583. }
  1584. return left <= right;
  1585. }
  1586. if (left === undefined || right === undefined) {
  1587. return true;
  1588. }
  1589. throw new pytorch.Error("Unsupported 'torch.le' expression type.");
  1590. });
  1591. this.registerFunction('torch.list', function(args) {
  1592. return args;
  1593. });
  1594. this.registerFunction('torch.list_with_default', function(size /*, defaults */) {
  1595. return size;
  1596. });
  1597. this.registerFunction('torch.lt', function(left, right) {
  1598. if (typeof left === 'number' && typeof right === 'number') {
  1599. return left < right;
  1600. }
  1601. throw new pytorch.Error("Unsupported 'torch.lt' expression type.");
  1602. });
  1603. this.registerFunction('torch.mul', function(left, right) {
  1604. if (typeof left === 'number' && typeof right === 'number') {
  1605. return left * right;
  1606. }
  1607. if (isNaN(left) || isNaN(right)) {
  1608. return NaN;
  1609. }
  1610. if (Array.isArray(left) && left.every((value) => typeof value === 'number') && typeof right === 'number') {
  1611. return left.map((value) => value * right);
  1612. }
  1613. throw new pytorch.Error("Unsupported 'torch.mul' expression type.");
  1614. });
  1615. this.registerFunction('torch.div', function(left, right) {
  1616. if (typeof left === 'number' && typeof right === 'number') {
  1617. return left / right;
  1618. }
  1619. if (isNaN(left) || isNaN(right)) {
  1620. return NaN;
  1621. }
  1622. throw new pytorch.Error("Unsupported 'torch.div' expression type.");
  1623. });
  1624. this.registerFunction('torch.round', function(value) {
  1625. if (typeof value === 'number') {
  1626. return Math.round(value);
  1627. }
  1628. if (isNaN(value)) {
  1629. return value;
  1630. }
  1631. throw new pytorch.Error("Unsupported 'torch.round' expression type.");
  1632. });
  1633. this.registerFunction('torch.remainder', function(left, right) {
  1634. if (typeof left === 'number' && typeof right === 'number') {
  1635. return left % right;
  1636. }
  1637. if (isNaN(left) || isNaN(right)) {
  1638. return NaN;
  1639. }
  1640. throw new pytorch.Error("Unsupported 'torch.remainder' expression type.");
  1641. });
  1642. this.registerFunction('torch.ne', function(left, right) {
  1643. if (typeof left === 'boolean' && typeof right === 'boolean') {
  1644. return left !== right;
  1645. }
  1646. if (typeof left === 'number' && typeof right === 'number') {
  1647. if (isNaN(left) || isNaN(right)) {
  1648. return false;
  1649. }
  1650. return left !== right;
  1651. }
  1652. if (Array.isArray(left) && Array.isArray(right) && left.length === right.length) {
  1653. return false;
  1654. }
  1655. if (typeof left === 'string' && typeof right === 'string') {
  1656. return left !== right;
  1657. }
  1658. if (left === undefined || right === undefined) {
  1659. return true;
  1660. }
  1661. throw new pytorch.Error("Unsupported 'torch.ne' expression type.");
  1662. });
  1663. this.registerFunction('torch.neg', function(value) {
  1664. if (typeof value === 'number') {
  1665. return -value;
  1666. }
  1667. throw new pytorch.Error("Unsupported 'torch.neg' expression type.");
  1668. });
  1669. this.registerFunction('torch.q_scale', function(/* tensor */) {
  1670. return -1; // TODO
  1671. });
  1672. this.registerFunction('torch.t', function(tensor) {
  1673. return tensor;
  1674. });
  1675. this.registerFunction('torch.size', function(tensor, dim) {
  1676. if (tensor && tensor.size) {
  1677. const size = tensor.size();
  1678. if (Array.isArray(size)) {
  1679. if (dim === undefined) {
  1680. return size;
  1681. }
  1682. if (Number.isInteger(dim)) {
  1683. if (dim >= 0 && dim < size.length) {
  1684. return size[dim];
  1685. }
  1686. if (dim < 0 && -dim < size.length) {
  1687. return size[size.length + dim];
  1688. }
  1689. }
  1690. throw new pytorch.Error('Dimension out of range (expected to be in range of ' + JSON.stringify(size) + ', but got ' + JSON.stringify(dim) + ').');
  1691. }
  1692. }
  1693. if (Number.isInteger(dim)) {
  1694. return NaN;
  1695. }
  1696. return [];
  1697. });
  1698. this.registerFunction('torch.slice', function(l, start, end, step) {
  1699. if (!Array.isArray(l)) {
  1700. throw new pytorch.Error('Slicing expected array');
  1701. }
  1702. step = step || 1;
  1703. if (step !== 1) {
  1704. throw new pytorch.Error('Slicing only supports step=1');
  1705. }
  1706. start = Math.max(0, start >= 0 ? start : l.length + start);
  1707. end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER);
  1708. return l.slice(start, end);
  1709. });
  1710. this.registerFunction('torch.sub', function(left, right) {
  1711. if (typeof left === 'number' && typeof right === 'number') {
  1712. return left - right;
  1713. }
  1714. throw new pytorch.Error("Unsupported 'torch.sub' expression type.");
  1715. });
  1716. this.registerFunction('torch.values', function(dict) {
  1717. return Object.keys(dict).map((key) => dict[key]);
  1718. });
  1719. this.registerFunction('torch.warn', function() {
  1720. });
  1721. this.registerFunction('builtins.uninitialized', function(/* type */) {
  1722. return undefined;
  1723. });
  1724. this.registerType('torch.device', class {
  1725. constructor(type, index) {
  1726. this.type = type;
  1727. if (index) {
  1728. this.index = index;
  1729. }
  1730. }
  1731. });
  1732. this.registerType('torch.dtype', class {
  1733. constructor(type) {
  1734. this._type = type;
  1735. this._data = pytorch.Utility.getScalarType(type);
  1736. }
  1737. scalar_type() {
  1738. return this._type;
  1739. }
  1740. itemsize() {
  1741. return this._data.itemsize;
  1742. }
  1743. __reduce__() {
  1744. return this._data.name;
  1745. }
  1746. __str__() {
  1747. return 'torch.' + this._data.name;
  1748. }
  1749. });
  1750. this.registerType('torch.utils.hooks.RemovableHandle', class {
  1751. __setstate__(state) {
  1752. this.hooks_dict_ref = state[0] || new Map();
  1753. this.id = state[1];
  1754. }
  1755. });
  1756. this.registerType('torch.storage._StorageBase', class {
  1757. constructor(size, dtype) {
  1758. this._size = size;
  1759. this._dtype = dtype;
  1760. this._device = null;
  1761. }
  1762. get device() {
  1763. return this._device;
  1764. }
  1765. get dtype() {
  1766. return this._dtype;
  1767. }
  1768. element_size() {
  1769. return this._dtype.element_size;
  1770. }
  1771. size() {
  1772. return this._size;
  1773. }
  1774. get data() {
  1775. return this._cdata;
  1776. }
  1777. _set_cdata(data) {
  1778. const length = this.size() * this.dtype.itemsize();
  1779. if (length !== data.length) {
  1780. throw new pytorch.Error('Storage data size mismatch.');
  1781. }
  1782. this._cdata = data;
  1783. }
  1784. _set_from_file(unpickler) {
  1785. const buffer = unpickler.read(8);
  1786. const reader = new base.BinaryReader(buffer);
  1787. const size = reader.int64();
  1788. if (size !== this.size()) {
  1789. throw new pytorch.Error('Storage size mismatch.');
  1790. }
  1791. const itemsize = this.dtype.itemsize();
  1792. const data = unpickler.stream(itemsize * size);
  1793. this._set_cdata(data);
  1794. }
  1795. static _new_with_file(unpickler) {
  1796. const buffer = unpickler.read(8);
  1797. const reader = new base.BinaryReader(buffer);
  1798. const size = reader.int64();
  1799. const storage = new this(size);
  1800. const itemsize = storage.dtype.itemsize();
  1801. const data = unpickler.stream(itemsize * size);
  1802. storage._set_cdata(data);
  1803. return storage;
  1804. }
  1805. });
  1806. this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase {
  1807. constructor() {
  1808. super();
  1809. throw new python.Error('_UntypedStorage not implemented.');
  1810. }
  1811. });
  1812. this.registerType('torch.storage._TypedStorage', class {
  1813. constructor() {
  1814. throw new python.Error('_TypedStorage not implemented.');
  1815. }
  1816. });
  1817. this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage {
  1818. constructor() {
  1819. super();
  1820. throw new python.Error('_LegacyStorage not implemented.');
  1821. }
  1822. });
  1823. this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase {
  1824. constructor(size) {
  1825. super(size, torch.complex64);
  1826. }
  1827. });
  1828. this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase {
  1829. constructor(size) {
  1830. super(size, torch.complex128);
  1831. }
  1832. });
  1833. this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase {
  1834. constructor(size) {
  1835. super(size, torch.bool);
  1836. }
  1837. });
  1838. this.registerType('torch.ByteStorage', class extends torch_storage._StorageBase {
  1839. constructor(size) {
  1840. super(size, torch.uint8);
  1841. }
  1842. });
  1843. this.registerType('torch.CharStorage', class extends torch_storage._StorageBase {
  1844. constructor(size) {
  1845. super(size, torch.int8);
  1846. }
  1847. });
  1848. this.registerType('torch.ShortStorage', class extends torch_storage._StorageBase {
  1849. constructor(size) {
  1850. super(size, torch.int16);
  1851. }
  1852. });
  1853. this.registerType('torch.IntStorage', class extends torch_storage._StorageBase {
  1854. constructor(size) {
  1855. super(size, torch.int32);
  1856. }
  1857. });
  1858. this.registerType('torch.LongStorage', class extends torch_storage._StorageBase {
  1859. constructor(size) {
  1860. super(size, torch.int64);
  1861. }
  1862. });
  1863. this.registerType('torch.HalfStorage', class extends torch_storage._StorageBase {
  1864. constructor(size) {
  1865. super(size, torch.float16);
  1866. }
  1867. });
  1868. this.registerType('torch.FloatStorage', class extends torch_storage._StorageBase {
  1869. constructor(size) {
  1870. super(size, torch.float32);
  1871. }
  1872. });
  1873. this.registerType('torch.DoubleStorage', class extends torch_storage._StorageBase {
  1874. constructor(size) {
  1875. super(size, torch.float64);
  1876. }
  1877. });
  1878. this.registerType('torch.QInt8Storage', class extends torch_storage._StorageBase {
  1879. constructor(size) {
  1880. super(size, torch.qint8);
  1881. }
  1882. });
  1883. this.registerType('torch.QUInt8Storage', class extends torch_storage._StorageBase {
  1884. constructor(size) {
  1885. super(size, torch.quint8);
  1886. }
  1887. });
  1888. this.registerType('torch.QInt32Storage', class extends torch_storage._StorageBase {
  1889. constructor(size) {
  1890. super(size, torch.qint32);
  1891. }
  1892. });
  1893. this.registerType('torch.BFloat16Storage', class extends torch_storage._StorageBase {
  1894. constructor(size) {
  1895. super(size, torch.bfloat16);
  1896. }
  1897. });
  1898. this.registerType('torch.Size', class extends Array {
  1899. constructor(size) {
  1900. super(size.length);
  1901. for (let i = 0; i < size.length; i++) {
  1902. this[i] = size[i];
  1903. }
  1904. }
  1905. __len__() {
  1906. return this.length;
  1907. }
  1908. });
  1909. this.registerType('torch.Tensor', class {
  1910. constructor() {
  1911. }
  1912. get device() {
  1913. return this.storage().device;
  1914. }
  1915. get dtype() {
  1916. return this.storage().dtype;
  1917. }
  1918. get shape() {
  1919. return this._shape;
  1920. }
  1921. size() {
  1922. return this._shape;
  1923. }
  1924. storage() {
  1925. if (!this._storage) {
  1926. const name = this.__class__.__name__ == 'Tensor' ? 'FloatStorage' : this.__storage__.__name__.replace('Tensor', 'Storage');
  1927. this._storage = self.invoke(this.__class__.__module__ + '.' + name, []);
  1928. }
  1929. return this._storage;
  1930. }
  1931. storage_offset() {
  1932. return this._storage_offset;
  1933. }
  1934. stride() {
  1935. return this._stride;
  1936. }
  1937. resize_(shape) {
  1938. this._shape = shape;
  1939. }
  1940. __len__() {
  1941. return this._shape[0];
  1942. }
  1943. __setstate__(state) {
  1944. this._storage = state[0];
  1945. this._storage_offset = state[1];
  1946. this._shape = state[2];
  1947. this._stride = state[3];
  1948. }
  1949. __bool__() {
  1950. return true;
  1951. }
  1952. __int__() {
  1953. const storage = this.storage();
  1954. if (storage && storage.dtype.__reduce__() === 'int64' && storage.data.length === 8) {
  1955. const buffer = storage.data;
  1956. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1957. return view.getInt64(0, true);
  1958. }
  1959. return NaN;
  1960. }
  1961. __float__() {
  1962. const storage = this.storage();
  1963. if (storage && storage.dtype.__reduce__() === 'float32') {
  1964. if (storage.size() !== undefined && storage.data.length === 4) {
  1965. const buffer = storage.data;
  1966. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1967. return view.getFloat32(0, true);
  1968. }
  1969. }
  1970. return NaN;
  1971. }
  1972. __str__() {
  1973. return 'tensor(...)';
  1974. }
  1975. });
  1976. this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor {
  1977. constructor(data, requires_grad) {
  1978. super();
  1979. if (!data) {
  1980. data = self.invoke('torch.Tensor', [[]]);
  1981. }
  1982. this.data = data;
  1983. this.requires_grad = requires_grad !== undefined ? requires_grad : true;
  1984. }
  1985. __setstate__(state) {
  1986. switch (state.length) {
  1987. case 3:
  1988. this.data = null;
  1989. break;
  1990. case 4:
  1991. this.data = state[0];
  1992. break;
  1993. case 5:
  1994. this.data = state[0];
  1995. break;
  1996. default:
  1997. throw new pytorch.Error("Unsupported parameter state length '" + state.length + "'.");
  1998. }
  1999. }
  2000. });
  2001. this.registerType('torch.nn.parameter.UninitializedParameter', class extends torch_nn_parameter.Parameter {
  2002. constructor(requires_grad /*, device, dtype */) {
  2003. super(undefined, requires_grad);
  2004. }
  2005. });
  2006. this.registerType('torch.BoolTensor', class extends torch.Tensor {});
  2007. this.registerType('torch.ByteTensor', class extends torch.Tensor {});
  2008. this.registerType('torch.CharTensor', class extends torch.Tensor {});
  2009. this.registerType('torch.ShortTensor', class extends torch.Tensor {});
  2010. this.registerType('torch.IntTensor', class extends torch.Tensor {});
  2011. this.registerType('torch.LongTensor', class extends torch.Tensor {});
  2012. this.registerType('torch.HalfTensor', class extends torch.Tensor {});
  2013. this.registerType('torch.FloatTensor', class extends torch.Tensor {});
  2014. this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
  2015. this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {});
  2016. this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {});
  2017. this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
  2018. this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
  2019. this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});
  2020. this.registerType('torch.BFloat16Tensor', class extends torch.Tensor {});
  2021. this.registerType('torch.cuda.FloatTensor', class extends torch.Tensor {});
  2022. this.registerType('torch.cuda.DoubleTensor', class extends torch.Tensor {});
  2023. torch.uint8 = new torch.dtype(pytorch.ScalarType.uint8);
  2024. torch.int8 = new torch.dtype(pytorch.ScalarType.int8);
  2025. torch.int16 = new torch.dtype(pytorch.ScalarType.int16);
  2026. torch.int32 = new torch.dtype(pytorch.ScalarType.int32);
  2027. torch.int64 = new torch.dtype(pytorch.ScalarType.int64);
  2028. torch.float16 = new torch.dtype(pytorch.ScalarType.float16);
  2029. torch.float32 = new torch.dtype(pytorch.ScalarType.float32);
  2030. torch.float64 = new torch.dtype(pytorch.ScalarType.float64);
  2031. torch.complex32 = new torch.dtype(pytorch.ScalarType.complex32);
  2032. torch.complex64 = new torch.dtype(pytorch.ScalarType.complex64);
  2033. torch.complex128 = new torch.dtype(pytorch.ScalarType.complex128);
  2034. torch.bool = new torch.dtype(pytorch.ScalarType.boolean);
  2035. torch.qint8 = new torch.dtype(pytorch.ScalarType.qint8);
  2036. torch.quint8 = new torch.dtype(pytorch.ScalarType.quint8);
  2037. torch.qint32 = new torch.dtype(pytorch.ScalarType.qint32);
  2038. torch.bfloat16 = new torch.dtype(pytorch.ScalarType.bfloat16);
  2039. }
  2040. debug(file) {
  2041. const buffer = this.source(file + '.debug_pkl');
  2042. if (buffer) {
  2043. return null;
  2044. // const unpickler = python.Unpickler.open(buffer, this);
  2045. // return unpickler.load();
  2046. }
  2047. return null;
  2048. }
  2049. };
  2050. pytorch.Container = class {
  2051. static open(context) {
  2052. const zip = pytorch.Container.Zip.open(context.entries('zip'));
  2053. if (zip) {
  2054. return zip;
  2055. }
  2056. const pickle = pytorch.Container.Pickle.open(context.stream);
  2057. if (pickle) {
  2058. return pickle;
  2059. }
  2060. const tar = pytorch.Container.Tar.open(context.entries('tar'));
  2061. if (tar) {
  2062. return tar;
  2063. }
  2064. return null;
  2065. }
  2066. };
  2067. pytorch.Container.Tar = class {
  2068. static open(entries) {
  2069. if (entries.has('pickle')) {
  2070. return new pytorch.Container.Tar(entries);
  2071. }
  2072. return null;
  2073. }
  2074. constructor(entries) {
  2075. this._entries = entries;
  2076. this._graphs = [ this ];
  2077. }
  2078. set metadata(value) {
  2079. this._metadata = value;
  2080. }
  2081. set exception(value) {
  2082. this._exceptionCallack = value;
  2083. }
  2084. get format() {
  2085. return 'PyTorch v0.1.1';
  2086. }
  2087. get graphs() {
  2088. this._unpickle();
  2089. return this._graphs;
  2090. }
  2091. get littleEndian() {
  2092. this._unpickle();
  2093. return this._littleEndian;
  2094. }
  2095. _unpickle() {
  2096. if (!this._entries) {
  2097. return;
  2098. }
  2099. this._type = '';
  2100. this._data = null;
  2101. this._littleEndian = true;
  2102. const execution = new pytorch.Execution(null, this._exceptionCallback);
  2103. const entries = {};
  2104. for (const entry of this._entries) {
  2105. const key = entry[0];
  2106. const value = entry[1];
  2107. entries[key] = value.peek();
  2108. }
  2109. this._exceptionCallback = null;
  2110. this._entries = null;
  2111. if (entries.sys_info) {
  2112. const unpickler = python.Unpickler.open(entries.sys_info, execution);
  2113. const sys_info = unpickler.load();
  2114. if (sys_info.protocol_version != 1000) {
  2115. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2116. }
  2117. if (sys_info.type_sizes &&
  2118. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  2119. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  2120. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  2121. throw new pytorch.Error('Unsupported type sizes.');
  2122. }
  2123. this._littleEndian = sys_info.little_endian;
  2124. }
  2125. const deserialized_objects = {};
  2126. if (entries.storages) {
  2127. const data = entries.storages;
  2128. const unpickler = python.Unpickler.open(data, execution);
  2129. const num_storages = unpickler.load();
  2130. for (let i = 0; i < num_storages; i++) {
  2131. const args = unpickler.load();
  2132. const key = args[0];
  2133. const storage_type = execution.type(args[2]);
  2134. const obj = storage_type._new_with_file(unpickler);
  2135. deserialized_objects[key] = obj;
  2136. }
  2137. /*
  2138. let storage_views = unpickler.load();
  2139. for target_cdata, root_cdata, offset, size in storage_views:
  2140. root = deserialized_objects[root_cdata]
  2141. deserialized_objects[target_cdata] = root[offset:offset + size]
  2142. */
  2143. }
  2144. if (entries.tensors) {
  2145. const data = entries.tensors;
  2146. const unpickler = python.Unpickler.open(data, execution);
  2147. const num_tensors = unpickler.load();
  2148. for (let i = 0; i < num_tensors; i++) {
  2149. const args = unpickler.load();
  2150. const key = args[0];
  2151. const storage_id = args[1];
  2152. const storage = deserialized_objects[storage_id];
  2153. const int32 = (unpickler) => {
  2154. const buffer = unpickler.read(4);
  2155. const reader = new base.BinaryReader(buffer);
  2156. return reader.int32();
  2157. };
  2158. const int64 = (unpickler) => {
  2159. const buffer = unpickler.read(8);
  2160. const reader = new base.BinaryReader(buffer);
  2161. return reader.int64();
  2162. };
  2163. const ndim = int32(unpickler);
  2164. unpickler.read(4);
  2165. const shape = new Array(ndim);
  2166. for (let j = 0; j < ndim; j++) {
  2167. shape[j] = int64(unpickler);
  2168. }
  2169. const stride = new Array(ndim);
  2170. for (let j = 0; j < ndim; j++) {
  2171. stride[j] = int64(unpickler);
  2172. }
  2173. const storage_offset = int64(unpickler);
  2174. const tensor = execution.invoke('torch._utils._rebuild_tensor', [ storage, storage_offset, shape, stride ]);
  2175. deserialized_objects[key] = tensor;
  2176. }
  2177. }
  2178. if (entries.pickle) {
  2179. const unpickler = python.Unpickler.open(entries.pickle, execution);
  2180. unpickler.persistent_load = (saved_id) => deserialized_objects[saved_id];
  2181. const obj = unpickler.load();
  2182. const weights = pytorch.Utility.findWeights(obj);
  2183. if (weights) {
  2184. this._graphs = weights;
  2185. for (const graph of this._graphs) {
  2186. graph.type = 'weights';
  2187. }
  2188. }
  2189. else {
  2190. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2191. }
  2192. }
  2193. }
  2194. };
  2195. pytorch.Container.Pickle = class {
  2196. static open(stream) {
  2197. const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  2198. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  2199. return new pytorch.Container.Pickle(stream);
  2200. }
  2201. return null;
  2202. }
  2203. constructor(stream) {
  2204. this._stream = stream;
  2205. this._graphs = [ this ];
  2206. }
  2207. set metadata(value) {
  2208. this._metadata = value;
  2209. }
  2210. set exception(value) {
  2211. this._exceptionCallback = value;
  2212. }
  2213. get format() {
  2214. return 'PyTorch v0.1.10';
  2215. }
  2216. get graphs() {
  2217. this._unpickle();
  2218. return this._graphs;
  2219. }
  2220. get littleEndian() {
  2221. this._unpickle();
  2222. return this._littleEndian;
  2223. }
  2224. _unpickle() {
  2225. if (!this._stream) {
  2226. return;
  2227. }
  2228. const data = this._stream.length < 0x7ffff000 ? this._stream.peek() : this._stream;
  2229. const execution = new pytorch.Execution(null, this._exceptionCallback);
  2230. const unpickler = python.Unpickler.open(data, execution);
  2231. this._stream = null;
  2232. this._exceptionCallback = null;
  2233. unpickler.load(); // magic_number
  2234. const protocol_version = unpickler.load();
  2235. if (protocol_version != 1001) {
  2236. throw new pytorch.Error("Unsupported protocol version '" + protocol_version + "'.");
  2237. }
  2238. const sys_info = unpickler.load();
  2239. if (sys_info.protocol_version != 1001) {
  2240. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2241. }
  2242. this._littleEndian = sys_info.little_endian;
  2243. const module_source_map = new Map();
  2244. const deserialized_objects = new Map();
  2245. unpickler.persistent_load = (saved_id) => {
  2246. const typename = saved_id.shift();
  2247. const data = saved_id;
  2248. switch (typename) {
  2249. case 'module': {
  2250. const module = data[0];
  2251. const source = data[2];
  2252. module_source_map.set(module, source);
  2253. return data[0];
  2254. }
  2255. case 'storage': {
  2256. const name = data.shift();
  2257. const storage_type = execution.type(name);
  2258. const root_key = data.shift();
  2259. data.shift(); // location
  2260. const size = data.shift();
  2261. const view_metadata = data.shift();
  2262. if (!deserialized_objects.has(root_key)) {
  2263. const obj = new storage_type(size);
  2264. deserialized_objects.set(root_key, obj);
  2265. }
  2266. if (view_metadata) {
  2267. const view_key = view_metadata.shift();
  2268. view_metadata.shift(); // view_offset
  2269. view_metadata.shift(); // view_size
  2270. if (!deserialized_objects.has(view_key)) {
  2271. const view = null; // storage.slice(view_offset, view_offset + view_size);
  2272. deserialized_objects.set(view_key, view);
  2273. }
  2274. return deserialized_objects.get(view_key);
  2275. }
  2276. return deserialized_objects.get(root_key);
  2277. }
  2278. default: {
  2279. throw new pytorch.Error("Unsupported persistent load type '" + typename + "'.");
  2280. }
  2281. }
  2282. };
  2283. const obj = unpickler.load();
  2284. if (!obj) {
  2285. throw new pytorch.Error('File format is not PyTorch.');
  2286. }
  2287. if (obj === 'None') {
  2288. throw new pytorch.Error("File contains 'None' root object.");
  2289. }
  2290. const deserialized_storage_keys = unpickler.load();
  2291. for (const deserialized_storage_key of deserialized_storage_keys) {
  2292. const storage = deserialized_objects.get(deserialized_storage_key);
  2293. storage._set_from_file(unpickler);
  2294. }
  2295. this._graphs = pytorch.Utility.find(obj);
  2296. }
  2297. };
  2298. pytorch.Container.Zip = class {
  2299. static open(entries) {
  2300. if (entries.size > 0) {
  2301. let prefix = [];
  2302. const paths = Array.from(entries.keys()).map((path) => path.split('/').reverse());
  2303. for (;;) {
  2304. const set = new Set(paths.map((path) => path.length > 0 ? path.pop() : null));
  2305. if (set.size !== 1 || set.keys().next().value === null) {
  2306. break;
  2307. }
  2308. prefix.push(set.keys().next().value);
  2309. }
  2310. prefix = prefix.join('/');
  2311. prefix = prefix.length > 0 ? prefix + '/' : prefix;
  2312. entries = new Map(Array.from(entries).map((entry) => [ entry[0].substring(prefix.length), entry[1] ]));
  2313. if (entries.has('model.json')) {
  2314. try {
  2315. const stream = entries.get('model.json');
  2316. const buffer = stream.peek();
  2317. const decoder = new TextDecoder('utf-8');
  2318. const content = decoder.decode(buffer);
  2319. const model = JSON.parse(content);
  2320. if (model.mainModule) {
  2321. return new pytorch.Container.Zip.Json(entries, model);
  2322. }
  2323. }
  2324. catch (error) {
  2325. // continue regardless of error
  2326. }
  2327. }
  2328. if (entries.has('data.pkl')) {
  2329. return new pytorch.Container.Zip.Pickle(entries);
  2330. }
  2331. if (Array.from(entries.keys()).find((name) => name.startsWith('.data/'))) {
  2332. return new pytorch.Container.Zip.Package(entries);
  2333. }
  2334. }
  2335. return null;
  2336. }
  2337. constructor(entries) {
  2338. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
  2339. this._entries = entries;
  2340. this._producer = '';
  2341. }
  2342. set metadata(value) {
  2343. this._metadata = value;
  2344. }
  2345. set exception(value) {
  2346. this._exceptionCallback = value;
  2347. }
  2348. get producer() {
  2349. return this._producer;
  2350. }
  2351. get littleEndian() {
  2352. return true;
  2353. }
  2354. version(name) {
  2355. const stream = this._entries.get(name);
  2356. if (stream) {
  2357. const decoder = new TextDecoder('utf-8');
  2358. const buffer = stream.peek();
  2359. const text = decoder.decode(buffer);
  2360. const value = text.split('\n').shift();
  2361. // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
  2362. // kProducedFileFormatVersion
  2363. const versions = new Map([
  2364. [ '1', 'v1.3' ],
  2365. [ '2', 'v1.5' ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
  2366. [ '3', 'v1.6' ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
  2367. [ '4', 'v1.6' ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
  2368. [ '5', 'v1.7' ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
  2369. [ '6', 'v1.9' ], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
  2370. [ '7', 'v1.10' ] // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
  2371. ]);
  2372. if (!versions.has(value)) {
  2373. this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
  2374. }
  2375. return versions.get(value) || 'v-' + value.toString();
  2376. }
  2377. return '';
  2378. }
  2379. };
  2380. pytorch.Container.Zip.Script = class {
  2381. constructor(entries, execution, location, name) {
  2382. this._entries = entries;
  2383. this._execution = execution;
  2384. this._location = location || {};
  2385. this._name = name || '';
  2386. }
  2387. get name() {
  2388. return this._name;
  2389. }
  2390. get type() {
  2391. return 'script';
  2392. }
  2393. trace() {
  2394. this._inputs = [];
  2395. this._outputs = [];
  2396. this.execution.reset();
  2397. if (this.data.forward) {
  2398. const args = [ this.data ]; // self
  2399. if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
  2400. for (const parameter of this.data.forward.__code__.parameters) {
  2401. const defaultValue = (type, name) => {
  2402. if (type.type === 'type' && type.name.type) {
  2403. switch (type.name.value) {
  2404. case 'Tensor': {
  2405. const tensor = this.execution.invoke('torch.Tensor', []);
  2406. tensor.__variable__ = name;
  2407. tensor.__origin__ = 'graph-input';
  2408. return tensor;
  2409. }
  2410. case 'Tuple': {
  2411. return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']'));
  2412. }
  2413. case 'List': {
  2414. return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']' ));
  2415. }
  2416. case 'Dict': {
  2417. if (type.arguments[1].name.value === 'Tensor') {
  2418. const Dict = class extends Map {
  2419. get(key) {
  2420. if (!super.has(key)) {
  2421. super.set(key, defaultValue(type.arguments[1], name + ':' + key));
  2422. }
  2423. return super.get(key);
  2424. }
  2425. };
  2426. return new Dict();
  2427. }
  2428. return new Map();
  2429. }
  2430. case 'int': {
  2431. return 0;
  2432. }
  2433. case 'float': {
  2434. return 0.0;
  2435. }
  2436. case 'bool': {
  2437. return false;
  2438. }
  2439. case 'Optional': {
  2440. return undefined;
  2441. }
  2442. case 'str':
  2443. return '';
  2444. default: {
  2445. break;
  2446. }
  2447. }
  2448. }
  2449. throw new pytorch.Error("Unsupported function parameter type '" + JSON.stringify(type) + "'.");
  2450. };
  2451. if (parameter.name !== 'self') {
  2452. const type = parameter.parameterType;
  2453. const value = defaultValue(type, parameter.name);
  2454. if (pytorch.Utility.isTensor(value)) {
  2455. value.__variable__ = parameter.name;
  2456. value.__origin__ = 'graph-input';
  2457. this._inputs.push(parameter.name);
  2458. }
  2459. args.push(value);
  2460. }
  2461. }
  2462. }
  2463. const result = this.data.forward.__call__(args);
  2464. if (Array.isArray(result)) {
  2465. for (const output of result) {
  2466. if (pytorch.Utility.isTensor(output)) {
  2467. this._outputs.push(output.__variable__);
  2468. }
  2469. }
  2470. }
  2471. else if (pytorch.Utility.isTensor(result)) {
  2472. this._outputs.push(result.__variable__);
  2473. }
  2474. else if (Object(result) === result) {
  2475. for (const key of Object.keys(result)) {
  2476. const value = result[key];
  2477. if (Array.isArray(value)) {
  2478. for (const output of value) {
  2479. if (pytorch.Utility.isTensor(output)) {
  2480. this._outputs.push(output.__variable__);
  2481. }
  2482. }
  2483. }
  2484. else if (pytorch.Utility.isTensor(value)) {
  2485. this._outputs.push(value.__variable__);
  2486. }
  2487. }
  2488. }
  2489. this._nodes = this.execution.nodes;
  2490. return true;
  2491. }
  2492. throw new pytorch.Error("Module 'forward' not implemented.");
  2493. }
  2494. get execution() {
  2495. const directory = this._location.code || 'code/';
  2496. const sources = new Map();
  2497. for (const entry of this._entries) {
  2498. const name = entry[0];
  2499. if (name.startsWith(directory) && name.endsWith('.py')) {
  2500. const file = name.substring(directory.length);
  2501. if (sources.has(file)) {
  2502. throw new pytorch.Error("Duplicate source file '" + file + "'.");
  2503. }
  2504. const stream = entry[1];
  2505. const buffer = stream.peek();
  2506. this._execution.add(file, buffer);
  2507. sources.set(file, buffer);
  2508. }
  2509. }
  2510. for (const entry of sources) {
  2511. const name = entry[0].replace(/\.py$/, '').split('/').join('.');
  2512. const module = this._execution.import(name);
  2513. this._execution.context.setx(name, module);
  2514. }
  2515. const torch = this._execution.import('torch');
  2516. this._execution.context.setx('Tensor', torch.Tensor);
  2517. const constants = {};
  2518. for (let i = 0; i < this.constants.length; i++) {
  2519. constants['c' + i.toString()] = this.constants[i];
  2520. }
  2521. this._execution.context.set('CONSTANTS', constants);
  2522. return this._execution;
  2523. }
  2524. _unpickle(data, storage_map) {
  2525. const loaded_storages = new Map();
  2526. const execution = this.execution;
  2527. const unpickler = python.Unpickler.open(data, execution);
  2528. unpickler.persistent_load = (saved_id) => {
  2529. const typename = saved_id.shift();
  2530. switch (typename) {
  2531. case 'storage': {
  2532. const name = saved_id.shift();
  2533. const storage_type = execution.type(name);
  2534. if (!storage_type) {
  2535. throw new pytorch.Error("Unsupported persistent load data type '" + name + "'.");
  2536. }
  2537. const root_key = saved_id.shift();
  2538. /* const location = */ saved_id.shift();
  2539. const size = saved_id.shift();
  2540. if (!loaded_storages.has(root_key)) {
  2541. const storage = new storage_type(size);
  2542. storage._set_cdata(storage_map.get(root_key));
  2543. loaded_storages.set(root_key, storage);
  2544. }
  2545. const storage = loaded_storages.get(root_key);
  2546. const view_metadata = saved_id.shift();
  2547. if (view_metadata) {
  2548. const view_key = view_metadata.shift();
  2549. view_metadata.shift(); // view_offset
  2550. view_metadata.shift(); // view_size
  2551. let view = null;
  2552. if (loaded_storages.has(view_key)) {
  2553. view = loaded_storages.get(root_key);
  2554. }
  2555. else {
  2556. view = null; // storage.slice(view_offset, view_offset + view_size);
  2557. loaded_storages.set(view_key, view);
  2558. }
  2559. return view;
  2560. }
  2561. return storage;
  2562. }
  2563. default: {
  2564. throw new pytorch.Error("Unsupported persistent load type '" + typename + "'.");
  2565. }
  2566. }
  2567. };
  2568. return unpickler.load();
  2569. }
  2570. get constants() {
  2571. if (this._constants === undefined) {
  2572. this._constants = [];
  2573. const stream = this._entries.get('constants.pkl');
  2574. if (stream) {
  2575. const buffer = stream.peek();
  2576. this._constants = this._unpickle(buffer, this._storage('constants/'));
  2577. for (let i = 0; i < this._constants.length; i++) {
  2578. const constant = this._constants[i];
  2579. const variable = 'CONSTANTS.c' + i.toString();
  2580. if (pytorch.Utility.isTensor(constant)) {
  2581. constant.__variable__ = variable;
  2582. }
  2583. else if (constant && constant.__class__ && constant.__class__.__module__ && constant.__class__.__name__) {
  2584. const type = constant.__class__.__module__ + '.' + constant.__class__.__name__;
  2585. switch (type) {
  2586. case '__torch__.torch.classes.xnnpack.LinearOpContext':
  2587. case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
  2588. case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
  2589. case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
  2590. if (pytorch.Utility.isTensor(constant.weight)) {
  2591. constant.weight.__variable__ = variable + '.weight';
  2592. }
  2593. if (pytorch.Utility.isTensor(constant.bias)) {
  2594. constant.bias.__variable__ = variable + '.bias';
  2595. }
  2596. break;
  2597. default:
  2598. throw new pytorch.Error("Unsupported constant context '" + type + "'.");
  2599. }
  2600. }
  2601. else {
  2602. throw new pytorch.Error('Unsupported constant.');
  2603. }
  2604. }
  2605. }
  2606. }
  2607. return this._constants;
  2608. }
  2609. _storage(dirname) {
  2610. const map = new Map();
  2611. const prefix = dirname;
  2612. for (const entry of this._entries) {
  2613. if (entry[0].startsWith(prefix)) {
  2614. const key = entry[0].substring(prefix.length);
  2615. const buffer = entry[1].peek();
  2616. map.set(key, buffer);
  2617. }
  2618. }
  2619. return map;
  2620. }
  2621. get inputs() {
  2622. return this._inputs;
  2623. }
  2624. get outputs() {
  2625. return this._outputs;
  2626. }
  2627. get nodes() {
  2628. return this._nodes;
  2629. }
  2630. };
  2631. pytorch.Container.Zip.Json = class extends pytorch.Container.Zip {
  2632. constructor(entries, model) {
  2633. super(entries);
  2634. this._producer = model && model.producerName ? model.producerName + (model.producerVersion ? ' v' + model.producerVersion : '') : '';
  2635. this._model = model;
  2636. }
  2637. get format() {
  2638. return this._entries.get('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
  2639. }
  2640. get graphs() {
  2641. if (!this._graphs) {
  2642. const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
  2643. const graph = new pytorch.Container.Zip.Json.Script(this._entries, execution, this._model);
  2644. this._graphs = graph.data.forward ? [ graph ] : pytorch.Utility.find(graph.data);
  2645. }
  2646. return this._graphs;
  2647. }
  2648. };
  2649. pytorch.Container.Zip.Json.Script = class extends pytorch.Container.Zip.Script {
  2650. constructor(entries, execution, model) {
  2651. super(entries);
  2652. this._execution = execution;
  2653. this._model = model;
  2654. this._name = model.mainModule.name || '';
  2655. }
  2656. get name() {
  2657. return this._name;
  2658. }
  2659. get data() {
  2660. if (!this._data) {
  2661. this._data = this._model.mainModule || {};
  2662. const queue = [ this._data ];
  2663. const entries = new Map();
  2664. for (const entry of this._entries) {
  2665. const name = entry[0];
  2666. const stream = entry[1];
  2667. const buffer = stream.peek();
  2668. entries.set(name, buffer);
  2669. }
  2670. const tensorTypeMap = new Map([
  2671. [ 'FLOAT', 'Float' ],
  2672. [ 'FLOAT16', 'Half' ],
  2673. [ 'DOUBLE', 'Double' ],
  2674. [ 'INT8', 'Char' ],
  2675. [ 'INT32', 'Int' ],
  2676. [ 'INT64', 'Long' ]
  2677. ]);
  2678. const constants = this._model.tensors || [];
  2679. this._constants = constants.map((constant) => {
  2680. const key = constant.data.key;
  2681. if (!tensorTypeMap.has(constant.dataType)) {
  2682. throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
  2683. }
  2684. const type = tensorTypeMap.get(constant.dataType);
  2685. const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
  2686. const storage_type = this.execution.type('torch.' + type + 'Storage');
  2687. const size = (shape || []).reduce((a, b) => a * b, 1);
  2688. const offset = parseInt(constant.offset, 10) || 0;
  2689. const storage = new storage_type([ size ]);
  2690. const itemsize = storage.dtype.itemsize();
  2691. const buffer = entries.get(key);
  2692. const length = size * itemsize;
  2693. const data = buffer.slice(offset, offset + length);
  2694. storage._set_cdata(data);
  2695. const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
  2696. tensor.name = constant.data.key;
  2697. return tensor;
  2698. });
  2699. this._attributes = [];
  2700. const stream = this._entries.get('attributes.pkl');
  2701. if (stream) {
  2702. const buffer = stream.peek();
  2703. const unpickler = python.Unpickler.open(buffer, this.execution);
  2704. this._attributes.push(...unpickler.load());
  2705. }
  2706. while (queue.length > 0) {
  2707. const module = queue.shift();
  2708. if (!module.__class__) {
  2709. module.__class__ = {
  2710. __module__: 'torch.nn.modules.module',
  2711. __name__: 'Module'
  2712. };
  2713. }
  2714. if (module.name) {
  2715. module.__id__ = module.name;
  2716. }
  2717. if (module.submodules) {
  2718. for (const submodule of module.submodules) {
  2719. module[submodule.name] = submodule;
  2720. submodule.__parent__ = module;
  2721. queue.push(submodule);
  2722. }
  2723. delete module.submodules;
  2724. }
  2725. const attributes = [];
  2726. if (module.attributes) {
  2727. attributes.push(...module.attributes);
  2728. delete module.attributes;
  2729. }
  2730. const parameters = [];
  2731. if (module.parameters) {
  2732. parameters.push(...module.parameters);
  2733. delete module.parameters;
  2734. }
  2735. if (module.arguments) {
  2736. parameters.push(...module.arguments);
  2737. delete module.arguments;
  2738. }
  2739. for (const parameter of parameters) {
  2740. const tensor = this._constants[parameter.tensorId];
  2741. module[parameter.name] = tensor;
  2742. if (!parameter.__class__) {
  2743. parameter.__class__ = {
  2744. __module__: 'torch',
  2745. __name__: 'Tensor'
  2746. };
  2747. }
  2748. }
  2749. for (const attribute of attributes) {
  2750. module[attribute.name] = this._attributes[attribute.id];
  2751. }
  2752. }
  2753. const code = this._data.torchscriptArena;
  2754. if (code && code.key && code.key.startsWith('code/')) {
  2755. const file = code.key.substring('code/'.length);
  2756. const program = this.execution.parse(file);
  2757. for (const statement of program.body) {
  2758. if (statement.type == 'def') {
  2759. const self = this;
  2760. const globals = this.execution.context;
  2761. const func = {
  2762. __class__: this.execution._builtins.function,
  2763. __name__: statement.name,
  2764. __code__: statement,
  2765. __call__: function(args) {
  2766. return self.execution.apply(this.__code__, args, globals);
  2767. }
  2768. };
  2769. this._data[statement.name] = func;
  2770. }
  2771. }
  2772. }
  2773. delete this._model;
  2774. }
  2775. return this._data;
  2776. }
  2777. };
  2778. pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip {
  2779. constructor(entries) {
  2780. super(entries);
  2781. }
  2782. get format() {
  2783. return (this._entries.get('constants.pkl') ? 'TorchScript' : 'PyTorch') + ' ' + this.version('version');
  2784. }
  2785. get graphs() {
  2786. if (!this._graphs) {
  2787. const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
  2788. const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution);
  2789. this._graphs = graph.data.forward ? [ graph ] : pytorch.Utility.find(graph.data);
  2790. }
  2791. return this._graphs;
  2792. }
  2793. };
  2794. pytorch.Container.Zip.Pickle.Script = class extends pytorch.Container.Zip.Script {
  2795. constructor(entries, execution, location, name) {
  2796. super(entries, execution, location, name);
  2797. }
  2798. get data() {
  2799. if (!this._data) {
  2800. const stream = this._entries.get(this._location.model || 'data.pkl');
  2801. const buffer = stream.peek();
  2802. this._data = this._unpickle(buffer, this._storage(this._location.data || 'data/'));
  2803. }
  2804. return this._data;
  2805. }
  2806. };
  2807. pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
  2808. constructor(entries) {
  2809. super(entries);
  2810. }
  2811. get format() {
  2812. return 'PyTorch Package' + ' ' + this.version('.data/version');
  2813. }
  2814. get graphs() {
  2815. if (!this._graphs) {
  2816. this._graphs = [];
  2817. const entries = Array.from(this._entries).filter((entry) => !entry[0].startsWith('.data/') && !entry[0].endsWith('py'));
  2818. for (const entry of entries) {
  2819. const name = entry[0];
  2820. const stream = entry[1];
  2821. const loaded_reduces = new Map();
  2822. const loaded_storages = new Map();
  2823. const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
  2824. execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) {
  2825. return "torch.jit._script.RecursiveScriptModule('" + script_module_id + "')";
  2826. });
  2827. for (const entry of this._entries) {
  2828. if (!entry[0].startsWith('.data/') && entry[0].endsWith('.py')) {
  2829. const name = entry[0];
  2830. const stream = entry[1];
  2831. const buffer = stream.peek();
  2832. execution.add(name, buffer);
  2833. }
  2834. }
  2835. const unpickler = python.Unpickler.open(stream, execution);
  2836. unpickler.persistent_load = (saved_id) => {
  2837. const typename = saved_id.shift();
  2838. switch (typename) {
  2839. case 'storage': {
  2840. const storage_type = execution.type(saved_id[0]);
  2841. const key = saved_id[1];
  2842. /* const location = saved_id[2]; */
  2843. const size = saved_id[3];
  2844. if (!loaded_storages.has(key)) {
  2845. const storage = new storage_type(size);
  2846. const stream = this._entries.get('.data/' + key + '.storage');
  2847. const buffer = stream.peek();
  2848. storage._set_cdata(buffer);
  2849. loaded_storages.set(key, storage);
  2850. }
  2851. return loaded_storages.get(key);
  2852. }
  2853. case 'reduce_package': {
  2854. if (saved_id.left === 2) {
  2855. const func = saved_id[0];
  2856. const args = saved_id[1];
  2857. return execution.invoke(func, args);
  2858. }
  2859. const reduce_id = saved_id[0];
  2860. const func = saved_id[1];
  2861. const args = saved_id[2];
  2862. if (!loaded_reduces.has(reduce_id)) {
  2863. const value = execution.invoke(func, args);
  2864. loaded_reduces.set(reduce_id, value);
  2865. }
  2866. return loaded_reduces.get(reduce_id);
  2867. }
  2868. default: {
  2869. throw new python.Error("Unknown package typename '" + typename + "'.");
  2870. }
  2871. }
  2872. };
  2873. const root = unpickler.load();
  2874. this._graphs.push({
  2875. name: name,
  2876. type: 'module',
  2877. data: root
  2878. });
  2879. }
  2880. }
  2881. return this._graphs;
  2882. }
  2883. };
  2884. pytorch.Container.Zip.Execution = class extends pytorch.Execution {
  2885. constructor(sources, exceptionCallback, metadata) {
  2886. super(sources, exceptionCallback);
  2887. this._metadata = metadata;
  2888. this.reset();
  2889. }
  2890. reset() {
  2891. this._nodes = [];
  2892. this._variableIndex = 0;
  2893. }
  2894. get nodes() {
  2895. return this._nodes;
  2896. }
  2897. call(target, name, args, context) {
  2898. let resolvedTarget = pytorch.Utility.target(target);
  2899. let outputTypes = null;
  2900. if (resolvedTarget && resolvedTarget + '.' + name === 'ops.prim.NumToTensor' &&
  2901. args.length === 1 && args[0].type === 'call' && args[0].target.member.type == 'id') {
  2902. const innerCall = args[0];
  2903. resolvedTarget = pytorch.Utility.target(innerCall.target.target);
  2904. name = innerCall.target.member.value;
  2905. args = innerCall.arguments;
  2906. outputTypes = [ 'int64' ];
  2907. }
  2908. if (resolvedTarget) {
  2909. const type = resolvedTarget + '.' + name;
  2910. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
  2911. let schemas = this._metadata.type(type);
  2912. if (schemas) {
  2913. schemas = !Array.isArray(schemas) ? [ schemas ] : schemas;
  2914. const evalArgs = args.map((argument) => argument.type === '=' && argument.target && argument.target.type === 'id' ? this.expression(argument.expression, context) : this.expression(argument, context));
  2915. for (const schema of schemas) {
  2916. const copyArgs = Array.prototype.slice.call(args);
  2917. const copyEvalArgs = Array.prototype.slice.call(evalArgs);
  2918. const node = {
  2919. type: schema.name,
  2920. inputs: [],
  2921. attributes: [],
  2922. outputs: []
  2923. };
  2924. const referencedParameters = [];
  2925. let next = false;
  2926. const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
  2927. let op_context = null;
  2928. while (copyEvalArgs.length > 0 || (op_context && parameters.length > 0)) {
  2929. if (parameters.length <= 0) {
  2930. next = true;
  2931. break;
  2932. }
  2933. const arg = copyEvalArgs[0];
  2934. if (arg && arg.__class__ && arg.__class__.__module__ && arg.__class__.__name__) {
  2935. const type = arg.__class__.__module__ + '.' + arg.__class__.__name__;
  2936. switch (type) {
  2937. case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
  2938. case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
  2939. case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
  2940. case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
  2941. case '__torch__.torch.classes.xnnpack.LinearOpContext':
  2942. op_context = arg;
  2943. copyArgs.shift();
  2944. copyEvalArgs.shift();
  2945. continue;
  2946. default:
  2947. break;
  2948. }
  2949. }
  2950. if (op_context && parameters[0]) {
  2951. const parameter = parameters[0];
  2952. const name = parameter.name;
  2953. if (name in op_context && parameter.context) {
  2954. copyArgs.unshift({ type: null });
  2955. copyEvalArgs.unshift(op_context[name]);
  2956. }
  2957. }
  2958. if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
  2959. parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) {
  2960. const map = new Map(parameters.map((parameter) => [ parameter.name, parameter ]));
  2961. while (copyArgs.length > 0) {
  2962. const argument = copyArgs.shift();
  2963. const value = copyEvalArgs.shift();
  2964. const parameter = map.get(argument.target.value);
  2965. if (!parameter) {
  2966. next = true;
  2967. break;
  2968. }
  2969. if (!pytorch.Utility.isType(value, parameter.type)) {
  2970. if (parameter.optional) {
  2971. continue;
  2972. }
  2973. next = true;
  2974. break;
  2975. }
  2976. node.attributes.push({ name: parameter.name, value: value });
  2977. }
  2978. continue;
  2979. }
  2980. if (next) {
  2981. break;
  2982. }
  2983. const parameter = parameters.shift();
  2984. const argument = copyEvalArgs[0];
  2985. if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) {
  2986. if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
  2987. if (parameter.optional) {
  2988. if (argument === undefined) {
  2989. copyArgs.shift();
  2990. copyEvalArgs.shift();
  2991. }
  2992. continue;
  2993. }
  2994. next = true;
  2995. }
  2996. else {
  2997. copyArgs.shift();
  2998. copyEvalArgs.shift();
  2999. const item = (argument === null || argument === undefined) ? {} : argument;
  3000. item.__variable__ = item.__variable__ || this.variable();
  3001. const inputs = [];
  3002. inputs.push({ id: item.__variable__ });
  3003. referencedParameters.push(item);
  3004. node.inputs.push(inputs);
  3005. }
  3006. }
  3007. else if (parameter.type === 'Tensor[]') {
  3008. const argument = copyEvalArgs[0];
  3009. if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) {
  3010. if (parameter.optional) {
  3011. continue;
  3012. }
  3013. next = true;
  3014. }
  3015. else {
  3016. copyArgs.shift();
  3017. copyEvalArgs.shift();
  3018. const inputs = [];
  3019. for (let item of argument) {
  3020. if (item === null) {
  3021. item = {};
  3022. }
  3023. item.__variable__ = item.__variable__ || this.variable();
  3024. inputs.push({ id: item.__variable__ });
  3025. referencedParameters.push(item);
  3026. }
  3027. node.inputs.push(inputs);
  3028. }
  3029. }
  3030. else {
  3031. const arg = copyArgs[0];
  3032. if (!pytorch.Utility.isType(argument, parameter.type) && argument !== null) {
  3033. if (parameter.optional) {
  3034. continue;
  3035. }
  3036. next = true;
  3037. }
  3038. else if (arg.type !== '=') {
  3039. copyArgs.shift();
  3040. copyEvalArgs.shift();
  3041. node.attributes.push({ name: parameter.name, value: argument });
  3042. }
  3043. else {
  3044. throw new pytorch.Error('Expected named argument.');
  3045. }
  3046. }
  3047. if (next) {
  3048. break;
  3049. }
  3050. }
  3051. if (next) {
  3052. continue;
  3053. }
  3054. const result = [];
  3055. for (let i = 0; i < schema.outputs.length; i++) {
  3056. const parameter = schema.outputs[i];
  3057. switch (parameter.type) {
  3058. case 'Tensor': {
  3059. const parameter = this.invoke('torch.Tensor', []);
  3060. parameter.__origin__ = type;
  3061. if (i === 0) {
  3062. switch (type) {
  3063. case 'torch.conv1d':
  3064. case 'torch.embedding': {
  3065. parameter.resize_([ NaN, NaN, NaN ]);
  3066. break;
  3067. }
  3068. case 'torch.cat':
  3069. case 'torch.conv2d':
  3070. case 'torch.dropout':
  3071. case 'torch.flatten':
  3072. case 'torch.max_pool2d':
  3073. case 'torch.adaptive_avg_pool2d':
  3074. case 'torch.avg_pool2d':
  3075. case 'torch.quantize_per_tensor':
  3076. case 'torch.relu_':
  3077. case 'torch.hardtanh_':
  3078. case 'torch.upsample_bilinear2d':
  3079. case 'ops.prepacked.conv2d_clamp_run': {
  3080. parameter.resize_([ NaN, NaN, NaN, NaN ]);
  3081. break;
  3082. }
  3083. case 'torch.slice': {
  3084. const input = evalArgs[0];
  3085. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3086. const size = input.size();
  3087. parameter.resize_(size);
  3088. }
  3089. break;
  3090. }
  3091. case 'torch.to': {
  3092. const input = evalArgs[0];
  3093. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3094. const size = input.size();
  3095. parameter.resize_(size);
  3096. }
  3097. break;
  3098. }
  3099. case 'torch.conv3d': {
  3100. parameter.resize_([ NaN, NaN, NaN, NaN, NaN ]);
  3101. break;
  3102. }
  3103. case 'torch.detach':
  3104. case 'torch.mean':
  3105. case 'torch.mul':
  3106. case 'torch.div':
  3107. case 'torch.batch_norm':
  3108. case 'torch.gelu':
  3109. case 'torch.relu':
  3110. case 'torch.clamp_':
  3111. case 'torch.hardswish_': {
  3112. const input = evalArgs[0];
  3113. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3114. parameter.resize_(input.size());
  3115. }
  3116. break;
  3117. }
  3118. case 'torch.add':
  3119. case 'torch.sub': {
  3120. const input = evalArgs[0];
  3121. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3122. parameter.resize_(input.size());
  3123. }
  3124. else {
  3125. const other = evalArgs[1];
  3126. if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) {
  3127. parameter.resize_(other.size());
  3128. }
  3129. }
  3130. break;
  3131. }
  3132. case 'torch.select': {
  3133. const input = evalArgs[0];
  3134. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3135. parameter.resize_(Array(input.size().length - 1).fill(NaN));
  3136. }
  3137. break;
  3138. }
  3139. case 'torch.layer_norm': {
  3140. const input = evalArgs[0];
  3141. const normalized_shape = evalArgs[1];
  3142. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3143. const shape = input.size();
  3144. if (Array.isArray(normalized_shape) && normalized_shape.length === 1) {
  3145. shape[shape.length - 1] = normalized_shape[0];
  3146. }
  3147. parameter.resize_(shape);
  3148. }
  3149. break;
  3150. }
  3151. case 'torch.empty':
  3152. case 'torch.ones':
  3153. case 'torch.zeros':
  3154. case 'torch.zeros_like': {
  3155. parameter.resize_(evalArgs[0]);
  3156. break;
  3157. }
  3158. case 'torch.view':
  3159. case 'torch.reshape':
  3160. case 'torch.new_full': {
  3161. parameter.resize_(evalArgs[1]);
  3162. break;
  3163. }
  3164. case 'torch.squeeze': {
  3165. const input = evalArgs[0];
  3166. const size = input.size();
  3167. if (Array.isArray(size)) {
  3168. switch (evalArgs.length) {
  3169. case 1: {
  3170. parameter.resize_(size.filter((value) => value !== 1));
  3171. break;
  3172. }
  3173. case 2: {
  3174. const dim = evalArgs[1];
  3175. parameter.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim));
  3176. break;
  3177. }
  3178. default: {
  3179. break;
  3180. }
  3181. }
  3182. }
  3183. break;
  3184. }
  3185. case 'torch.unsqueeze': {
  3186. const input = evalArgs[0];
  3187. const size = input.size();
  3188. const dim = evalArgs[1];
  3189. if (Array.isArray(size) && dim !== undefined) {
  3190. const shape = size.slice();
  3191. shape.splice(dim, 0, 1);
  3192. parameter.resize_(shape);
  3193. }
  3194. else {
  3195. parameter.resize_([ NaN, NaN, NaN, NaN ]);
  3196. }
  3197. break;
  3198. }
  3199. case 'torch.transpose': {
  3200. const input = evalArgs[0];
  3201. let dim0 = evalArgs[1];
  3202. let dim1 = evalArgs[2];
  3203. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  3204. const size = input.size().slice();
  3205. dim0 = dim0 >= 0 ? dim0 : size.length + dim0;
  3206. dim1 = dim1 >= 0 ? dim1 : size.length + dim1;
  3207. const value = size[dim0];
  3208. size[dim0] = size[1];
  3209. size[dim1] = value;
  3210. parameter.resize_(size);
  3211. }
  3212. break;
  3213. }
  3214. case 'ops.quantized.cat':
  3215. case 'ops.quantized.cat_relu':
  3216. case 'ops.quantized.linear':
  3217. case 'ops.quantized.conv2d':
  3218. case 'ops.quantized.conv2d_relu':
  3219. case 'ops.quantized.add':
  3220. case 'ops.quantized.add_relu':
  3221. parameter.resize_([ NaN, NaN, NaN, NaN ]);
  3222. parameter.__quantized__ = true;
  3223. break;
  3224. case 'torch.contiguous':
  3225. parameter.__source__ = evalArgs[0];
  3226. break;
  3227. default:
  3228. break;
  3229. }
  3230. }
  3231. parameter.__variable__ = this.variable();
  3232. result.push(parameter);
  3233. node.outputs.push([ { id: parameter.__variable__ } ]);
  3234. break;
  3235. }
  3236. case 'Tensor[]': {
  3237. let count = 1;
  3238. switch (type) {
  3239. case 'torch.chunk':
  3240. count = node.attributes.filter((attribute) => attribute.name == 'chunks')[0].value;
  3241. break;
  3242. case 'torch.meshgrid':
  3243. count = node.inputs[0].length;
  3244. break;
  3245. case 'torch.unbind':
  3246. count = args[0].__tuple__ || count;
  3247. break;
  3248. case 'torch.broadcast_tensors':
  3249. case 'torch.split':
  3250. case 'torch.split_with_sizes':
  3251. if (context.target.length > 0) {
  3252. count = context.target[context.target.length - 1].length;
  3253. }
  3254. break;
  3255. default:
  3256. break;
  3257. }
  3258. const tensors = [];
  3259. const outputs = [];
  3260. for (let i = 0; i < count; i ++) {
  3261. const tensor = this.invoke('torch.Tensor', []);
  3262. tensor.__origin__ = type;
  3263. tensor.__variable__ = this.variable();
  3264. tensors.push(tensor);
  3265. outputs.push({ id: tensor.__variable__ });
  3266. }
  3267. result.push(tensors);
  3268. node.outputs.push(outputs);
  3269. break;
  3270. }
  3271. default: {
  3272. if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) {
  3273. next = true;
  3274. break;
  3275. }
  3276. const tensor = this.invoke('torch.Tensor', []);
  3277. tensor.resize_([]);
  3278. tensor.__origin__ = type;
  3279. tensor.__variable__ = this.variable();
  3280. result.push(tensor);
  3281. node.outputs.push([ { id: tensor.__variable__ } ]);
  3282. break;
  3283. }
  3284. }
  3285. }
  3286. if (next) {
  3287. continue;
  3288. }
  3289. for (const parameter of referencedParameters) {
  3290. parameter.__count__ = (parameter.__count__ || 0) + 1;
  3291. }
  3292. this.push(node);
  3293. if (result.length > 1) {
  3294. return result;
  3295. }
  3296. return result[0];
  3297. }
  3298. }
  3299. }
  3300. return super.call(target, name, args, context);
  3301. }
  3302. block(statements, context) {
  3303. statements = Array.prototype.slice.call(statements);
  3304. while (statements.length > 0) {
  3305. if (statements.length > 1) {
  3306. const assign = statements[0];
  3307. const condition = statements[1];
  3308. // _x = torch.ne(torch.len(torch.size(input)), 5)
  3309. // if _x:
  3310. // ops.prim.RaiseException(...)
  3311. if (assign.type === '=' &&
  3312. condition.type === 'if' &&
  3313. pytorch.Utility.isEqual(assign.target, condition.condition) &&
  3314. pytorch.Utility.isCall(assign.expression, 'torch.ne', 2) &&
  3315. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.len', 1) &&
  3316. pytorch.Utility.isCall(assign.expression.arguments[0].arguments[0], 'torch.size', 1) &&
  3317. condition.then.statements.length == 1 &&
  3318. pytorch.Utility.isCall(condition.then.statements[0], 'ops.prim.RaiseException', 1)) {
  3319. const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
  3320. if (pytorch.Utility.isTensor(tensor) && tensor.size) {
  3321. const number = this.expression(assign.expression.arguments[1], context);
  3322. const size = tensor.size();
  3323. if (number >= 3 && number <= 5) {
  3324. if (!Array.isArray(size) || size.length !== number) {
  3325. tensor.resize_(Array(number).fill(NaN));
  3326. }
  3327. }
  3328. }
  3329. }
  3330. // _x = torch.ne(torch.dim(input), 5)
  3331. // if _x:
  3332. // ops.prim.RaiseException(...)
  3333. if (assign.type === '=' &&
  3334. condition.type === 'if' &&
  3335. pytorch.Utility.isEqual(assign.target, condition.condition) &&
  3336. pytorch.Utility.isCall(assign.expression, 'torch.ne', 2) &&
  3337. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.dim', 1) &&
  3338. condition.then.statements.length > 0 &&
  3339. pytorch.Utility.isCall(condition.then.statements[condition.then.statements.length - 1], 'ops.prim.RaiseException', 1)) {
  3340. const tensor = this.expression(assign.expression.arguments[0].arguments[0], context);
  3341. if (pytorch.Utility.isTensor(tensor)) {
  3342. const size = this.expression(assign.expression.arguments[1], context);
  3343. tensor.resize_(Array(size).fill(NaN));
  3344. }
  3345. }
  3346. // _0 = torch.eq(torch.len(torch.size(x)), 2)
  3347. // if _0:
  3348. // pass
  3349. // else:
  3350. // ops.prim.RaiseException("AssertionError: ")
  3351. if (assign.type === '=' &&
  3352. condition.type === 'if' &&
  3353. pytorch.Utility.isEqual(assign.target, condition.condition) &&
  3354. pytorch.Utility.isCall(assign.expression, 'torch.eq', 2) &&
  3355. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.len', 1) &&
  3356. pytorch.Utility.isCall(assign.expression.arguments[0].arguments[0], 'torch.size', 1) &&
  3357. condition.else.statements.length == 1 &&
  3358. pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
  3359. const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
  3360. if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
  3361. const number = this.expression(assign.expression.arguments[1], context);
  3362. tensor.resize_(Array(number).fill(NaN));
  3363. }
  3364. }
  3365. // val = torch.slice(torch.size(img), -2)
  3366. // if torch.eq(torch.len(val), 2):
  3367. // pass
  3368. // else:
  3369. // ops.prim.RaiseException("AssertionError: ")
  3370. if (assign.type === '=' &&
  3371. condition.type === 'if' &&
  3372. pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) &&
  3373. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.size', 1) &&
  3374. pytorch.Utility.isCall(condition.condition, 'torch.eq', 2) &&
  3375. pytorch.Utility.isCall(condition.condition.arguments[0], 'torch.len', 1) &&
  3376. pytorch.Utility.isEqual(condition.condition.arguments[0].arguments[0], assign.target) &&
  3377. condition.else.statements.length == 1 &&
  3378. pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
  3379. const tensor = this.expression(assign.expression.arguments[0].arguments[0], context);
  3380. if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
  3381. const start = this.expression(assign.expression.arguments[1], context);
  3382. const value = this.expression(condition.condition.arguments[1], context);
  3383. if (Number.isInteger(start) && start < 0 && Number.isInteger(value) && value > 0) {
  3384. tensor.resize_(Array(value - start).fill(NaN));
  3385. }
  3386. }
  3387. }
  3388. }
  3389. if (statements.length > 1) {
  3390. // getattr_1 = torch.size(x)
  3391. // getitem = torch.slice(getattr_1, -2, 9223372036854775807, 1)
  3392. const size = statements[0];
  3393. const statement = statements[1];
  3394. if (size.type === '=' && statement.type === '=' &&
  3395. size.target.type === 'id' &&
  3396. pytorch.Utility.isCall(size.expression, 'torch.size', 1) &&
  3397. pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
  3398. statement.expression.arguments[0].type === 'id' && size.target.value === statement.expression.arguments[0].value) {
  3399. const tensor = this.expression(size.expression.arguments[0], context);
  3400. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3401. tensor.resize_([ 1, 3, 299, 299 ]);
  3402. }
  3403. }
  3404. }
  3405. if (statements.length > 1) {
  3406. // _0 = torch.split_with_sizes(...)
  3407. // a, a_1, a_2, = _0
  3408. const statement = statements[0];
  3409. const tuple = statements[1];
  3410. if (statement.type === '=' && statement.target.type === 'id' && statement.expression.type == 'call' &&
  3411. tuple.type === '=' && tuple.target.type === 'tuple' &&
  3412. tuple.target.value.every((item) => item.type === 'id') &&
  3413. tuple.expression.value === statement.target.value) {
  3414. const containsVariableReference = (queue, value) => {
  3415. while (queue.length > 0) {
  3416. const obj = queue.shift();
  3417. if (obj && obj.type === 'id' && obj.value === value) {
  3418. return true;
  3419. }
  3420. else if (Array.isArray(obj)) {
  3421. for (const item of obj) {
  3422. if (Array.isArray(item) || (Object(item) === item && item.type)) {
  3423. queue.push(item);
  3424. }
  3425. }
  3426. }
  3427. else if (Object(obj) === obj) {
  3428. for (const entry of Object.entries(obj)) {
  3429. const key = entry[0];
  3430. const value = entry[1];
  3431. if (key === 'location') {
  3432. continue;
  3433. }
  3434. if (Array.isArray(value)) {
  3435. for (const item of value) {
  3436. if (Array.isArray(item) || (Object(item) === item && item.type)) {
  3437. queue.push(item);
  3438. }
  3439. }
  3440. }
  3441. else if (Object(value) === value && value.type) {
  3442. queue.push(value);
  3443. }
  3444. }
  3445. }
  3446. }
  3447. return false;
  3448. };
  3449. if (!containsVariableReference(statements.slice(2, statements.length - 1), statement.target.value)) {
  3450. statements[0] = Object.assign({}, statement);
  3451. statements[0].target = tuple.target;
  3452. statements.splice(1, 1);
  3453. }
  3454. }
  3455. }
  3456. const statement = statements.shift();
  3457. // input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
  3458. if (statement.type === '=' &&
  3459. pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
  3460. pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 1)) {
  3461. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  3462. if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
  3463. tensor.resize_([ 1, 3, 299, 299 ]);
  3464. }
  3465. }
  3466. // torch.slice(ops.prim.shape(input), 0, 2, 1)
  3467. if (statement.type === '=' &&
  3468. pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
  3469. pytorch.Utility.isCall(statement.expression.arguments[0], 'ops.prim.shape', 1)) {
  3470. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  3471. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3472. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  3473. }
  3474. }
  3475. // _3 = torch.le(xxxx, torch.dim(f0))
  3476. if (statement.type === '=' &&
  3477. pytorch.Utility.isCall(statement.expression, 'torch.le', 2) &&
  3478. pytorch.Utility.isCall(statement.expression.arguments[1], 'torch.dim', 1)) {
  3479. const tensor = this.expression(statement.expression.arguments[1].arguments[0], context);
  3480. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3481. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  3482. }
  3483. }
  3484. // if torch.ne(torch.dim(image), 3):
  3485. // xxxx
  3486. // ops.prim.RaiseException(_7)
  3487. if (statement.type === 'if' &&
  3488. pytorch.Utility.isCall(statement.condition, 'torch.ne', 2) &&
  3489. pytorch.Utility.isCall(statement.condition.arguments[0], 'torch.dim', 1) &&
  3490. statement.then.statements.length > 0 &&
  3491. pytorch.Utility.isCall(statement.then.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) {
  3492. const tensor = this.expression(statement.condition.arguments[0].arguments[0], context);
  3493. const size = this.expression(statement.condition.arguments[1], context);
  3494. if (pytorch.Utility.isTensor(tensor) && Number.isInteger(size) && size < 10) {
  3495. tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN));
  3496. }
  3497. }
  3498. // if bool(...):
  3499. // ops.prim.RaiseException(torch.format(_1, dtype))
  3500. // else:
  3501. // pass
  3502. if (statement.type === 'if' &&
  3503. pytorch.Utility.isCall(statement.condition, 'bool', 1) &&
  3504. statement.then.statements.length > 0 &&
  3505. pytorch.Utility.isCall(statement.then.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) {
  3506. statement.condition = { type: 'id', value: 'False' };
  3507. }
  3508. // dim = torch.sub(torch.dim(input), 2)
  3509. if (statement.type === '=' &&
  3510. statement.target.type === 'id' && statement.target.value === 'dim' &&
  3511. pytorch.Utility.isCall(statement.expression, 'torch.sub', 2) &&
  3512. pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.dim', 1)) {
  3513. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  3514. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3515. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  3516. }
  3517. }
  3518. // a, b = torch.unbind(size, 0)
  3519. if (statement.type === '=' &&
  3520. statement.target.type === 'tuple' &&
  3521. (pytorch.Utility.isCall(statement.expression, 'torch.unbind', 1) ||
  3522. pytorch.Utility.isCall(statement.expression, 'torch.unbind', 2))) {
  3523. statement.expression.arguments[0].__tuple__ = statement.target.value.length;
  3524. }
  3525. // a, b, c = torch.size(input)
  3526. if (statement.type === '=' &&
  3527. statement.target.type === 'tuple' &&
  3528. pytorch.Utility.isCall(statement.expression, 'torch.size', 1)) {
  3529. const tensor = this.expression(statement.expression.arguments[0], context);
  3530. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3531. const dim = statement.target.value.length;
  3532. tensor.resize_(Array(dim).fill(NaN));
  3533. }
  3534. }
  3535. // x = torch.len(input)
  3536. if (statement.type === '=' &&
  3537. statement.target.type === 'id' &&
  3538. pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) {
  3539. const tensor = this.expression(statement.expression.arguments[0], context);
  3540. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  3541. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  3542. }
  3543. }
  3544. if (statement.type === '=' &&
  3545. statement.expression.type === 'call' && statement.expression.arguments.length > 0 &&
  3546. pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 2)) {
  3547. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  3548. const dim = this.expression(statement.expression.arguments[0].arguments[1], context);
  3549. if (pytorch.Utility.isTensor(tensor) && Number.isInteger(dim)) {
  3550. if (tensor.shape === undefined) {
  3551. tensor.resize_(Array(dim + 1).fill(NaN));
  3552. }
  3553. else if (Array.isArray(tensor.shape) && tensor.shape.length <= dim) {
  3554. tensor.resize_(tensor.shape.concat(Array(dim + 1 - tensor.shape.length).fill(NaN)));
  3555. }
  3556. }
  3557. }
  3558. if (statement.type === '=' && statement.target.type === 'tuple' &&
  3559. statement.expression.type === 'call' && statement.expression.arguments.length > 0 &&
  3560. pytorch.Utility.isCall(statement.expression, 'torch.size', 1)) {
  3561. const tensor = this.expression(statement.expression.arguments[0], context);
  3562. if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input') {
  3563. if (tensor.shape === undefined) {
  3564. tensor.resize_(Array(statement.target.value.length).fill(NaN));
  3565. }
  3566. }
  3567. }
  3568. const value = this.statement(statement, context);
  3569. if (value !== undefined) {
  3570. return value;
  3571. }
  3572. }
  3573. return undefined;
  3574. }
  3575. push(node) {
  3576. this._nodes.push(node);
  3577. }
  3578. variable() {
  3579. this._variableIndex++;
  3580. return this._variableIndex.toString();
  3581. }
  3582. };
  3583. pytorch.ScalarType = {
  3584. uint8: 0,
  3585. int8: 1,
  3586. int16: 2,
  3587. int32: 3,
  3588. int64: 4,
  3589. float16: 5,
  3590. float32: 6,
  3591. float64: 7,
  3592. complex32: 8,
  3593. complex64: 9,
  3594. complex128: 10,
  3595. boolean: 11,
  3596. qint8: 12,
  3597. quint8: 13,
  3598. qint32: 14,
  3599. bfloat16: 15,
  3600. quint4x2: 16
  3601. };
  3602. pytorch.MemoryFormat = {
  3603. Contiguous: 0,
  3604. Preserve: 1,
  3605. ChannelsLast: 2,
  3606. ChannelsLast3d: 3
  3607. };
  3608. pytorch.Layout = {
  3609. Strided: 0,
  3610. Sparse: 1,
  3611. Mkldnn: 2
  3612. };
  3613. pytorch.Utility = class {
  3614. static getScalarType(scalarType) {
  3615. if (!pytorch.Utility._scalarTypes) {
  3616. pytorch.Utility._scalarTypes = [
  3617. { name: 'uint8', itemsize: 1 },
  3618. { name: 'int8', itemsize: 1 },
  3619. { name: 'int16', itemsize: 2 },
  3620. { name: 'int32', itemsize: 4 },
  3621. { name: 'int64', itemsize: 8 },
  3622. { name: 'float16', itemsize: 2 },
  3623. { name: 'float32', itemsize: 4 },
  3624. { name: 'float64', itemsize: 8 },
  3625. { name: 'complex32', itemsize: 4 },
  3626. { name: 'complex64', itemsize: 8 },
  3627. { name: 'complex128', itemsize: 16 },
  3628. { name: 'boolean', itemsize: 1 },
  3629. { name: 'qint8', itemsize: 1 },
  3630. { name: 'quint8', itemsize: 1 },
  3631. { name: 'qint32', itemsize: 4 },
  3632. { name: 'bfloat16', itemsize: 2 },
  3633. { name: 'quint4x2' }
  3634. ];
  3635. }
  3636. if (scalarType < pytorch.Utility._scalarTypes.length) {
  3637. return pytorch.Utility._scalarTypes[scalarType];
  3638. }
  3639. throw new pytorch.Error("Unsupported scalar type '" + scalarType + "'.");
  3640. }
  3641. static target(expression) {
  3642. if (expression.type == 'id') {
  3643. return expression.value;
  3644. }
  3645. if (expression.type == '.') {
  3646. return pytorch.Utility.target(expression.target) + '.' + pytorch.Utility.target(expression.member);
  3647. }
  3648. return null;
  3649. }
  3650. static isTensor(obj) {
  3651. const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
  3652. switch (name) {
  3653. case 'torch':
  3654. case 'torch.cuda':
  3655. return obj.__class__.__name__.endsWith('Tensor');
  3656. case 'torch.nn.parameter':
  3657. return obj.__class__.__name__ === 'Parameter';
  3658. default:
  3659. return false;
  3660. }
  3661. }
  3662. static toTensor(obj) {
  3663. const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
  3664. switch (name) {
  3665. case 'torch':
  3666. case 'torch.cuda':
  3667. return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
  3668. case 'torch.nn.parameter':
  3669. return obj.__class__.__name__ === 'Parameter' ? obj.data : null;
  3670. default:
  3671. return null;
  3672. }
  3673. }
  3674. static createTensor(name, tensor, littleEndian) {
  3675. const storage = tensor.storage();
  3676. const size = tensor.size();
  3677. const type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
  3678. return new pytorch.Tensor(name || '', type, storage.data, littleEndian);
  3679. }
  3680. static isType(obj, type) {
  3681. switch (type) {
  3682. case 'Tensor':
  3683. return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null);
  3684. case 'Tensor[]':
  3685. return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null);
  3686. case 'Scalar':
  3687. return (obj !== null && obj !== Object(obj)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0);
  3688. case 'boolean':
  3689. return obj === true || obj === false;
  3690. case 'int64':
  3691. return Number.isInteger(obj) || obj instanceof base.Int64 || (typeof obj === 'number' && isNaN(obj));
  3692. case 'int64[]':
  3693. return Array.isArray(obj) && obj.every((item) => Number.isInteger(item) || (typeof item === 'number' && isNaN(item)) || item === undefined);
  3694. case 'int64[1]':
  3695. return pytorch.Utility.isType(obj, 'int64') || pytorch.Utility.isType(obj, 'int64[]');
  3696. case 'float32':
  3697. case 'float64':
  3698. return obj !== null && obj !== Object(obj);
  3699. case 'string[][]':
  3700. return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
  3701. case 'Layout':
  3702. case 'ScalarType':
  3703. case 'MemoryFormat':
  3704. return Number.isInteger(obj) || obj === null;
  3705. case 'Device':
  3706. return obj === null || obj === Object(obj);
  3707. default:
  3708. return true;
  3709. }
  3710. }
  3711. static isCall(expression, name, size) {
  3712. if (expression.type === 'call' &&
  3713. expression.arguments.length === size &&
  3714. pytorch.Utility.target(expression.target) === name) {
  3715. return true;
  3716. }
  3717. return false;
  3718. }
  3719. static isEqual(a, b) {
  3720. return (a.type === 'id' && b.type === 'id' && a.value === b.value);
  3721. }
  3722. static find(data) {
  3723. const root = pytorch.Utility.findModule(data);
  3724. if (root) {
  3725. for (const graph of root) {
  3726. graph.type = 'module';
  3727. }
  3728. return root;
  3729. }
  3730. const weights = pytorch.Utility.findWeights(data);
  3731. if (weights) {
  3732. for (const graph of weights) {
  3733. graph.type = 'weights';
  3734. }
  3735. return weights;
  3736. }
  3737. throw new pytorch.Error('File does not contain root module or state dictionary.');
  3738. }
  3739. static findModule(root) {
  3740. if (root) {
  3741. const keys = [ '', 'model', 'net' ];
  3742. for (const key of keys) {
  3743. const obj = key === '' ? root : root[key];
  3744. if (obj && obj instanceof Map && obj.has('engine')) {
  3745. // https://github.com/NVIDIA-AI-IOT/torch2trt/blob/master/torch2trt/torch2trt.py
  3746. const data = obj.get('engine');
  3747. const signature = [ 0x70, 0x74, 0x72, 0x74 ]; // ptrt
  3748. if (data instanceof Uint8Array && data.length > signature.length && signature.every((value, index) => value === data[index])) {
  3749. const buffer = data.slice(0, 24);
  3750. const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  3751. throw new pytorch.Error("Invalid file content. File contains undocumented PyTorch TensorRT engine data (" + content.substring(8) + ").");
  3752. }
  3753. }
  3754. if (obj) {
  3755. if (obj._modules) {
  3756. return [ { name: '', data: obj } ];
  3757. }
  3758. const objKeys = Object.keys(obj).filter((key) => obj[key] && obj[key]._modules);
  3759. if (objKeys.length > 1) {
  3760. return objKeys.map((key) => { return { name: key, data: obj[key] }; });
  3761. }
  3762. }
  3763. }
  3764. }
  3765. return null;
  3766. }
  3767. static findWeights(root) {
  3768. if (!root) {
  3769. return null;
  3770. }
  3771. if (root instanceof Map) {
  3772. const obj = {};
  3773. for (const pair of root) {
  3774. const key = pair[0];
  3775. const value = pair[1];
  3776. obj[key] = value;
  3777. }
  3778. root = obj;
  3779. }
  3780. const keys = root && !Array.isArray(root) ? Object.keys(root) : [];
  3781. if (keys.length > 1) {
  3782. keys.splice(0, keys.length);
  3783. }
  3784. keys.push(...[
  3785. 'state_dict', 'state', 'model_state', 'model', 'model_state_dict', 'model_dict', 'net_dict', 'params', 'generator', 'module', 'weights',
  3786. 'discriminator', 'g_state', 'network', 'net', 'netG', 'net_states', 'state_dict_stylepredictor', 'state_dict_ghiasi', 'runner', ''
  3787. ]);
  3788. for (const key of keys) {
  3789. const obj = key === '' ? root : root[key];
  3790. let graphs = null;
  3791. graphs = graphs || pytorch.Utility._convertTensor(obj);
  3792. graphs = graphs || pytorch.Utility._convertObjectList(obj);
  3793. graphs = graphs || pytorch.Utility._convertStateDict(obj);
  3794. if (graphs) {
  3795. return graphs;
  3796. }
  3797. }
  3798. return null;
  3799. }
  3800. static _convertTensor(obj) {
  3801. if (obj && pytorch.Utility.isTensor(obj)) {
  3802. const layers = [];
  3803. const argument = { id: '', value: obj };
  3804. const parameter = { name: 'value', arguments: [ argument ] };
  3805. layers.push({ states: [ parameter ] });
  3806. return [ { data: layers } ];
  3807. }
  3808. return null;
  3809. }
  3810. static _convertObjectList(obj) {
  3811. if (obj && Array.isArray(obj)) {
  3812. if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
  3813. const layers = [];
  3814. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?';
  3815. const layer = { type: type, states: [], attributes: [] };
  3816. for (let i = 0; i < obj.length; i++) {
  3817. const key = i.toString();
  3818. const value = obj[i];
  3819. if (pytorch.Utility.isTensor(value)) {
  3820. layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
  3821. }
  3822. else {
  3823. layer.attributes.push({ name: key, value: value });
  3824. }
  3825. }
  3826. layers.push(layer);
  3827. return [ { data: layers } ];
  3828. }
  3829. if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
  3830. const layers = [];
  3831. for (const item of obj) {
  3832. const type = item.__class__ ? item.__class__.__module__ + '.' + item.__class__.__name__ : '?';
  3833. const layer = { type: type, states: [], attributes: [] };
  3834. if (item instanceof Map) {
  3835. return null;
  3836. }
  3837. for (const entry of Object.entries(item)) {
  3838. const key = entry[0];
  3839. const value = entry[1];
  3840. if (pytorch.Utility.isTensor(value)) {
  3841. layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
  3842. }
  3843. else {
  3844. layer.attributes.push({ name: key, value: value });
  3845. }
  3846. }
  3847. layers.push(layer);
  3848. }
  3849. return [ { data: layers } ];
  3850. }
  3851. }
  3852. return null;
  3853. }
  3854. static _convertStateDict(obj) {
  3855. const clean = (obj) => {
  3856. if (obj && Array.isArray(obj)) {
  3857. return obj;
  3858. }
  3859. if (obj && obj instanceof Map) {
  3860. return obj;
  3861. }
  3862. if (obj && Object(obj) === obj) {
  3863. const target = {};
  3864. const map_count = Object.entries(obj).filter((entry) => entry[1] instanceof Map).length;
  3865. for (const entry of Object.entries(obj)) {
  3866. const key = entry[0];
  3867. const value = entry[1];
  3868. if (key.indexOf('optim') !== -1 || key.indexOf('opt') !== -1) {
  3869. if (value === null || (value.state && value.param_groups)) {
  3870. continue;
  3871. }
  3872. }
  3873. if (map_count > 2 && key.endsWith('_avg') && pytorch.Utility.isTensor(value)) {
  3874. continue;
  3875. }
  3876. if (typeof value === 'number' || typeof value === 'string' || typeof value === 'boolean') {
  3877. continue;
  3878. }
  3879. if (key === '__class__' && value.__module__ && value.__name__) {
  3880. continue;
  3881. }
  3882. if (Array.isArray(value) && (key.indexOf('loss') !== -1 || value.length === 0)) {
  3883. continue;
  3884. }
  3885. if (value && value.__class__ && value.__class__.__module__ === 'datetime' && value.__class__.__name__ === 'datetime') {
  3886. continue;
  3887. }
  3888. if ((key.startsWith('dico_') && Object(value) === value) ||
  3889. (key === 'args' && Object(value) === value) ||
  3890. (key.startsWith('params') && Object(value) === value && (value.id2lang || value.lang2id)) ||
  3891. (key.startsWith('spk_dict_') && Object(value) === value && Object.keys(value).length === 0)) {
  3892. continue;
  3893. }
  3894. target[key] = value;
  3895. }
  3896. return target;
  3897. }
  3898. return obj;
  3899. };
  3900. const validate = (map) => {
  3901. let tensor = false;
  3902. if (map && map instanceof Map) {
  3903. for (const pair of map) {
  3904. const key = pair[0];
  3905. const value = pair[1];
  3906. const separator = key.indexOf('.') === -1 && key.indexOf('|') !== -1 ? '|' : '.';
  3907. const keys = key.split(separator);
  3908. if (keys[keys.length - 1] === '_metadata') {
  3909. continue;
  3910. }
  3911. else if (keys.length >= 2 && keys[keys.length - 2] === '_packed_params') {
  3912. continue;
  3913. }
  3914. else if (pytorch.Utility.isTensor(value)) {
  3915. tensor = true;
  3916. continue;
  3917. }
  3918. else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
  3919. tensor = true;
  3920. continue;
  3921. }
  3922. else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
  3923. continue;
  3924. }
  3925. else if (value === null) {
  3926. continue;
  3927. }
  3928. return false;
  3929. }
  3930. }
  3931. return tensor;
  3932. };
  3933. const flatten = (obj) => {
  3934. if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {
  3935. return null;
  3936. }
  3937. if (obj instanceof Map) {
  3938. if (validate(obj)) {
  3939. return obj;
  3940. }
  3941. return null;
  3942. }
  3943. if (Object(obj) !== obj) {
  3944. return null;
  3945. }
  3946. const map = new Map(Object.keys(obj).map((key) => [ key, obj[key] ]));
  3947. if (validate(map)) {
  3948. return map;
  3949. }
  3950. map.clear();
  3951. for (const key of Object.keys(obj)) {
  3952. const value = flatten(obj[key]);
  3953. if (value && value instanceof Map) {
  3954. for (const pair of value) {
  3955. map.set(key + '.' + pair[0], pair[1]);
  3956. }
  3957. continue;
  3958. }
  3959. return null;
  3960. }
  3961. return map;
  3962. };
  3963. if (!obj) {
  3964. return null;
  3965. }
  3966. obj = clean(obj);
  3967. const map = new Map();
  3968. if (Array.isArray(obj) && obj.every((item) => validate(item))) {
  3969. for (let i = 0; i < obj.length; i++) {
  3970. map.set(i.toString(), flatten(obj[i]));
  3971. }
  3972. }
  3973. else if (obj instanceof Map && validate(obj)) {
  3974. map.set('', flatten(obj));
  3975. }
  3976. else if (Object(obj) === obj && Object.entries(obj).every((entry) => validate(entry[1]))) {
  3977. for (const entry of Object.entries(obj)) {
  3978. map.set(entry[0], entry[1]);
  3979. }
  3980. }
  3981. else if (Object(obj) === obj && Object.entries(obj).every((entry) => pytorch.Utility.isTensor(entry[1]))) {
  3982. map.set('', new Map(Object.keys(obj).map((key) => [ key, obj[key] ])));
  3983. }
  3984. else {
  3985. const value = flatten(obj);
  3986. if (value) {
  3987. map.set('', value);
  3988. }
  3989. }
  3990. if (map.size > 0) {
  3991. const graphs = [];
  3992. for (const entry of map) {
  3993. const graph_key = entry[0];
  3994. const layer_map = entry[1];
  3995. const layers = new Map();
  3996. for (const item of layer_map) {
  3997. const key = item[0];
  3998. const value = item[1];
  3999. let layerName = '';
  4000. let parameter = '';
  4001. const separator = key.indexOf('.') === -1 && key.indexOf('|') !== -1 ? '|' : '.';
  4002. const keys = key.split(separator);
  4003. if (keys[keys.length - 1] === '_metadata') {
  4004. continue;
  4005. }
  4006. if (keys.length >= 2 && keys[keys.length - 2] === '_packed_params') {
  4007. parameter = keys.slice(-2).join(separator);
  4008. keys.pop();
  4009. keys.pop();
  4010. }
  4011. else {
  4012. parameter = keys.pop();
  4013. if (keys.length < 0) {
  4014. keys.push('');
  4015. }
  4016. }
  4017. layerName = keys.join(separator);
  4018. if (!layers.has(layerName)) {
  4019. layers.set(layerName, { name: layerName, states: [], attributes: [] });
  4020. }
  4021. const layer = layers.get(layerName);
  4022. if (pytorch.Utility.isTensor(value)) {
  4023. layer.states.push({ name: parameter, arguments: [ { id: key, value: value } ] });
  4024. if (layer.name == '' && layer.states.length > 12) {
  4025. return null;
  4026. }
  4027. }
  4028. else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
  4029. layer.states.push({ name: parameter, arguments: value.map((item) => { return { id: '', value: item }; }) });
  4030. }
  4031. else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
  4032. layer.attributes.push({ name: parameter, value: value });
  4033. }
  4034. }
  4035. graphs.push({
  4036. name: graph_key,
  4037. data: layers.values()
  4038. });
  4039. }
  4040. return graphs;
  4041. }
  4042. return null;
  4043. }
  4044. };
  4045. pytorch.nnapi = {};
  4046. pytorch.nnapi.SerializedModel = class {
  4047. constructor(serialized_model, buffers) {
  4048. const reader = new base.BinaryReader(serialized_model);
  4049. this.version = reader.int32();
  4050. if (this.version !== 1) {
  4051. throw new pytorch.Error('Invalid NNAPI serialized model version.');
  4052. }
  4053. const operands = new Array(reader.int32());
  4054. const values = new Array(reader.int32());
  4055. this.operations = new Array(reader.int32());
  4056. this.inputs = new Array(reader.int32());
  4057. this.outputs = new Array(reader.int32());
  4058. const data_types = new Map([
  4059. [ 0, 'float32' ],
  4060. [ 1, 'int32' ],
  4061. [ 2, 'uint32' ],
  4062. [ 3, 'float32[]' ],
  4063. [ 4, 'int32[]' ],
  4064. [ 5, 'quant8_asymm[]' ],
  4065. [ 6, 'boolean' ],
  4066. [ 7, 'quant16_symm[]' ],
  4067. [ 8, 'float16[]' ],
  4068. [ 9, 'boolean[]' ],
  4069. [ 10, 'float16' ],
  4070. [ 11, 'quant8_symm_per_channel[]' ],
  4071. [ 12, 'quant16_asymm[]' ],
  4072. [ 13, 'quant8_symm[]' ],
  4073. [ 14, 'quant8_asymm_signed[]' ],
  4074. [ 16, 'model' ]
  4075. ]);
  4076. for (let i = 0; i < operands.length; i++) {
  4077. const data_type = reader.int32();
  4078. operands[i] = {
  4079. index: i,
  4080. data_type: data_types.has(data_type) ? data_types.get(data_type) : data_type,
  4081. dimensions: new Array(reader.uint32()),
  4082. scale: reader.float32(),
  4083. zero_point: reader.int32()
  4084. };
  4085. }
  4086. for (let i = 0; i < values.length; i++) {
  4087. values[i] = {
  4088. index: reader.int32(),
  4089. source_type: reader.int32(),
  4090. source_length: reader.uint32()
  4091. };
  4092. }
  4093. for (let i = 0; i < this.operations.length; i++) {
  4094. this.operations[i] = {
  4095. index: reader.int32(),
  4096. location: i,
  4097. inputs: new Array(reader.uint32()),
  4098. outputs: new Array(reader.uint32())
  4099. };
  4100. }
  4101. for (const operand of operands) {
  4102. for (let i = 0; i< operand.dimensions.length; i++) {
  4103. operand.dimensions[i] = reader.uint32();
  4104. }
  4105. }
  4106. for (const value of values) {
  4107. const index = value.index;
  4108. const operand = operands[index];
  4109. switch (value.source_type) {
  4110. case 0: { // immediate
  4111. switch (operand.data_type) {
  4112. case 'boolean':
  4113. operand.value = reader.byte() ? true : false;
  4114. reader.skip(3);
  4115. break;
  4116. case 'int32':
  4117. operand.value = reader.int32();
  4118. break;
  4119. case 'float32':
  4120. operand.value = reader.float32();
  4121. break;
  4122. case 'int32[]':
  4123. operand.data = reader.read(value.source_length);
  4124. break;
  4125. case 'float32[]':
  4126. operand.data = reader.read(value.source_length);
  4127. break;
  4128. default:
  4129. throw new pytorch.Error("Unsupported NNAPI operand type '" + operand.data_type.toString() + "'.");
  4130. }
  4131. break;
  4132. }
  4133. case 2: { // numbered buffer
  4134. if (value.source_length !== 12) {
  4135. throw new pytorch.Error('Invalid NNAPI numbered buffer source length.');
  4136. }
  4137. const number = reader.uint32();
  4138. const offset = reader.uint32();
  4139. const operand_length = reader.uint32();
  4140. const buffer = buffers[number];
  4141. operand.data = buffer.slice(offset, operand_length);
  4142. break;
  4143. }
  4144. case 3: { // numbered memory
  4145. throw new pytorch.Error('NNAPI numbered memory buffer not implemented.');
  4146. }
  4147. default: {
  4148. throw new pytorch.Error('Unsupported NNAPI value source type.');
  4149. }
  4150. }
  4151. }
  4152. for (const operation of this.operations) {
  4153. for (let i = 0; i< operation.inputs.length; i++) {
  4154. const index = reader.uint32();
  4155. operation.inputs[i] = operands[index];
  4156. }
  4157. for (let i = 0; i< operation.outputs.length; i++) {
  4158. const index = reader.uint32();
  4159. operation.outputs[i] = operands[index];
  4160. }
  4161. }
  4162. for (let i = 0; i< this.inputs.length; i++) {
  4163. const index = reader.uint32();
  4164. this.inputs[i] = operands[index];
  4165. }
  4166. for (let i = 0; i< this.outputs.length; i++) {
  4167. const index = reader.uint32();
  4168. this.outputs[i] = operands[index];
  4169. }
  4170. if (reader.position !== reader.length) {
  4171. throw new pytorch.Error('Invalid NNAPI serialized model length.');
  4172. }
  4173. }
  4174. };
  4175. pytorch.nnapi.Metadata = class {
  4176. constructor() {
  4177. this._types = new Map();
  4178. // https://developer.android.com/ndk/reference/group/neural-networks
  4179. // https://github.com/pytorch/pytorch/commits/master/torch/backends/_nnapi/serializer.py
  4180. this.register( 0, 'ADD', '', [ 'A', 'B' ], [ [ 'activation', 'int32'] ], [ 'C' ]);
  4181. this.register( 1, 'AVERAGE_POOL_2D', 'Pool', [ 'input' ], [ [ 'padding_left', 'int32' ], [ 'padding_right', 'int32' ], [ 'padding_top', 'int32' ], [ 'padding_bottom', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'filter_x', 'int32' ], [ 'filter_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ] ], [ 'output' ]);
  4182. this.register( 1, 'AVERAGE_POOL_2D', 'Pool', [ 'input' ], [ [ 'padding_scheme', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'filter_x', 'int32' ], [ 'filter_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ] ], [ 'output' ]);
  4183. this.register( 2, 'CONCATENATION');
  4184. this.register( 3, 'CONV_2D', 'Layer', [ 'input', 'weights', 'bias' ], [ [ 'padding_left', 'int32' ], [ 'padding_right', 'int32' ], [ 'padding_top', 'int32' ], [ 'padding_bottom', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ], [ 'dilation_width', 'int32' ], [ 'dilation_height', 'int32' ] ], [ 'output' ]);
  4185. this.register( 3, 'CONV_2D', 'Layer', [ 'input', 'weights', 'bias' ], [ [ 'padding_scheme', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ], [ 'dilation_width', 'int32' ], [ 'dilation_height', 'int32' ] ], [ 'output' ]);
  4186. this.register( 4, 'DEPTHWISE_CONV_2D', 'Layer', [ 'input', 'weights', 'bias' ], [ [ 'padding_left', 'int32' ], [ 'padding_right', 'int32' ], [ 'padding_top', 'int32' ], [ 'padding_bottom', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ], [ 'dilation_width', 'int32' ], [ 'dilation_height', 'int32' ] ], [ 'output' ]);
  4187. this.register( 4, 'DEPTHWISE_CONV_2D', 'Layer', [ 'input', 'weights', 'bias' ], [ [ 'padding_scheme', 'int32' ], [ 'stride_x', 'int32' ], [ 'stride_y', 'int32' ], [ 'activation', 'int32' ], [ 'nchw', 'boolean' ], [ 'dilation_width', 'int32' ], [ 'dilation_height', 'int32' ] ], [ 'output' ]);
  4188. this.register( 5, 'DEPTH_TO_SPACE');
  4189. this.register( 6, 'DEQUANTIZE');
  4190. this.register( 7, 'EMBEDDING_LOOKUP');
  4191. this.register( 8, 'FLOOR');
  4192. this.register( 9, 'FULLY_CONNECTED', 'Layer', [ 'input', 'weights', 'bias' ], [ [ 'activation', 'int32' ] ], [ 'output' ]);
  4193. this.register(10, 'HASHTABLE_LOOKUP');
  4194. this.register(11, 'L2_NORMALIZATION');
  4195. this.register(12, 'L2_POOL_2D', 'Pool');
  4196. this.register(13, 'LOCAL_RESPONSE_NORMALIZATION');
  4197. this.register(14, 'LOGISTIC');
  4198. this.register(15, 'LSH_PROJECTION');
  4199. this.register(16, 'LSTM', 'Layer');
  4200. this.register(17, 'MAX_POOL_2D', 'Pool');
  4201. this.register(18, 'MUL');
  4202. this.register(19, 'RELU', 'Activation', [ 'input' ], [], [ 'output' ]);
  4203. this.register(20, 'RELU1', 'Activation');
  4204. this.register(21, 'RELU6', 'Activation');
  4205. this.register(22, 'RESHAPE', 'Shape', [ 'input', 'shape' ], [], [ 'output' ]);
  4206. this.register(23, 'RESIZE_BILINEAR');
  4207. this.register(24, 'RNN', 'Layer');
  4208. this.register(25, 'SOFTMAX', 'Activation');
  4209. this.register(26, 'SPACE_TO_DEPTH');
  4210. this.register(27, 'SVDF');
  4211. this.register(28, 'TANH');
  4212. this.register(29, 'BATCH_TO_SPACE_ND');
  4213. this.register(30, 'DIV');
  4214. this.register(31, 'MEAN');
  4215. this.register(32, 'PAD');
  4216. this.register(33, 'SPACE_TO_BATCH_ND');
  4217. this.register(34, 'SQUEEZE');
  4218. this.register(35, 'STRIDED_SLICE');
  4219. this.register(36, 'SUB');
  4220. this.register(37, 'TRANSPOSE');
  4221. this.register(38, 'ABS');
  4222. this.register(39, 'ARGMAX');
  4223. this.register(40, 'ARGMIN');
  4224. this.register(41, 'AXIS_ALIGNED_BBOX_TRANSFORM');
  4225. this.register(42, 'BIDIRECTIONAL_SEQUENCE_LSTM');
  4226. this.register(43, 'BIDIRECTIONAL_SEQUENCE_RNN');
  4227. this.register(44, 'BOX_WITH_NMS_LIMIT');
  4228. this.register(45, 'CAST');
  4229. this.register(46, 'CHANNEL_SHUFFLE');
  4230. this.register(47, 'DETECTION_POSTPROCESSING');
  4231. this.register(48, 'EQUAL');
  4232. this.register(49, 'EXP');
  4233. this.register(50, 'EXPAND_DIMS');
  4234. this.register(51, 'GATHER');
  4235. this.register(52, 'GENERATE_PROPOSALS');
  4236. this.register(53, 'GREATER');
  4237. this.register(54, 'GREATER_EQUAL');
  4238. this.register(55, 'GROUPED_CONV_2D');
  4239. this.register(56, 'HEATMAP_MAX_KEYPOINT');
  4240. this.register(57, 'INSTANCE_NORMALIZATION');
  4241. this.register(58, 'LESS');
  4242. this.register(59, 'LESS_EQUAL');
  4243. this.register(60, 'LOG');
  4244. this.register(61, 'LOGICAL_AND');
  4245. this.register(62, 'LOGICAL_NOT');
  4246. this.register(63, 'LOGICAL_OR');
  4247. this.register(64, 'LOG_SOFTMAX');
  4248. this.register(65, 'MAXIMUM');
  4249. this.register(66, 'MINIMUM');
  4250. this.register(67, 'NEG');
  4251. this.register(68, 'NOT_EQUAL');
  4252. this.register(69, 'PAD_V2');
  4253. this.register(70, 'POW');
  4254. this.register(71, 'PRELU');
  4255. this.register(72, 'QUANTIZE');
  4256. this.register(73, 'QUANTIZED_16BIT_LSTM');
  4257. this.register(74, 'RANDOM_MULTINOMIAL');
  4258. this.register(75, 'REDUCE_ALL');
  4259. this.register(76, 'REDUCE_ANY');
  4260. this.register(77, 'REDUCE_MAX');
  4261. this.register(78, 'REDUCE_MIN');
  4262. this.register(79, 'REDUCE_PROD');
  4263. this.register(80, 'REDUCE_SUM');
  4264. this.register(81, 'ROI_ALIGN');
  4265. this.register(82, 'ROI_POOLING');
  4266. this.register(83, 'RSQRT');
  4267. this.register(84, 'SELECT');
  4268. this.register(85, 'SIN');
  4269. this.register(86, 'SLICE');
  4270. this.register(87, 'SPLIT');
  4271. this.register(88, 'SQRT');
  4272. this.register(89, 'TILE');
  4273. this.register(90, 'TOPK_V2');
  4274. this.register(91, 'TRANSPOSE_CONV_2D', 'Layer');
  4275. this.register(92, 'UNIDIRECTIONAL_SEQUENCE_LSTM', 'Layer');
  4276. this.register(93, 'UNIDIRECTIONAL_SEQUENCE_RNN', 'Layer');
  4277. this.register(94, 'RESIZE_NEAREST_NEIGHBOR');
  4278. this.register(95, 'QUANTIZED_LSTM', 'Layer');
  4279. this.register(96, 'IF');
  4280. this.register(97, 'WHILE');
  4281. this.register(98, 'ELU', 'Activation');
  4282. this.register(99, 'HARD_SWISH', 'Activation');
  4283. this.register(100, 'FILL');
  4284. this.register(101, 'RANK');
  4285. }
  4286. register(index, name, category, inputs, attributes, outputs) {
  4287. inputs = inputs || [];
  4288. outputs = outputs || [];
  4289. attributes = attributes || [];
  4290. const type = {
  4291. name: name,
  4292. inputs: inputs.map((name) => { return { name: name, type: 'Tensor' }; }),
  4293. outputs: outputs.map((name) => { return { name: name, type: 'Tensor' }; }),
  4294. attributes: attributes.map((pair) => { return { name: pair[0], type: pair[1] }; })
  4295. };
  4296. if (category) {
  4297. type.category = category;
  4298. }
  4299. if (!this._types.has(index)) {
  4300. this._types.set(index, []);
  4301. }
  4302. this._types.get(index).push(type);
  4303. }
  4304. type(index, signature) {
  4305. if (!this._types.has(index)) {
  4306. this._types.set(index, { name: index.toString(), inputs: [], outputs: [], attributes: [] });
  4307. }
  4308. const types = this._types.get(index);
  4309. for (const type of types) {
  4310. const inputs = type.inputs.concat(type.attributes);
  4311. if (signature.length < inputs.length) {
  4312. let match = true;
  4313. for (let i = 0; i < inputs.length; i++) {
  4314. const input = inputs[i];
  4315. if (input.type === undefined || input.type === 'Tensor' || input.type === signature[i]) {
  4316. continue;
  4317. }
  4318. match = false;
  4319. }
  4320. if (match) {
  4321. return type;
  4322. }
  4323. }
  4324. }
  4325. return types[0];
  4326. }
  4327. };
  4328. pytorch.nnapi.Graph = class {
  4329. constructor(model) {
  4330. this._nodes = [];
  4331. this._inputs = [];
  4332. this._outputs = [];
  4333. const args = new Map();
  4334. const arg = (operand) => {
  4335. if (!args.has(operand.index)) {
  4336. const argument = new pytorch.nnapi.Argument(operand);
  4337. args.set(operand.index, argument);
  4338. }
  4339. return args.get(operand.index);
  4340. };
  4341. const metadata = new pytorch.nnapi.Metadata();
  4342. for (const operation of model.operations) {
  4343. const node = new pytorch.nnapi.Node(metadata, operation, arg);
  4344. this._nodes.push(node);
  4345. }
  4346. for (let i = 0; i < model.inputs.length; i++) {
  4347. const operand = model.inputs[i];
  4348. const argument = arg(operand);
  4349. const parameter = new pytorch.Parameter(i.toString(), true, [ argument ]);
  4350. this._inputs.push(parameter);
  4351. }
  4352. for (let i = 0; i < model.outputs.length; i++) {
  4353. const operand = model.outputs[i];
  4354. const argument = arg(operand);
  4355. const parameter = new pytorch.Parameter(i.toString(), true, [ argument ]);
  4356. this._outputs.push(parameter);
  4357. }
  4358. }
  4359. get name() {
  4360. return 'torch.classes._nnapi.Compilation';
  4361. }
  4362. get inputs() {
  4363. return this._inputs;
  4364. }
  4365. get outputs() {
  4366. return this._outputs;
  4367. }
  4368. get nodes() {
  4369. return this._nodes;
  4370. }
  4371. };
  4372. pytorch.nnapi.Argument = class {
  4373. constructor(operand) {
  4374. this._name = operand.index.toString();
  4375. const shape = new pytorch.TensorShape(operand.dimensions);
  4376. this._type = new pytorch.TensorType(operand.data_type.replace('[]', ''), shape);
  4377. this._initializer = operand.data ? new pytorch.Tensor(this._name, this._type, operand.data, true) : null;
  4378. this._scale = operand.scale;
  4379. this._zeroPoint = operand.zero_point;
  4380. }
  4381. get name() {
  4382. return this._name;
  4383. }
  4384. get type() {
  4385. return this._type;
  4386. }
  4387. get quantization() {
  4388. if (this._scale != 0 || this._zeroPoint != 0) {
  4389. return this._scale.toString() + ' * ' + (this._zeroPoint == 0 ? 'q' : ('(q - ' + this._zeroPoint.toString() + ')'));
  4390. }
  4391. return null;
  4392. }
  4393. get initializer() {
  4394. return this._initializer;
  4395. }
  4396. };
  4397. pytorch.nnapi.Node = class {
  4398. constructor(metadata, operation, arg) {
  4399. const signature = (operation.inputs || []).map((input) => input.data_type);
  4400. this._type = metadata.type(operation.index, signature);
  4401. this._inputs = [];
  4402. this._outputs = [];
  4403. this._attributes = [];
  4404. this._chain = [];
  4405. if (operation.location !== undefined) {
  4406. this._location = operation.location.toString();
  4407. }
  4408. const inputs = this._type.inputs.concat(this._type.attributes);
  4409. if (operation.inputs) {
  4410. for (let i = 0; i < operation.inputs.length; i++) {
  4411. const name = i < inputs.length ? inputs[i].name : i.toString();
  4412. const operand = operation.inputs[i];
  4413. if (operand.dimensions.length > 0) {
  4414. const argument = arg(operand);
  4415. const parameter = new pytorch.Parameter(name, true, [ argument ]);
  4416. this._inputs.push(parameter);
  4417. }
  4418. else if (name === 'activation') {
  4419. const activation = new Map([ [ 1, 19 ], [ 2, 20 ], [ 3, 21 ] ]).get(operand.value) || 0;
  4420. if (activation !== 0) {
  4421. this._chain.push(new pytorch.nnapi.Node(metadata, { index: activation }));
  4422. }
  4423. }
  4424. else {
  4425. const attribute = new pytorch.nnapi.Attribute(name, operand);
  4426. this._attributes.push(attribute);
  4427. }
  4428. }
  4429. }
  4430. if (operation.outputs) {
  4431. for (let i = 0; i < operation.outputs.length; i++) {
  4432. const name = i < inputs.length ? inputs[i].name : i.toString();
  4433. const operand = operation.outputs[i];
  4434. const argument = arg(operand);
  4435. const parameter = new pytorch.Parameter(name, true, [ argument ]);
  4436. this._outputs.push(parameter);
  4437. }
  4438. }
  4439. }
  4440. get type() {
  4441. return this._type;
  4442. }
  4443. get location() {
  4444. return this._location;
  4445. }
  4446. get inputs() {
  4447. return this._inputs;
  4448. }
  4449. get outputs() {
  4450. return this._outputs;
  4451. }
  4452. get attributes() {
  4453. return this._attributes;
  4454. }
  4455. get chain() {
  4456. return this._chain;
  4457. }
  4458. };
  4459. pytorch.nnapi.Attribute = class {
  4460. constructor(name, operand) {
  4461. this._name = name;
  4462. this._type = operand.data_type;
  4463. this._value = operand.value;
  4464. }
  4465. get type() {
  4466. return this._type;
  4467. }
  4468. get name() {
  4469. return this._name;
  4470. }
  4471. get value() {
  4472. return this._value;
  4473. }
  4474. get visible() {
  4475. return false;
  4476. }
  4477. };
  4478. pytorch.nnapi.Tensor = class {
  4479. constructor(type, data) {
  4480. this._type = type;
  4481. this._data = data;
  4482. }
  4483. get type() {
  4484. return this._type;
  4485. }
  4486. get state() {
  4487. return 'Not implemented.';
  4488. }
  4489. };
  4490. pytorch.Metadata = class {
  4491. static open(context) {
  4492. if (pytorch.Metadata._metadata) {
  4493. return Promise.resolve(pytorch.Metadata._metadata);
  4494. }
  4495. return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
  4496. pytorch.Metadata._metadata = new pytorch.Metadata(data);
  4497. return pytorch.Metadata._metadata;
  4498. }).catch(() => {
  4499. pytorch.Metadata._metadata = new pytorch.Metadata(null);
  4500. return pytorch.Metadata._metadata;
  4501. });
  4502. }
  4503. constructor(data) {
  4504. this._types = new Map();
  4505. this._attributes = new Map();
  4506. if (data) {
  4507. const items = JSON.parse(data);
  4508. for (const item of items) {
  4509. this._types.set(item.name, item);
  4510. const index = item.name.indexOf(':');
  4511. if (index !== -1) {
  4512. const name = item.name.substring(0, index);
  4513. if (!this._types.has(name)) {
  4514. this._types.set(name, []);
  4515. }
  4516. this._types.get(name).push(item.name);
  4517. }
  4518. }
  4519. }
  4520. }
  4521. type(name) {
  4522. const schema = this._types.get(name);
  4523. if (schema) {
  4524. return Array.isArray(schema) ? schema.map((name) => this._types.get(name)) : schema;
  4525. }
  4526. return null;
  4527. }
  4528. attribute(type, name) {
  4529. const attributeName = type + ':' + name;
  4530. if (!this._attributes.has(attributeName)) {
  4531. this._attributes.set(attributeName, null);
  4532. const schema = this.type(type);
  4533. if (schema) {
  4534. if (schema.inputs) {
  4535. for (const input of schema.inputs) {
  4536. this._attributes.set(type + ':' + input.name, input);
  4537. }
  4538. }
  4539. if (schema.attributes) {
  4540. for (const attribute of schema.attributes) {
  4541. this._attributes.set(type + ':' + attribute.name, attribute);
  4542. }
  4543. }
  4544. }
  4545. }
  4546. return this._attributes.get(attributeName);
  4547. }
  4548. };
  4549. pytorch.Error = class extends Error {
  4550. constructor(message) {
  4551. super(message);
  4552. this.name = 'Error loading PyTorch model.';
  4553. }
  4554. };
  4555. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  4556. module.exports.ModelFactory = pytorch.ModelFactory;
  4557. }