pytorch.js 164 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var pytorch = pytorch || {};
  4. var python = python || require('./python');
  5. var base = base || require('./base');
  6. pytorch.ModelFactory = class {
  7. match(context) {
  8. if (pytorch.Container.open(context)) {
  9. return true;
  10. }
  11. return false;
  12. }
  13. open(context) {
  14. const identifier = context.identifier;
  15. return pytorch.Metadata.open(context).then((metadata) => {
  16. let container = null;
  17. try {
  18. container = pytorch.Container.open(context, metadata, (error, fatal) => {
  19. const message = error && error.message ? error.message : error.toString();
  20. context.exception(new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'."), fatal);
  21. });
  22. }
  23. catch (error) {
  24. const message = error && error.message ? error.message : error.toString();
  25. throw new pytorch.Error('File format is not PyTorch (' + message.replace(/\.$/, '') + ').');
  26. }
  27. return new pytorch.Model(metadata, container);
  28. });
  29. }
  30. };
  31. pytorch.Model = class {
  32. constructor(metadata, container) {
  33. this._format = container.format;
  34. this._producer = container.producer || '';
  35. this._graphs = [];
  36. const type = container.type;
  37. switch (type) {
  38. case 'script':
  39. this._graphs.push(new pytorch.Graph(metadata, type, container.data, container));
  40. break;
  41. case 'module':
  42. case 'weights':
  43. for (const data of container.data) {
  44. this._graphs.push(new pytorch.Graph(metadata, type, data, container));
  45. }
  46. break;
  47. }
  48. }
  49. get format() {
  50. return this._format;
  51. }
  52. get graphs() {
  53. return this._graphs;
  54. }
  55. };
  56. pytorch.Graph = class {
  57. constructor(metadata, type, data, container) {
  58. this._nodes = [];
  59. this._inputs = [];
  60. this._outputs = [];
  61. this._groups = true;
  62. this._littleEndian = container.littleEndian;
  63. switch (type) {
  64. case 'script': {
  65. this._name = container.name;
  66. const traced = container.trace();
  67. const initializers = new Map();
  68. if (container.constants) {
  69. for (const constant of container.constants) {
  70. if (pytorch.Utility.isTensor(constant)) {
  71. constant.initializer = new pytorch.Tensor(constant.__variable__, constant, this._littleEndian);
  72. initializers.set(constant.__variable__, constant);
  73. }
  74. else if (constant && constant.__class__ && constant.__class__.__module__ === '__torch__.torch.classes.xnnpack') {
  75. switch (constant.__class__.__name__) {
  76. case 'LinearOpContext':
  77. case 'Conv2dOpContext':
  78. for (const key of Object.keys(constant)) {
  79. const value = constant[key];
  80. if (pytorch.Utility.isTensor(value)) {
  81. value.initializer = new pytorch.Tensor(value.__variable__, value, this._littleEndian);
  82. initializers.set(value.__variable__, value);
  83. }
  84. }
  85. break;
  86. default:
  87. throw new pytorch.Error("Unsupported constant context '" + constant.__class__.__name__ + "'.");
  88. }
  89. }
  90. else {
  91. throw new pytorch.Error('Unsupported constant.');
  92. }
  93. }
  94. }
  95. if (data) {
  96. const queue = [ data ];
  97. while (queue.length > 0) {
  98. const module = queue.shift();
  99. if (module.__class__ && module.__class__.__module__ === '__torch__.torch.classes._nnapi' && module.__class__.__name__ === 'Compilation') {
  100. continue;
  101. }
  102. for (const key of Object.keys(module)) {
  103. if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') {
  104. const obj = module[key];
  105. if (!Array.isArray(obj) && obj === Object(obj)) {
  106. if (pytorch.Utility.isTensor(obj)) {
  107. const parameter = obj;
  108. parameter.__parent__ = module;
  109. if (!parameter.initializer && parameter.storage()) {
  110. parameter.initializer = new pytorch.Tensor(parameter.name, parameter, this._littleEndian);
  111. }
  112. if (parameter.__variable__ && parameter.__count__ === 1) {
  113. initializers.set(parameter.__variable__, parameter);
  114. }
  115. }
  116. else if (obj && obj.__class__) {
  117. obj.__parent__ = module;
  118. if (!obj.__id__) {
  119. obj.__id__ = key;
  120. }
  121. queue.push(obj);
  122. }
  123. }
  124. }
  125. }
  126. }
  127. }
  128. if (traced) {
  129. if (container.inputs) {
  130. for (const input of container.inputs) {
  131. this._inputs.push(new pytorch.Parameter(input, true, [
  132. new pytorch.Argument(input, null, null)
  133. ]));
  134. }
  135. }
  136. if (container.outputs) {
  137. for (const output of container.outputs) {
  138. this._outputs.push(new pytorch.Parameter(output, true, [
  139. new pytorch.Argument(output, null, null)
  140. ]));
  141. }
  142. }
  143. if (container.nodes) {
  144. for (const node of container.nodes) {
  145. const item = {
  146. type: node.type,
  147. node: node
  148. };
  149. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  150. }
  151. }
  152. }
  153. if (data) {
  154. this._loadScriptModule(metadata, container, data, initializers);
  155. }
  156. break;
  157. }
  158. case 'module': {
  159. this._name = data.name || '';
  160. this._type = (data.obj.__module__ && data.obj.__name__) ? (data.obj.__module__ + '.' + data.obj.__name__) : '';
  161. this._loadModule(metadata, data.obj, [], []);
  162. break;
  163. }
  164. case 'weights': {
  165. this._name = data.name || '';
  166. for (const state_group of data.layers) {
  167. const attributes = state_group.attributes || [];
  168. const inputs = state_group.states.map((parameter) => {
  169. return new pytorch.Parameter(parameter.name, true,
  170. parameter.arguments.map((state) => {
  171. const tensor = new pytorch.Tensor(state.id, pytorch.Utility.toTensor(state.value), this._littleEndian);
  172. return new pytorch.Argument(state.id, null, tensor);
  173. }));
  174. });
  175. const obj = {
  176. name: state_group.name,
  177. type: state_group.type || 'torch.nn.Module',
  178. attributes: attributes,
  179. inputs: inputs,
  180. outputs: []
  181. };
  182. this._nodes.push(new pytorch.Node(metadata, '', obj, null));
  183. }
  184. }
  185. }
  186. }
  187. _loadModule(metadata, current, groups, inputs) {
  188. if (current.__class__ && current.__class__.__module__ !== 'torch.nn.modules.container' && (!current._modules || current._modules.size == 0)) {
  189. this._createNode(metadata, groups, '', current, inputs, false);
  190. return [];
  191. }
  192. if (!current._modules) {
  193. throw new pytorch.Error('Module does not contain modules.');
  194. }
  195. const sequential = current.__class__ && current.__class__.__module__ === 'torch.nn.modules.container' && current.__class__.__name__ === 'Sequential';
  196. for (const pair of current._modules) {
  197. const key = pair[0];
  198. const value = pair[1];
  199. if (value) {
  200. const type = value.__class__.__module__ + '.' + value.__class__.__name__;
  201. switch (type) {
  202. case 'torch.nn.modules.container.Sequential':
  203. groups.push(key);
  204. inputs = this._loadModule(metadata, value, groups, sequential ? inputs : []);
  205. groups.pop(key);
  206. break;
  207. default: {
  208. inputs = this._createNode(metadata, groups, key, value, sequential ? inputs : [], sequential);
  209. break;
  210. }
  211. }
  212. }
  213. }
  214. return inputs;
  215. }
  216. _createNode(metadata, groups, key, obj, args, output) {
  217. const type = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  218. const schema = metadata.type(type);
  219. let inputSchema = [ { name: 'input'} ];
  220. if (schema && schema.inputs && schema.inputs.length > 0) {
  221. inputSchema = schema.inputs.slice();
  222. }
  223. const inputName = inputSchema.shift().name;
  224. const inputs = [];
  225. if (args.length > 0) {
  226. inputs.push(new pytorch.Parameter(inputName, true, args.map((argument) => {
  227. return new pytorch.Argument(argument, null, null);
  228. })));
  229. }
  230. const parameters = obj._parameters || obj._buffers || [];
  231. for (const parameter of parameters) {
  232. const key = parameter[0];
  233. const value = pytorch.Utility.toTensor(parameter[1]);
  234. let visible = true;
  235. let inputName = '';
  236. if (inputSchema.length > 0) {
  237. const input = inputSchema.shift();
  238. inputName = input.name;
  239. visible = input.visible === false ? false : true;
  240. }
  241. if (value) {
  242. const initializer = new pytorch.Tensor('', value, this._littleEndian);
  243. inputs.push(new pytorch.Parameter(inputName || key, visible, [ new pytorch.Argument('', null, initializer) ]));
  244. }
  245. }
  246. const group = groups.join('/');
  247. const name = group ? (group + '/' + key) : key;
  248. const outputs = output ? [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ] : [];
  249. const attributes = [];
  250. for (const name of Object.keys(obj)) {
  251. if (name.startsWith('_')) {
  252. continue;
  253. }
  254. attributes.push({ name: name, value: obj[name] });
  255. }
  256. const item = {
  257. name: name,
  258. type: type,
  259. attributes: attributes,
  260. children: obj._modules && obj._modules.size > 0 ? true : false,
  261. inputs: inputs,
  262. outputs: outputs
  263. };
  264. const node = new pytorch.Node(metadata, group, item, {});
  265. this._nodes.push(node);
  266. return [ node.name ];
  267. }
  268. _loadScriptModule(metadata, container, module, initializers) {
  269. if (module) {
  270. if (pytorch.Graph._getParameters(module).length > 0 && !module.__hide__) {
  271. const item = { module: module };
  272. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  273. }
  274. const submodules = pytorch.Graph._getSubmodules(module);
  275. for (const submodule of submodules) {
  276. this._loadScriptModule(metadata, container, submodule, initializers);
  277. }
  278. }
  279. }
  280. static _getParameters(module) {
  281. const parameters = [];
  282. if (module && module.__class__.__module__ && module.__class__.__name__) {
  283. for (const key of Object.keys(module)) {
  284. if (pytorch.Utility.isTensor(module[key])) {
  285. const parameter = module[key];
  286. parameter.__id__ = key;
  287. parameters.push(parameter);
  288. }
  289. }
  290. }
  291. return parameters;
  292. }
  293. static _getSubmodules(module) {
  294. const submodules = [];
  295. if (module && module.__class__ && module.__class__.__module__ && module.__class__.__name__) {
  296. for (const key of Object.keys(module)) {
  297. if (!key.startsWith('__')) {
  298. const value = module[key];
  299. if (value && value.__class__ && value.__module__ && value.__name__ && !pytorch.Utility.isTensor(value)) {
  300. submodules.push(value);
  301. }
  302. }
  303. }
  304. }
  305. return submodules;
  306. }
  307. get type() {
  308. return this._type;
  309. }
  310. get name() {
  311. return this._name;
  312. }
  313. get groups() {
  314. return this._groups;
  315. }
  316. get inputs() {
  317. return this._inputs;
  318. }
  319. get outputs() {
  320. return this._outputs;
  321. }
  322. get nodes() {
  323. return this._nodes;
  324. }
  325. };
  326. pytorch.Parameter = class {
  327. constructor(name, visible, args) {
  328. this._name = name;
  329. this._visible = visible;
  330. this._arguments = args;
  331. }
  332. get name() {
  333. return this._name;
  334. }
  335. get visible() {
  336. return this._visible;
  337. }
  338. get arguments() {
  339. return this._arguments;
  340. }
  341. };
  342. pytorch.Argument = class {
  343. constructor(name, type, initializer) {
  344. if (typeof name !== 'string') {
  345. throw new pytorch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  346. }
  347. this._name = name;
  348. this._type = type;
  349. this._initializer = initializer;
  350. }
  351. get name() {
  352. return this._name;
  353. }
  354. get type() {
  355. if (this._initializer) {
  356. return this._initializer.type;
  357. }
  358. return this._type;
  359. }
  360. get initializer() {
  361. return this._initializer;
  362. }
  363. };
  364. pytorch.Node = class {
  365. constructor(metadata, group, item, initializers) {
  366. this._group = group || '';
  367. this._name = item.name || '';
  368. const type = (metadata, name) => {
  369. this._type = Object.assign({}, metadata.type(name) || { name: name });
  370. const identifier = this._type.name;
  371. this._type.identifier = identifier;
  372. const index = identifier.indexOf(':');
  373. this._type.name = index === -1 ? identifier : identifier.substring(0, index);
  374. };
  375. if (!item.module && !item.node) {
  376. type(metadata, item.type);
  377. this._function = item.children;
  378. this._inputs = item.inputs;
  379. this._outputs = item.outputs;
  380. this._attributes = item.attributes.map((attribute) => {
  381. const schema = metadata.attribute(this._type.identifier, attribute.name);
  382. return new pytorch.Attribute(schema, attribute.name, attribute.value);
  383. });
  384. }
  385. else {
  386. this._attributes = [];
  387. this._inputs = [];
  388. this._outputs = [];
  389. let module = item.module;
  390. if (module) {
  391. this._type = { name: 'torch.nn.modules.module.Module' };
  392. for (const parameter of pytorch.Graph._getParameters(module)) {
  393. this._inputs.push(new pytorch.Parameter(parameter.__id__, true, [
  394. new pytorch.Argument('', null, parameter.initializer || null)
  395. ]));
  396. if (parameter.__variable__) {
  397. this._outputs.push(new pytorch.Parameter(parameter.__id__, true, [
  398. new pytorch.Argument(parameter.__variable__, null, null)
  399. ]));
  400. }
  401. }
  402. }
  403. if (item.node) {
  404. type(metadata, item.type);
  405. module = null;
  406. let match = true;
  407. let count = 0;
  408. for (const input of item.node.inputs) {
  409. for (const argument of input) {
  410. const parameter = initializers.get(argument.id);
  411. if (parameter) {
  412. if (parameter.__parent__ && (module == null || module == parameter.__parent__)) {
  413. module = parameter.__parent__;
  414. count++;
  415. }
  416. else if (parameter.__variable__.startsWith('CONSTANTS.c')) {
  417. argument.initializer = parameter.initializer;
  418. count++;
  419. }
  420. else {
  421. match = false;
  422. break;
  423. }
  424. }
  425. }
  426. if (!match) {
  427. break;
  428. }
  429. }
  430. if (module) {
  431. const params = pytorch.Graph._getParameters(module).filter((p) => p.__id__ !== 'num_batches_tracked');
  432. if (params.length == count && match) {
  433. module.__hide__ = true;
  434. for (const input of item.node.inputs) {
  435. for (const argument of input) {
  436. const parameter = initializers.get(argument.id);
  437. if (parameter && parameter.initializer) {
  438. argument.initializer = parameter.initializer;
  439. }
  440. }
  441. }
  442. }
  443. else {
  444. module = null;
  445. }
  446. }
  447. for (let inputIndex = 0; inputIndex < item.node.inputs.length; inputIndex++) {
  448. let inputName = inputIndex.toString();
  449. if (this._type && this._type.inputs && this._type.inputs.length > inputIndex) {
  450. inputName = this._type.inputs[inputIndex].name;
  451. }
  452. this._inputs.push(new pytorch.Parameter(inputName, true,
  453. item.node.inputs[inputIndex].map((input) => new pytorch.Argument(input.id, null, input.initializer || null))
  454. ));
  455. }
  456. for (let outputIndex = 0; outputIndex < item.node.outputs.length; outputIndex++) {
  457. let outputName = outputIndex.toString();
  458. if (this._type && this._type.outputs && this._type.outputs.length > outputIndex) {
  459. outputName = this._type.outputs[outputIndex].name;
  460. }
  461. this._outputs.push(new pytorch.Parameter(outputName, true,
  462. item.node.outputs[outputIndex].map((output) => new pytorch.Argument(output.id, null, null))
  463. ));
  464. }
  465. for (const attribute of item.node.attributes) {
  466. const name = attribute.name;
  467. const value = attribute.value;
  468. const schema = metadata.attribute(this._type.identifier, name);
  469. this._attributes.push(new pytorch.Attribute(schema, name, value));
  470. }
  471. }
  472. if (module) {
  473. if (module.__id__) {
  474. let current = module;
  475. this._name = current.__id__;
  476. while (current.__parent__ != null) {
  477. current = current.__parent__;
  478. if (!current.__parent__ && !current.__id__) {
  479. break;
  480. }
  481. this._name = [ current.__id__, this._name ].join('.');
  482. }
  483. }
  484. }
  485. }
  486. }
  487. get name() {
  488. return this._name;
  489. }
  490. get group() {
  491. return this._group;
  492. }
  493. get type() {
  494. return this._type;
  495. }
  496. get function() {
  497. return this._function;
  498. }
  499. get attributes() {
  500. return this._attributes;
  501. }
  502. get inputs() {
  503. return this._inputs;
  504. }
  505. get outputs() {
  506. return this._outputs;
  507. }
  508. };
  509. pytorch.Attribute = class {
  510. constructor(schema, name, value) {
  511. this._name = name;
  512. this._value = value;
  513. if (this._name === 'training') {
  514. this._visible = false;
  515. this._type = 'boolean';
  516. return;
  517. }
  518. if (schema) {
  519. if (Object.prototype.hasOwnProperty.call(schema, 'type')) {
  520. this._type = schema.type;
  521. }
  522. if (schema.visible === false) {
  523. this._visible = false;
  524. }
  525. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  526. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  527. this._visible = false;
  528. }
  529. else if (Array.isArray(this._value) && !Array.isArray(schema.default) && this.value.every((item) => item == schema.default)) {
  530. this._visible = false;
  531. }
  532. }
  533. }
  534. if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
  535. this._value = '?';
  536. }
  537. }
  538. get type() {
  539. return this._type;
  540. }
  541. get name() {
  542. return this._name;
  543. }
  544. get value() {
  545. return this._value;
  546. }
  547. get visible() {
  548. return this._visible == false ? false : true;
  549. }
  550. };
  551. pytorch.Tensor = class {
  552. constructor(name, tensor, littleEndian) {
  553. const storage = tensor.storage();
  554. const size = tensor.size();
  555. this._name = name || '';
  556. this._type = new pytorch.TensorType(storage.dtype, new pytorch.TensorShape(size));
  557. this._data = storage.data;
  558. this._littleEndian = littleEndian;
  559. }
  560. get kind() {
  561. return 'Tensor';
  562. }
  563. get name() {
  564. return this._name;
  565. }
  566. get type() {
  567. return this._type;
  568. }
  569. get state() {
  570. return this._context().state;
  571. }
  572. get value() {
  573. const context = this._context();
  574. if (context.state) {
  575. return null;
  576. }
  577. context.limit = Number.MAX_SAFE_INTEGER;
  578. return this._decode(context, 0);
  579. }
  580. toString() {
  581. const context = this._context();
  582. if (context.state) {
  583. return '';
  584. }
  585. context.limit = 10000;
  586. const value = this._decode(context, 0);
  587. return pytorch.Tensor._stringify(value, '', ' ');
  588. }
  589. _context() {
  590. const context = {};
  591. context.state = null;
  592. context.index = 0;
  593. context.count = 0;
  594. if (!this._type.dataType) {
  595. context.state = 'Tensor has no data type.';
  596. return context;
  597. }
  598. switch (this._type.dataType) {
  599. case 'boolean':
  600. case 'uint8':
  601. case 'qint8':
  602. case 'int8':
  603. case 'int16':
  604. case 'int32':
  605. case 'int64':
  606. case 'float16':
  607. case 'float32':
  608. case 'float64':
  609. break;
  610. default:
  611. context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
  612. return context;
  613. }
  614. if (!this._type.shape) {
  615. context.state = 'Tensor has no dimensions.';
  616. return context;
  617. }
  618. if (!this._data) {
  619. context.state = 'Tensor data is empty.';
  620. return context;
  621. }
  622. try {
  623. context.data = this._data instanceof Uint8Array ? this._data : this._data.peek();
  624. }
  625. catch (err) {
  626. context.state = err.message;
  627. return context;
  628. }
  629. context.dataType = this._type.dataType;
  630. context.dimensions = this._type.shape.dimensions;
  631. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  632. return context;
  633. }
  634. _decode(context, dimension) {
  635. const results = [];
  636. const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions;
  637. const size = dimensions[dimension];
  638. if (dimension == dimensions.length - 1) {
  639. for (let i = 0; i < size; i++) {
  640. if (context.count > context.limit) {
  641. results.push('...');
  642. return results;
  643. }
  644. switch (context.dataType) {
  645. case 'boolean':
  646. results.push(context.dataView.getUint8(context.index) === 0 ? false : true);
  647. context.index++;
  648. context.count++;
  649. break;
  650. case 'uint8':
  651. results.push(context.dataView.getUint8(context.index));
  652. context.index++;
  653. context.count++;
  654. break;
  655. case 'qint8':
  656. case 'int8':
  657. results.push(context.dataView.getInt8(context.index));
  658. context.index++;
  659. context.count++;
  660. break;
  661. case 'int16':
  662. results.push(context.dataView.getInt16(context.index, this._littleEndian));
  663. context.index += 2;
  664. context.count++;
  665. break;
  666. case 'int32':
  667. results.push(context.dataView.getInt32(context.index, this._littleEndian));
  668. context.index += 4;
  669. context.count++;
  670. break;
  671. case 'int64':
  672. results.push(context.dataView.getInt64(context.index, this._littleEndian));
  673. context.index += 8;
  674. context.count++;
  675. break;
  676. case 'float16':
  677. results.push(context.dataView.getFloat16(context.index, this._littleEndian));
  678. context.index += 2;
  679. context.count++;
  680. break;
  681. case 'float32':
  682. results.push(context.dataView.getFloat32(context.index, this._littleEndian));
  683. context.index += 4;
  684. context.count++;
  685. break;
  686. case 'float64':
  687. results.push(context.dataView.getFloat64(context.index, this._littleEndian));
  688. context.index += 8;
  689. context.count++;
  690. break;
  691. }
  692. }
  693. }
  694. else {
  695. for (let j = 0; j < size; j++) {
  696. if (context.count > context.limit) {
  697. results.push('...');
  698. return results;
  699. }
  700. results.push(this._decode(context, dimension + 1));
  701. }
  702. }
  703. if (context.dimensions.length == 0) {
  704. return results[0];
  705. }
  706. return results;
  707. }
  708. static _stringify(value, indentation, indent) {
  709. if (Array.isArray(value)) {
  710. const result = [];
  711. result.push(indentation + '[');
  712. const items = value.map((item) => pytorch.Tensor._stringify(item, indentation + indent, indent));
  713. if (items.length > 0) {
  714. result.push(items.join(',\n'));
  715. }
  716. result.push(indentation + ']');
  717. return result.join('\n');
  718. }
  719. if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
  720. return indentation + value.toString();
  721. }
  722. if (typeof value == 'string') {
  723. return indentation + value;
  724. }
  725. if (value == Infinity) {
  726. return indentation + 'Infinity';
  727. }
  728. if (value == -Infinity) {
  729. return indentation + '-Infinity';
  730. }
  731. if (isNaN(value)) {
  732. return indentation + 'NaN';
  733. }
  734. return indentation + value.toString();
  735. }
  736. };
  737. pytorch.TensorType = class {
  738. constructor(dtype, shape) {
  739. this._dataType = dtype.__reduce__();
  740. this._shape = shape;
  741. }
  742. get dataType() {
  743. return this._dataType;
  744. }
  745. get shape() {
  746. return this._shape;
  747. }
  748. toString() {
  749. return this._dataType + this._shape.toString();
  750. }
  751. };
  752. pytorch.TensorShape = class {
  753. constructor(dimensions) {
  754. this._dimensions = dimensions || [];
  755. }
  756. get dimensions() {
  757. return this._dimensions;
  758. }
  759. toString() {
  760. if (this._dimensions && this._dimensions.length > 0) {
  761. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  762. }
  763. return '';
  764. }
  765. };
  766. pytorch.Execution = class extends python.Execution {
  767. constructor(sources, exceptionCallback) {
  768. super(sources, exceptionCallback);
  769. this.registerModule('ops');
  770. this.registerModule('ops.torchvision');
  771. this.registerModule('torch');
  772. this.registerModule('torchvision');
  773. this.context.scope.ops._caffe2 = { __name__: 'torch', __class__: this.context.scope.builtins.module };
  774. const self = this;
  775. const torch = this.context.scope.torch;
  776. this.registerType('builtins.number', class {});
  777. this.registerType('__torch__.torch.classes._nnapi.Compilation', class {
  778. constructor() {
  779. this.__hide__ = true;
  780. }
  781. __init__() {
  782. }
  783. init(serialized_model_tensor, parameter_buffers) {
  784. this.serialized_model_tensor = serialized_model_tensor;
  785. this.parameter_buffers = parameter_buffers;
  786. const storage = serialized_model_tensor.storage();
  787. new pytorch.nnapi.SerializedModel(storage.data, parameter_buffers);
  788. }
  789. run(inputs, outputs) {
  790. this.serialized_model_tensor.__variable__ = this.serialized_model_tensor.__variable__ || self.variable();
  791. this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1;
  792. self.push({
  793. type: 'torch.classes._nnapi.Compilation',
  794. attributes: [],
  795. inputs: [
  796. [ { id: this.serialized_model_tensor.__variable__ } ],
  797. inputs.map((input) => { return { id: input.__variable__ }; }),
  798. this.parameter_buffers.map((buffer) => { return { id: buffer.__variable__ }; })
  799. ],
  800. outputs: [
  801. outputs.map((output) => { return { id: output.__variable__ }; })
  802. ],
  803. });
  804. }
  805. });
  806. this.registerType('torch.autograd.variable.Variable', class {});
  807. this.registerType('torch.backends.cudnn.rnn.Unserializable', class {});
  808. this.registerType('torch.distributions.constraints._LowerCholesky', class {});
  809. this.registerType('torch.distributions.constraints._Real', class {});
  810. this.registerType('torch.distributions.multivariate_normal.MultivariateNormal', class {});
  811. this.registerType('torch.distributions.transforms.LowerCholeskyTransform', class {});
  812. this.registerType('torch.nn.backends.thnn._get_thnn_function_backend', class {});
  813. this.registerType('torch.nn.intrinsic.modules.fused.ConvReLU2d', class {});
  814. this.registerType('torch.nn.intrinsic.qat.modules.conv_fused.ConvReLU2d', class {});
  815. this.registerType('torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d', class {});
  816. this.registerType('torch.nn.intrinsic.quantized.modules.linear_relu.LinearReLU', class {});
  817. this.registerType('torch.nn.modules.activation.CELU', class {});
  818. this.registerType('torch.nn.modules.activation.ELU', class {});
  819. this.registerType('torch.nn.modules.activation.GELU', class {});
  820. this.registerType('torch.nn.modules.activation.GLU', class {});
  821. this.registerType('torch.nn.modules.activation.Hardtanh', class {});
  822. this.registerType('torch.nn.modules.activation.Hardswish', class {});
  823. this.registerType('torch.nn.modules.activation.Hardsigmoid', class {});
  824. this.registerType('torch.nn.modules.activation.LeakyReLU', class {});
  825. this.registerType('torch.nn.modules.activation.LogSigmoid', class {});
  826. this.registerType('torch.nn.modules.activation.LogSoftmax', class {});
  827. this.registerType('torch.nn.modules.activation.MultiheadAttention', class {});
  828. this.registerType('torch.nn.modules.activation.ReLU', class {});
  829. this.registerType('torch.nn.modules.activation.ReLU6', class {});
  830. this.registerType('torch.nn.modules.activation.PReLU', class {});
  831. this.registerType('torch.nn.modules.activation.RReLU', class {});
  832. this.registerType('torch.nn.modules.activation.SELU', class {});
  833. this.registerType('torch.nn.modules.activation.Sigmoid', class {});
  834. this.registerType('torch.nn.modules.activation.SiLU', class {});
  835. this.registerType('torch.nn.modules.activation.Softmax', class {});
  836. this.registerType('torch.nn.modules.activation.Softmax2d', class {});
  837. this.registerType('torch.nn.modules.activation.Softplus', class {});
  838. this.registerType('torch.nn.modules.activation.Tanh', class {});
  839. this.registerType('torch.nn.modules.activation.Threshold', class {});
  840. this.registerType('torch.nn.modules.batchnorm.BatchNorm1d', class {});
  841. this.registerType('torch.nn.modules.batchnorm.BatchNorm2d', class {});
  842. this.registerType('torch.nn.modules.batchnorm.BatchNorm3d', class {});
  843. this.registerType('torch.nn.modules.batchnorm.SyncBatchNorm', class {});
  844. this.registerType('torch.nn.modules.container.ModuleDict', class {});
  845. this.registerType('torch.nn.modules.container.ModuleList', class {});
  846. this.registerType('torch.nn.modules.container.ParameterList', class {});
  847. this.registerType('torch.nn.modules.container.Sequential', class {});
  848. this.registerType('torch.nn.modules.conv.Conv1d', class {});
  849. this.registerType('torch.nn.modules.conv.Conv2d', class {});
  850. this.registerType('torch.nn.modules.conv.Conv3d', class {});
  851. this.registerType('torch.nn.modules.conv.ConvTranspose1d', class {});
  852. this.registerType('torch.nn.modules.conv.ConvTranspose2d', class {});
  853. this.registerType('torch.nn.modules.conv.ConvTranspose3d', class {});
  854. this.registerType('torch.nn.modules.distance.CosineSimilarity', class {});
  855. this.registerType('torch.nn.modules.dropout.AlphaDropout', class {});
  856. this.registerType('torch.nn.modules.dropout.Dropout', class {});
  857. this.registerType('torch.nn.modules.dropout.Dropout2d', class {});
  858. this.registerType('torch.nn.modules.dropout.Dropout3d', class {});
  859. this.registerType('torch.nn.modules.fold.Fold', class {});
  860. this.registerType('torch.nn.modules.fold.Unfold', class {});
  861. this.registerType('torch.nn.modules.flatten.Flatten', class {});
  862. this.registerType('torch.nn.modules.flatten.Unflatten', class {});
  863. this.registerType('torch.nn.modules.instancenorm.InstanceNorm1d', class {});
  864. this.registerType('torch.nn.modules.instancenorm.InstanceNorm2d', class {});
  865. this.registerType('torch.nn.modules.instancenorm.InstanceNorm3d', class {});
  866. this.registerType('torch.nn.modules.linear._LinearWithBias', class {});
  867. this.registerType('torch.nn.modules.linear.Bilinear', class {});
  868. this.registerType('torch.nn.modules.linear.LazyLinear', class {});
  869. this.registerType('torch.nn.modules.linear.Linear', class {});
  870. this.registerType('torch.nn.modules.linear.Identity', class {});
  871. this.registerType('torch.nn.modules.loss.BCELoss', class {});
  872. this.registerType('torch.nn.modules.loss.BCEWithLogitsLoss', class {});
  873. this.registerType('torch.nn.modules.loss.CrossEntropyLoss', class {});
  874. this.registerType('torch.nn.modules.loss.CTCLoss', class {});
  875. this.registerType('torch.nn.modules.loss.KLDivLoss', class {});
  876. this.registerType('torch.nn.modules.loss.L1Loss', class {});
  877. this.registerType('torch.nn.modules.loss.MarginRankingLoss', class {});
  878. this.registerType('torch.nn.modules.loss.MSELoss', class {});
  879. this.registerType('torch.nn.modules.loss.NLLLoss', class {});
  880. this.registerType('torch.nn.modules.loss.NLLLoss2d', class {});
  881. this.registerType('torch.nn.modules.loss.SmoothL1Loss', class {});
  882. this.registerType('torch.nn.modules.module._IncompatibleKeys', class {});
  883. this.registerType('torch.nn.modules.module.Module', class {});
  884. this.registerType('torch.nn.modules.normalization.CrossMapLRN2d', class {});
  885. this.registerType('torch.nn.modules.normalization.GroupNorm', class {});
  886. this.registerType('torch.nn.modules.normalization.LayerNorm', class {});
  887. this.registerType('torch.nn.modules.normalization.LocalResponseNorm', class {});
  888. this.registerType('torch.nn.modules.padding.ReflectionPad1d', class {});
  889. this.registerType('torch.nn.modules.padding.ReflectionPad2d', class {});
  890. this.registerType('torch.nn.modules.padding.ReplicationPad1d', class {});
  891. this.registerType('torch.nn.modules.padding.ReplicationPad2d', class {});
  892. this.registerType('torch.nn.modules.padding.ReplicationPad3d', class {});
  893. this.registerType('torch.nn.modules.padding.ZeroPad2d', class {});
  894. this.registerType('torch.nn.modules.padding.ConstantPad1d', class {});
  895. this.registerType('torch.nn.modules.padding.ConstantPad2d', class {});
  896. this.registerType('torch.nn.modules.padding.ConstantPad3d', class {});
  897. this.registerType('torch.nn.modules.pixelshuffle.PixelShuffle', class {});
  898. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool1d', class {});
  899. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool2d', class {});
  900. this.registerType('torch.nn.modules.pooling.AdaptiveAvgPool3d', class {});
  901. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool1d', class {});
  902. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool2d', class {});
  903. this.registerType('torch.nn.modules.pooling.AdaptiveMaxPool3d', class {});
  904. this.registerType('torch.nn.modules.pooling.AvgPool1d', class {});
  905. this.registerType('torch.nn.modules.pooling.AvgPool2d', class {});
  906. this.registerType('torch.nn.modules.pooling.AvgPool3d', class {});
  907. this.registerType('torch.nn.modules.pooling.FractionalMaxPool2d', class {});
  908. this.registerType('torch.nn.modules.pooling.LPPool2d', class {});
  909. this.registerType('torch.nn.modules.pooling.MaxPool1d', class {});
  910. this.registerType('torch.nn.modules.pooling.MaxPool2d', class {});
  911. this.registerType('torch.nn.modules.pooling.MaxPool3d', class {});
  912. this.registerType('torch.nn.modules.pooling.MaxUnpool1d', class {});
  913. this.registerType('torch.nn.modules.pooling.MaxUnpool2d', class {});
  914. this.registerType('torch.nn.modules.pooling.MaxUnpool3d', class {});
  915. this.registerType('torch.nn.modules.rnn.GRU', class {});
  916. this.registerType('torch.nn.modules.rnn.GRUCell', class {});
  917. this.registerType('torch.nn.modules.rnn.LSTM', class {});
  918. this.registerType('torch.nn.modules.rnn.LSTMCell', class {});
  919. this.registerType('torch.nn.modules.rnn.RNN', class {});
  920. this.registerType('torch.nn.modules.sparse.Embedding', class {});
  921. this.registerType('torch.nn.modules.sparse.EmbeddingBag', class {});
  922. this.registerType('torch.nn.modules.transformer.Transformer', class {});
  923. this.registerType('torch.nn.modules.transformer.TransformerDecoder', class {});
  924. this.registerType('torch.nn.modules.transformer.TransformerDecoderLayer', class {});
  925. this.registerType('torch.nn.modules.transformer.TransformerEncoder', class {});
  926. this.registerType('torch.nn.modules.transformer.TransformerEncoderLayer', class {});
  927. this.registerType('torch.nn.modules.upsampling.Upsample', class {});
  928. this.registerType('torch.nn.modules.upsampling.UpsamplingBilinear2d', class {});
  929. this.registerType('torch.nn.modules.upsampling.UpsamplingNearest2d', class {});
  930. this.registerType('torch.nn.parallel.data_parallel.DataParallel', class {});
  931. this.registerType('torch.nn.parallel.distributed.DistributedDataParallel', class {});
  932. this.registerType('torch.nn.qat.modules.conv.Conv2d', class {});
  933. this.registerType('torch.nn.quantized.modules.activation.ReLU', class {});
  934. this.registerType('torch.nn.quantized.modules.activation.ReLU6', class {});
  935. this.registerType('torch.nn.quantized.modules.batchnorm.BatchNorm2d', class {});
  936. this.registerType('torch.nn.quantized.modules.conv.Conv2d', class {});
  937. this.registerType('torch.nn.quantized.modules.conv.ConvTranspose2d', class {});
  938. this.registerType('torch.nn.quantized.modules.DeQuantize', class {});
  939. this.registerType('torch.nn.quantized.modules.functional_modules.FloatFunctional', class {});
  940. this.registerType('torch.nn.quantized.modules.functional_modules.QFunctional', class {});
  941. this.registerType('torch.nn.quantized.modules.linear.Linear', class {});
  942. this.registerType('torch.nn.quantized.modules.linear.LinearPackedParams', class {});
  943. this.registerType('torch.nn.quantized.modules.normalization.InstanceNorm2d', class {});
  944. this.registerType('torch.nn.quantized.modules.Quantize', class {});
  945. this.registerType('torch.nn.utils.prune.L1Unstructured', class {});
  946. this.registerType('torch.nn.utils.spectral_norm.SpectralNorm', class {});
  947. this.registerType('torch.nn.utils.spectral_norm.SpectralNormStateDictHook', class {});
  948. this.registerType('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', class {});
  949. this.registerType('torch.nn.utils.weight_norm.WeightNorm', class {});
  950. this.registerType('torch.optim.adam.Adam', class {});
  951. this.registerType('torch.optim.adamw.AdamW', class {});
  952. this.registerType('torch.optim.adagrad.Adagrad', class {});
  953. this.registerType('torch.optim.adadelta.Adadelta', class {});
  954. this.registerType('torch.optim.lr_scheduler.CosineAnnealingLR', class {});
  955. this.registerType('torch.optim.lr_scheduler.CyclicLR', class {});
  956. this.registerType('torch.optim.lr_scheduler.ExponentialLR', class {});
  957. this.registerType('torch.optim.lr_scheduler.LambdaLR', class {});
  958. this.registerType('torch.optim.lr_scheduler.MultiStepLR', class {});
  959. this.registerType('torch.optim.lr_scheduler.OneCycleLR', class {});
  960. this.registerType('torch.optim.lr_scheduler.ReduceLROnPlateau', class {});
  961. this.registerType('torch.optim.lr_scheduler.StepLR', class {});
  962. this.registerType('torch.optim.optimizer._RequiredParameter', class {});
  963. this.registerType('torch.optim.rmsprop.RMSprop', class {});
  964. this.registerType('torch.optim.sgd.SGD', class {});
  965. this.registerType('torch.quantization.fake_quantize.FakeQuantize', class {});
  966. this.registerType('torch.quantization.observer._PartialWrapper', class {});
  967. this.registerType('torch.quantization.observer.MinMaxObserver', class {});
  968. this.registerType('torch.quantization.QConfig.QConfig', class {});
  969. this.registerType('torch.quantization.stubs.DeQuantStub', class {});
  970. this.registerType('torch.quantization.stubs.QuantStub', class {});
  971. this.registerType('torch.utils.data.dataloader.DataLoader', class {});
  972. this.registerType('torch.utils.data.dataset.Subset', class {});
  973. this.registerType('torch.utils.data.dataset.ConcatDataset', class {});
  974. this.registerType('torch.utils.data.sampler.BatchSampler', class {});
  975. this.registerType('torch.utils.data.sampler.RandomSampler', class {});
  976. this.registerType('torch.utils.data.sampler.SequentialSampler', class {});
  977. this.registerType('torchvision.datasets.folder.ImageFolder', class {});
  978. this.registerType('torchvision.datasets.mnist.MNIST', class {});
  979. this.registerType('torchvision.datasets.vision.StandardTransform', class {});
  980. this.registerType('torchvision.models.alexnet.AlexNet', class {});
  981. this.registerType('torchvision.models.densenet.DenseNet', class {});
  982. this.registerType('torchvision.models.densenet._DenseBlock', class {});
  983. this.registerType('torchvision.models.densenet._DenseLayer', class {});
  984. this.registerType('torchvision.models.densenet._Transition', class {});
  985. this.registerType('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', class {});
  986. this.registerType('torchvision.models.detection._utils.BoxCoder', class {});
  987. this.registerType('torchvision.models.detection._utils.Matcher', class {});
  988. this.registerType('torchvision.models.detection._utils.SSDMatcher', class {});
  989. this.registerType('torchvision.models.detection.anchor_utils.AnchorGenerator', class {});
  990. this.registerType('torchvision.models.detection.anchor_utils.DefaultBoxGenerator', class {});
  991. this.registerType('torchvision.models.detection.backbone_utils.BackboneWithFPN', class {});
  992. this.registerType('torchvision.models.detection.faster_rcnn.FasterRCNN', class {});
  993. this.registerType('torchvision.models.detection.faster_rcnn.FastRCNNPredictor', class {});
  994. this.registerType('torchvision.models.detection.faster_rcnn.TwoMLPHead', class {});
  995. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNN', class {});
  996. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNHeads', class {});
  997. this.registerType('torchvision.models.detection.keypoint_rcnn.KeypointRCNNPredictor', class {});
  998. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNN', class {});
  999. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNHeads', class {});
  1000. this.registerType('torchvision.models.detection.mask_rcnn.MaskRCNNPredictor', class {});
  1001. this.registerType('torchvision.models.detection.retinanet.RetinaNetClassificationHead', class {});
  1002. this.registerType('torchvision.models.detection.retinanet.RetinaNetHead', class {});
  1003. this.registerType('torchvision.models.detection.retinanet.RetinaNetRegressionHead', class {});
  1004. this.registerType('torchvision.models.detection.roi_heads.RoIHeads', class {});
  1005. this.registerType('torchvision.models.detection.rpn.AnchorGenerator', class {});
  1006. this.registerType('torchvision.models.detection.rpn.RegionProposalNetwork', class {});
  1007. this.registerType('torchvision.models.detection.rpn.RPNHead', class {});
  1008. this.registerType('torchvision.models.detection.ssd.SSD', class {});
  1009. this.registerType('torchvision.models.detection.ssdlite.SSDLiteClassificationHead', class {});
  1010. this.registerType('torchvision.models.detection.ssdlite.SSDLiteFeatureExtractorMobileNet', class {});
  1011. this.registerType('torchvision.models.detection.ssdlite.SSDLiteHead', class {});
  1012. this.registerType('torchvision.models.detection.ssdlite.SSDLiteRegressionHead', class {}); this.registerType('torchvision.models.detection.transform.GeneralizedRCNNTransform', class {});
  1013. this.registerType('torchvision.models.googlenet.BasicConv2d', class {});
  1014. this.registerType('torchvision.models.googlenet.GoogLeNet', class {});
  1015. this.registerType('torchvision.models.googlenet.Inception', class {});
  1016. this.registerType('torchvision.models.googlenet.InceptionAux', class {});
  1017. this.registerType('torchvision.models.inception.BasicConv2d', class {});
  1018. this.registerType('torchvision.models.inception.Inception3', class {});
  1019. this.registerType('torchvision.models.inception.InceptionAux', class {});
  1020. this.registerType('torchvision.models.inception.InceptionA', class {});
  1021. this.registerType('torchvision.models.inception.InceptionB', class {});
  1022. this.registerType('torchvision.models.inception.InceptionC', class {});
  1023. this.registerType('torchvision.models.inception.InceptionD', class {});
  1024. this.registerType('torchvision.models.inception.InceptionE', class {});
  1025. this.registerType('torchvision.models.mnasnet._InvertedResidual', class {});
  1026. this.registerType('torchvision.models.mnasnet.MNASNet', class {});
  1027. this.registerType('torchvision.models.mobilenet.ConvBNReLU', class {});
  1028. this.registerType('torchvision.models.mobilenet.MobileNetV2', class {});
  1029. this.registerType('torchvision.models.mobilenet.InvertedResidual', class {});
  1030. this.registerType('torchvision.models.mobilenetv2.ConvBNActivation', class {});
  1031. this.registerType('torchvision.models.mobilenetv2.InvertedResidual', class {});
  1032. this.registerType('torchvision.models.mobilenetv2.MobileNetV2', class {});
  1033. this.registerType('torchvision.models.mobilenetv3.InvertedResidual', class {});
  1034. this.registerType('torchvision.models.mobilenetv3.MobileNetV3', class {});
  1035. this.registerType('torchvision.models.mobilenetv3.SqueezeExcitation', class {});
  1036. this.registerType('torchvision.models.resnet.Bottleneck', class {});
  1037. this.registerType('torchvision.models.resnet.BasicBlock', class {});
  1038. this.registerType('torchvision.models.quantization.mobilenet.QuantizableInvertedResidual', class {});
  1039. this.registerType('torchvision.models.quantization.mobilenet.QuantizableMobileNetV2', class {});
  1040. this.registerType('torchvision.models.quantization.resnet.QuantizableBasicBlock', class {});
  1041. this.registerType('torchvision.models.quantization.resnet.QuantizableBottleneck', class {});
  1042. this.registerType('torchvision.models.quantization.resnet.QuantizableResNet', class {});
  1043. this.registerType('torchvision.models.segmentation.deeplabv3.ASPP', class {});
  1044. this.registerType('torchvision.models.segmentation.deeplabv3.ASPPConv', class {});
  1045. this.registerType('torchvision.models.segmentation.deeplabv3.ASPPPooling', class {});
  1046. this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabHead', class {});
  1047. this.registerType('torchvision.models.segmentation.deeplabv3.DeepLabV3', class {});
  1048. this.registerType('torchvision.models.segmentation.fcn.FCN', class {});
  1049. this.registerType('torchvision.models.segmentation.fcn.FCNHead', class {});
  1050. this.registerType('torchvision.models.shufflenetv2.ShuffleNetV2', class {});
  1051. this.registerType('torchvision.models.shufflenetv2.InvertedResidual', class {});
  1052. this.registerType('torchvision.models.squeezenet.Fire', class {});
  1053. this.registerType('torchvision.models.squeezenet.SqueezeNet', class {});
  1054. this.registerType('torchvision.models.resnet.ResNet', class {});
  1055. this.registerType('torchvision.models.vgg.VGG', class {});
  1056. this.registerType('torchvision.models.video.resnet.BasicBlock', class {});
  1057. this.registerType('torchvision.models.video.resnet.BasicStem', class {});
  1058. this.registerType('torchvision.models.video.resnet.Conv2Plus1D', class {});
  1059. this.registerType('torchvision.models.video.resnet.Conv3DNoTemporal', class {});
  1060. this.registerType('torchvision.models.video.resnet.Conv3DSimple', class {});
  1061. this.registerType('torchvision.models.video.resnet.R2Plus1dStem', class {});
  1062. this.registerType('torchvision.models.video.resnet.VideoResNet', class {});
  1063. this.registerType('torchvision.models._utils.IntermediateLayerGetter', class {});
  1064. this.registerType('torchvision.ops.deform_conv.DeformConv2d', class {});
  1065. this.registerType('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', class {});
  1066. this.registerType('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', class {});
  1067. this.registerType('torchvision.ops.feature_pyramid_network.LastLevelP6P7', class {});
  1068. this.registerType('torchvision.ops.misc.ConvTranspose2d', class {});
  1069. this.registerType('torchvision.ops.misc.FrozenBatchNorm2d', class {});
  1070. this.registerType('torchvision.ops.poolers.LevelMapper', class {});
  1071. this.registerType('torchvision.ops.poolers.MultiScaleRoIAlign', class {});
  1072. this.registerType('torchvision.transforms.functional.InterpolationMode', class {});
  1073. this.registerType('torchvision.transforms.transforms.Compose', class {});
  1074. this.registerType('torchvision.transforms.transforms.Grayscale', class {});
  1075. this.registerType('torchvision.transforms.transforms.Normalize', class {});
  1076. this.registerType('torchvision.transforms.transforms.RandomAffine', class {});
  1077. this.registerType('torchvision.transforms.transforms.Resize', class {});
  1078. this.registerType('torchvision.transforms.transforms.ToPILImage', class {});
  1079. this.registerType('torchvision.transforms.transforms.ToTensor', class {});
  1080. this.registerFunction('annotate', function(type, value) {
  1081. if (type === self.context.scope.builtins.int) {
  1082. return Number.isInteger(value) ? value : NaN;
  1083. }
  1084. if (type === self.context.scope.builtins.float) {
  1085. return typeof value === 'number' ? value : NaN;
  1086. }
  1087. if (type === self.context.scope.builtins.number) {
  1088. if (pytorch.Utility.isTensor(value)) {
  1089. value.resize_([]);
  1090. }
  1091. }
  1092. return value;
  1093. });
  1094. this.registerFunction('int', function(value) {
  1095. if (pytorch.Utility.isTensor(value)) {
  1096. const storage = value.storage();
  1097. if (storage && storage.dtype.__reduce__() === 'int64' && storage.data.length === 8) {
  1098. const buffer = storage.data;
  1099. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1100. return view.getInt64(0, true);
  1101. }
  1102. }
  1103. if (Number.isInteger(value)) {
  1104. return value;
  1105. }
  1106. return NaN;
  1107. });
  1108. this.registerFunction('float', function(value) {
  1109. if (pytorch.Utility.isTensor(value)) {
  1110. const storage = value.storage();
  1111. if (storage && storage.dtype.__reduce__() === 'float32') {
  1112. if (storage.size() !== undefined && storage.data.length === 4) {
  1113. const buffer = storage.data;
  1114. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1115. return view.getFloat32(0, true);
  1116. }
  1117. return NaN;
  1118. }
  1119. }
  1120. if (Number(value) === value) {
  1121. return value;
  1122. }
  1123. return NaN;
  1124. });
  1125. this.registerFunction('str', function(value) {
  1126. return JSON.stringify(value);
  1127. });
  1128. this.registerFunction('unchecked_cast', function(type, value) {
  1129. return value;
  1130. });
  1131. this.registerFunction('ops.prim.data', function(tensor) {
  1132. return tensor;
  1133. });
  1134. this.registerFunction('ops.prim.device', function(tensor) {
  1135. return tensor.device;
  1136. });
  1137. this.registerFunction('ops.prim.dtype', function(tensor) {
  1138. return tensor.dtype.scalar_type();
  1139. });
  1140. this.registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
  1141. return value;
  1142. });
  1143. this.registerFunction('ops.prim.NumToTensor', function(value) {
  1144. const tensor = self.invoke('torch.Tensor', []);
  1145. tensor.value = value; // TODO
  1146. return tensor;
  1147. });
  1148. this.registerFunction('ops.prim.min', function(value) {
  1149. if (Array.isArray(value)) {
  1150. return Math.min.apply(null, value);
  1151. }
  1152. return Math.min.apply(null, arguments);
  1153. });
  1154. this.registerFunction('ops.prim.shape', function(tensor) {
  1155. return tensor && tensor.size ? tensor.size() : undefined;
  1156. });
  1157. this.registerFunction('ops.quantized.conv_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1158. return {
  1159. __class__: {
  1160. __module__: '__torch__.torch.classes.quantized',
  1161. __name__: 'Conv2dPackedParamsBase'
  1162. },
  1163. weight: weight,
  1164. bias: bias,
  1165. stride: stride,
  1166. padding: padding,
  1167. dilation: dilation,
  1168. groups: groups
  1169. };
  1170. });
  1171. this.registerFunction('ops.quantized.conv1d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1172. return {
  1173. __class__: {
  1174. __module__: '__torch__.torch.classes.quantized',
  1175. __name__: 'Conv2dPackedParamsBase'
  1176. },
  1177. weight: weight,
  1178. bias: bias,
  1179. stride: stride,
  1180. padding: padding,
  1181. dilation: dilation,
  1182. groups: groups
  1183. };
  1184. });
  1185. this.registerFunction('ops.quantized.conv2d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1186. return {
  1187. __class__: {
  1188. __module__: '__torch__.torch.classes.quantized',
  1189. __name__: 'Conv2dPackedParamsBase'
  1190. },
  1191. weight: weight,
  1192. bias: bias,
  1193. stride: stride,
  1194. padding: padding,
  1195. dilation: dilation,
  1196. groups: groups
  1197. };
  1198. });
  1199. this.registerFunction('ops.quantized.conv3d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1200. return {
  1201. __class__: {
  1202. __module__: '__torch__.torch.classes.quantized',
  1203. __name__: 'Conv3dPackedParamsBase'
  1204. },
  1205. weight: weight,
  1206. bias: bias,
  1207. stride: stride,
  1208. padding: padding,
  1209. dilation: dilation,
  1210. groups: groups
  1211. };
  1212. });
  1213. this.registerFunction('ops.quantized.conv_transpose2d_prepack', function(weight, bias, stride, padding, dilation, groups) {
  1214. return {
  1215. __class__: {
  1216. __module__: '__torch__.torch.classes.quantized',
  1217. __name__: 'Conv2dPackedParamsBase'
  1218. },
  1219. weight: weight,
  1220. bias: bias,
  1221. stride: stride,
  1222. padding: padding,
  1223. dilation: dilation,
  1224. groups: groups
  1225. };
  1226. });
  1227. this.registerFunction('ops.quantized.linear_prepack', function(weight, bias) {
  1228. return {
  1229. __class__: {
  1230. __module__: '__torch__.torch.classes.quantized',
  1231. __name__: 'LinearPackedParamsBase'
  1232. },
  1233. weight: weight,
  1234. bias: bias
  1235. };
  1236. });
  1237. this.registerFunction('ops.prim.RaiseException', function(message) {
  1238. throw new pytorch.Error(message);
  1239. });
  1240. this.registerFunction('range', function(start, stop, step) {
  1241. if (start !== undefined && Number.isInteger(start) && stop === undefined && step === undefined) {
  1242. return Array(start).keys();
  1243. }
  1244. throw new pytorch.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
  1245. });
  1246. this.registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
  1247. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1248. const tensor = self.invoke(name, []);
  1249. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1250. return tensor;
  1251. });
  1252. this.registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
  1253. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1254. const tensor = self.invoke(name, []);
  1255. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1256. tensor.requires_grad = requires_grad;
  1257. tensor.backward_hooks = backward_hooks;
  1258. return tensor;
  1259. });
  1260. this.registerFunction('torch._utils._rebuild_parameter', function(data, requires_grad, backward_hooks) {
  1261. const obj = self.invoke('torch.nn.parameter.Parameter', [ data, requires_grad ]);
  1262. obj.backward_hooks = backward_hooks;
  1263. return obj;
  1264. });
  1265. this.registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
  1266. const name = storage.__class__.__module__ + '.' + storage.__class__.__name__.replace('Storage', 'Tensor');
  1267. const tensor = self.invoke(name, []);
  1268. tensor.__setstate__([ storage, storage_offset, size, stride ]);
  1269. tensor.quantizer_params = quantizer_params;
  1270. tensor.requires_grad = requires_grad;
  1271. tensor.backward_hooks = backward_hooks;
  1272. return tensor;
  1273. });
  1274. this.registerFunction('torch._set_item', function(dict, key, value) {
  1275. dict[key] = value;
  1276. });
  1277. this.registerFunction('torch.__and__', function(left, right) {
  1278. return left && right;
  1279. });
  1280. this.registerFunction('torch.__contains__', function(dict, key) {
  1281. return dict[key] !== undefined;
  1282. });
  1283. this.registerFunction('torch.__derive_index', function(index, start, step) {
  1284. return start + index * step;
  1285. });
  1286. this.registerFunction('torch.__is__', function(left, right) {
  1287. if (left === null && right === null) {
  1288. return true;
  1289. }
  1290. if ((left !== null && right === null) || (left === null && right !== null)) {
  1291. return false;
  1292. }
  1293. throw new pytorch.Error("Unknown 'torch.__is__' expression type.");
  1294. });
  1295. this.registerFunction('torch.__isnot__', function(left, right) {
  1296. if (left === null && right === null) {
  1297. return false;
  1298. }
  1299. if ((left !== null && right === null) || (left === null && right !== null)) {
  1300. return true;
  1301. }
  1302. throw new pytorch.Error("Unknown 'torch.__isnot__' expression type.");
  1303. });
  1304. this.registerFunction('torch.__not__', function(value) {
  1305. if (typeof value === 'boolean') {
  1306. return !value;
  1307. }
  1308. throw new pytorch.Error("Unknown 'torch.__not__' expression type.");
  1309. });
  1310. this.registerFunction('torch.__range_length', function(lo, hi, step) {
  1311. if (step === 0) {
  1312. throw new pytorch.Error('range() arg 3 must not be zero');
  1313. }
  1314. if (step > 0 && lo < hi) {
  1315. return 1 + (hi - 1 - lo) / step;
  1316. }
  1317. else if (step < 0 && lo > hi) {
  1318. return 1 + (lo - 1 - hi) / (0 - step);
  1319. }
  1320. return 0;
  1321. });
  1322. this.registerFunction('torch._unwrap_optional', function(value) {
  1323. return value; // TODO
  1324. });
  1325. this.registerFunction('torch.add', function(left, right) {
  1326. if (typeof left === 'number' && typeof right === 'number') {
  1327. return left * right;
  1328. }
  1329. if (Array.isArray(left) && Array.isArray(right)) {
  1330. return left.concat(right);
  1331. }
  1332. if (typeof left === 'string' && typeof right === 'string') {
  1333. return left + right;
  1334. }
  1335. throw new pytorch.Error('Unknown torch.add expression type.');
  1336. });
  1337. this.registerFunction('torch.append', function(list, value) {
  1338. list.push(value);
  1339. return value;
  1340. });
  1341. this.registerFunction('torch.extend', function(list, value) {
  1342. list.push(...value);
  1343. });
  1344. this.registerFunction('torch.insert', function(list, index, value) {
  1345. list.splice(index, 0, value);
  1346. return value;
  1347. });
  1348. this.registerFunction('torch.clear', function(value) {
  1349. if (Object(value) === value) {
  1350. for (const key of Object.keys(value)) {
  1351. delete value[key];
  1352. }
  1353. }
  1354. });
  1355. this.registerFunction('torch.dict', function(args) {
  1356. const obj = {};
  1357. if (args) {
  1358. if (Array.isArray(args)) {
  1359. for (const pair of args) {
  1360. const key = pair[0];
  1361. const value = pair[1];
  1362. obj[key] = value;
  1363. }
  1364. }
  1365. else {
  1366. throw new pytorch.Error("'torch.dict' arguments not supported.");
  1367. }
  1368. }
  1369. return obj;
  1370. });
  1371. this.registerFunction('torch.dim', function(tensor) {
  1372. if (tensor && tensor.size) {
  1373. const size = tensor.size();
  1374. if (size) {
  1375. return size.length;
  1376. }
  1377. }
  1378. return undefined; // TODO
  1379. });
  1380. this.registerFunction('torch.numel', function(tensor) {
  1381. if (tensor && tensor.size) {
  1382. const size = tensor.size();
  1383. if (size) {
  1384. return size.reduce((a, b) => a * b, 1);
  1385. }
  1386. }
  1387. return NaN;
  1388. });
  1389. this.registerFunction('torch.eq', function(left, right) {
  1390. if (typeof left === 'string' && typeof right === 'string') {
  1391. return left === right;
  1392. }
  1393. if (typeof left === 'number' && typeof right === 'number') {
  1394. return left === right;
  1395. }
  1396. if (left === undefined || right === undefined) {
  1397. return true;
  1398. }
  1399. throw new pytorch.Error("Unknown 'torch.eq' expression type.");
  1400. });
  1401. this.registerFunction('torch.floor', function(value) {
  1402. return Math.floor(value);
  1403. });
  1404. this.registerFunction('torch.ceil', function(value) {
  1405. return Math.ceil(value);
  1406. });
  1407. this.registerFunction('torch.floordiv', function(left, right) {
  1408. return Math.floor(left / right);
  1409. });
  1410. this.registerFunction('torch.format', function() {
  1411. const args = Array.from(arguments);
  1412. const list = args.shift().split(/({}D?)/);
  1413. return list.map((text) => {
  1414. if (text === '{}' || text === '{}D') {
  1415. const arg = args.shift();
  1416. return Array.isArray(arg) ? '[' + arg.map((item) => item.toString()).join(', ') + ']' : arg ? arg.toString() : '?';
  1417. }
  1418. return text;
  1419. }).join('');
  1420. });
  1421. this.registerFunction('torch.gt', function(left, right) {
  1422. if (typeof left === 'number' && typeof right === 'number') {
  1423. if (!isNaN(left) && !isNaN(right)) {
  1424. return left > right;
  1425. }
  1426. }
  1427. if (isNaN(left) && !isNaN(right)) {
  1428. return true;
  1429. }
  1430. throw new pytorch.Error("Unknown 'torch.gt' expression type.");
  1431. });
  1432. this.registerFunction('torch.ge', function(left, right) {
  1433. if (typeof left === 'number' && typeof right === 'number') {
  1434. if (!isNaN(left) && !isNaN(right)) {
  1435. return left > right;
  1436. }
  1437. }
  1438. if (isNaN(left) && !isNaN(right)) {
  1439. return true;
  1440. }
  1441. throw new pytorch.Error("Unknown 'torch.ge' expression type.");
  1442. });
  1443. this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
  1444. return data;
  1445. });
  1446. this.registerFunction('torch.jit._pickle.build_doublelist', function(data) {
  1447. return data;
  1448. });
  1449. this.registerFunction('torch.jit._pickle.build_intlist', function(data) {
  1450. return data;
  1451. });
  1452. this.registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
  1453. return data;
  1454. });
  1455. this.registerFunction('torch.jit._pickle.build_tensor_from_id', function(data) {
  1456. const constants = self.context.getx('CONSTANTS');
  1457. return constants['c' + data.toString()];
  1458. });
  1459. this.registerFunction('torch.jit._pickle.restore_type_tag', function(value /*, type_str */) {
  1460. return value;
  1461. });
  1462. this.registerFunction('torch.keys', function(dict) {
  1463. return Object.keys(dict);
  1464. });
  1465. this.registerFunction('torch.len', function(value) {
  1466. if (Array.isArray(value)) {
  1467. return value.length;
  1468. }
  1469. if (value && value.__len__) {
  1470. return value.__len__();
  1471. }
  1472. return NaN;
  1473. });
  1474. this.registerFunction('torch.le', function(left, right) {
  1475. if (typeof left === 'number' && typeof right === 'number') {
  1476. if (isNaN(left) || isNaN(right)) {
  1477. return false;
  1478. }
  1479. return left <= right;
  1480. }
  1481. if (left === undefined || right === undefined) {
  1482. return true;
  1483. }
  1484. throw new pytorch.Error("Unknown 'torch.le' expression type.");
  1485. });
  1486. this.registerFunction('torch.list', function(args) {
  1487. return args;
  1488. });
  1489. this.registerFunction('torch.list_with_default', function(size /*, defaults */) {
  1490. return size;
  1491. });
  1492. this.registerFunction('torch.lt', function(left, right) {
  1493. if (typeof left === 'number' && typeof right === 'number') {
  1494. return left < right;
  1495. }
  1496. throw new pytorch.Error("Unknown 'torch.lt' expression type.");
  1497. });
  1498. this.registerFunction('torch.mul', function(left, right) {
  1499. if (typeof left === 'number' && typeof right === 'number') {
  1500. return left * right;
  1501. }
  1502. if (isNaN(left) || isNaN(right)) {
  1503. return NaN;
  1504. }
  1505. if (Array.isArray(left) && left.every((value) => typeof value === 'number') && typeof right === 'number') {
  1506. return left.map((value) => value * right);
  1507. }
  1508. throw new pytorch.Error("Unknown 'torch.mul' expression type.");
  1509. });
  1510. this.registerFunction('torch.div', function(left, right) {
  1511. if (typeof left === 'number' && typeof right === 'number') {
  1512. return left / right;
  1513. }
  1514. if (isNaN(left) || isNaN(right)) {
  1515. return NaN;
  1516. }
  1517. throw new pytorch.Error("Unknown 'torch.div' expression type.");
  1518. });
  1519. this.registerFunction('torch.remainder', function(left, right) {
  1520. if (typeof left === 'number' && typeof right === 'number') {
  1521. return left % right;
  1522. }
  1523. if (isNaN(left) || isNaN(right)) {
  1524. return NaN;
  1525. }
  1526. throw new pytorch.Error("Unknown 'torch.remainder' expression type.");
  1527. });
  1528. this.registerFunction('torch.ne', function(left, right) {
  1529. if (typeof left === 'number' && typeof right === 'number') {
  1530. if (isNaN(left) || isNaN(right)) {
  1531. return false;
  1532. }
  1533. return left !== right;
  1534. }
  1535. if (Array.isArray(left) && Array.isArray(right) && left.length === right.length) {
  1536. return false;
  1537. }
  1538. if (typeof left === 'string' && typeof right === 'string') {
  1539. return left !== right;
  1540. }
  1541. if (left === undefined || right === undefined) {
  1542. return true;
  1543. }
  1544. throw new pytorch.Error("Unknown 'torch.ne' expression type.");
  1545. });
  1546. this.registerFunction('torch.neg', function(value) {
  1547. if (typeof value === 'number') {
  1548. return -value;
  1549. }
  1550. throw new pytorch.Error("Unknown 'torch.neg' expression type.");
  1551. });
  1552. this.registerFunction('torch.q_scale', function(/* tensor */) {
  1553. return -1; // TODO
  1554. });
  1555. this.registerFunction('torch.t', function(tensor) {
  1556. return tensor;
  1557. });
  1558. this.registerFunction('torch.size', function(tensor, dim) {
  1559. if (tensor && tensor.size) {
  1560. const size = tensor.size();
  1561. if (Array.isArray(size)) {
  1562. if (dim === undefined) {
  1563. return size;
  1564. }
  1565. if (Number.isInteger(dim)) {
  1566. if (dim >= 0 && dim < size.length) {
  1567. return size[dim];
  1568. }
  1569. if (dim < 0 && -dim < size.length) {
  1570. return size[size.length + dim];
  1571. }
  1572. }
  1573. throw new pytorch.Error('Dimension out of range (expected to be in range of ' + JSON.stringify(size) + ', but got ' + JSON.stringify(dim) + ').');
  1574. }
  1575. }
  1576. return NaN;
  1577. });
  1578. this.registerFunction('torch.slice', function(l, start, end, step) {
  1579. if (step !== 1) {
  1580. throw new pytorch.Error('Slicing only supports step=1');
  1581. }
  1582. start = Math.max(0, start >= 0 ? start : l.length + start);
  1583. end = Math.min(l.length, end);
  1584. return l.slice(start, end);
  1585. });
  1586. this.registerFunction('torch.sub', function(left, right) {
  1587. if (typeof left === 'number' && typeof right === 'number') {
  1588. return left - right;
  1589. }
  1590. throw new pytorch.Error("Unknown 'torch.sub' expression type.");
  1591. });
  1592. this.registerFunction('torch.values', function(dict) {
  1593. return Object.keys(dict).map((key) => dict[key]);
  1594. });
  1595. this.registerFunction('torch.warn', function() {
  1596. });
  1597. this.registerFunction('uninitialized', function(/* type */) {
  1598. return undefined;
  1599. });
  1600. this.registerType('torch.device', class {
  1601. constructor(type, index) {
  1602. this.type = type;
  1603. if (index) {
  1604. this.index = index;
  1605. }
  1606. }
  1607. });
  1608. this.registerType('torch.dtype', class {
  1609. constructor(type) {
  1610. this._type = type;
  1611. this._data = pytorch.Utility.getScalarType(type);
  1612. }
  1613. scalar_type() {
  1614. return this._type;
  1615. }
  1616. itemsize() {
  1617. return this._data.itemsize;
  1618. }
  1619. __reduce__() {
  1620. return this._data.name;
  1621. }
  1622. __str__() {
  1623. return 'torch.' + this._data.name;
  1624. }
  1625. });
  1626. this.registerType('torch.utils.hooks.RemovableHandle', class {
  1627. __setstate__(state) {
  1628. this.hooks_dict_ref = state[0] || new Map();
  1629. this.id = state[1];
  1630. }
  1631. });
  1632. this.registerType('torch.storage._StorageBase', class {
  1633. constructor(size, dtype) {
  1634. this._size = size;
  1635. this._dtype = dtype;
  1636. }
  1637. get dtype() {
  1638. return this._dtype;
  1639. }
  1640. get data() {
  1641. return this._cdata;
  1642. }
  1643. element_size() {
  1644. return this._dtype.element_size;
  1645. }
  1646. size() {
  1647. return this._size;
  1648. }
  1649. _set_cdata(data) {
  1650. const length = this.size() * this.dtype.itemsize();
  1651. if (length !== data.length) {
  1652. throw new pytorch.Error('Storage data size mismatch.');
  1653. }
  1654. this._cdata = data;
  1655. }
  1656. _set_from_file(unpickler) {
  1657. const size = unpickler.int64();
  1658. if (size !== this.size()) {
  1659. throw new pytorch.Error('Storage size mismatch.');
  1660. }
  1661. const itemsize = this.dtype.itemsize();
  1662. const data = unpickler.stream(itemsize * size);
  1663. this._set_cdata(data);
  1664. }
  1665. static _new_with_file(unpickler) {
  1666. const size = unpickler.int64();
  1667. const storage = new this(size);
  1668. const itemsize = storage.dtype.itemsize();
  1669. const data = unpickler.stream(itemsize * size);
  1670. storage._set_cdata(data);
  1671. return storage;
  1672. }
  1673. });
  1674. this.registerType('torch.BoolStorage', class extends torch.storage._StorageBase {
  1675. constructor(size) {
  1676. super(size, torch.bool);
  1677. }
  1678. });
  1679. this.registerType('torch.ByteStorage', class extends torch.storage._StorageBase {
  1680. constructor(size) {
  1681. super(size, torch.uint8);
  1682. }
  1683. });
  1684. this.registerType('torch.CharStorage', class extends torch.storage._StorageBase {
  1685. constructor(size) {
  1686. super(size, torch.int8);
  1687. }
  1688. });
  1689. this.registerType('torch.ShortStorage', class extends torch.storage._StorageBase {
  1690. constructor(size) {
  1691. super(size, torch.int16);
  1692. }
  1693. });
  1694. this.registerType('torch.IntStorage', class extends torch.storage._StorageBase {
  1695. constructor(size) {
  1696. super(size, torch.int32);
  1697. }
  1698. });
  1699. this.registerType('torch.LongStorage', class extends torch.storage._StorageBase {
  1700. constructor(size) {
  1701. super(size, torch.int64);
  1702. }
  1703. });
  1704. this.registerType('torch.HalfStorage', class extends torch.storage._StorageBase {
  1705. constructor(size) {
  1706. super(size, torch.float16);
  1707. }
  1708. });
  1709. this.registerType('torch.FloatStorage', class extends torch.storage._StorageBase {
  1710. constructor(size) {
  1711. super(size, torch.float32);
  1712. }
  1713. });
  1714. this.registerType('torch.DoubleStorage', class extends torch.storage._StorageBase {
  1715. constructor(size) {
  1716. super(size, torch.float64);
  1717. }
  1718. });
  1719. this.registerType('torch.QInt8Storage', class extends torch.storage._StorageBase {
  1720. constructor(size) {
  1721. super(size, torch.qint8);
  1722. }
  1723. });
  1724. this.registerType('torch.QUInt8Storage', class extends torch.storage._StorageBase {
  1725. constructor(size) {
  1726. super(size, torch.quint8);
  1727. }
  1728. });
  1729. this.registerType('torch.Tensor', class {
  1730. constructor() {
  1731. }
  1732. get dtype() {
  1733. return this.storage().dtype;
  1734. }
  1735. get shape() {
  1736. return this._shape;
  1737. }
  1738. size() {
  1739. return this._shape;
  1740. }
  1741. storage() {
  1742. if (!this._storage) {
  1743. const name = this.__class__.__name__ == 'Tensor' ? 'FloatStorage' : this.__storage__.__name__.replace('Tensor', 'Storage');
  1744. this._storage = self.invoke(this.__class__.__module__ + '.' + name, []);
  1745. }
  1746. return this._storage;
  1747. }
  1748. storage_offset() {
  1749. return this._storage_offset;
  1750. }
  1751. stride() {
  1752. return this._stride;
  1753. }
  1754. resize_(shape) {
  1755. this._shape = shape;
  1756. }
  1757. __len__() {
  1758. return this._shape[0];
  1759. }
  1760. __setstate__(state) {
  1761. this._storage = state[0];
  1762. this._storage_offset = state[1];
  1763. this._shape = state[2];
  1764. this._stride = state[3];
  1765. }
  1766. tolist() {
  1767. }
  1768. });
  1769. this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor {
  1770. constructor(data, requires_grad) {
  1771. super();
  1772. if (!data) {
  1773. data = self.invoke('torch.Tensor', [[]]);
  1774. }
  1775. this.data = data;
  1776. this.requires_grad = requires_grad !== undefined ? requires_grad : true;
  1777. }
  1778. __setstate__(state) {
  1779. switch (state.length) {
  1780. case 4:
  1781. this.data = state[0];
  1782. break;
  1783. case 5:
  1784. this.data = state[0];
  1785. break;
  1786. }
  1787. }
  1788. });
  1789. this.registerType('torch.nn.parameter.UninitializedParameter', class extends torch.nn.parameter.Parameter {
  1790. constructor(requires_grad /*, device, dtype */) {
  1791. super(undefined, requires_grad);
  1792. }
  1793. });
  1794. this.registerType('torch.BoolTensor', class extends torch.Tensor {});
  1795. this.registerType('torch.ByteTensor', class extends torch.Tensor {});
  1796. this.registerType('torch.CharTensor', class extends torch.Tensor {});
  1797. this.registerType('torch.ShortTensor', class extends torch.Tensor {});
  1798. this.registerType('torch.IntTensor', class extends torch.Tensor {});
  1799. this.registerType('torch.LongTensor', class extends torch.Tensor {});
  1800. this.registerType('torch.HalfTensor', class extends torch.Tensor {});
  1801. this.registerType('torch.FloatTensor', class extends torch.Tensor {});
  1802. this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
  1803. this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
  1804. this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
  1805. this.registerType('torch.cuda.FloatTensor', class extends torch.Tensor {});
  1806. this.registerType('torch.cuda.DoubleTensor', class extends torch.Tensor {});
  1807. torch.uint8 = new torch.dtype(pytorch.ScalarType.uint8);
  1808. torch.int8 = new torch.dtype(pytorch.ScalarType.int8);
  1809. torch.int16 = new torch.dtype(pytorch.ScalarType.int16);
  1810. torch.int32 = new torch.dtype(pytorch.ScalarType.int32);
  1811. torch.int64 = new torch.dtype(pytorch.ScalarType.int64);
  1812. torch.float16 = new torch.dtype(pytorch.ScalarType.float16);
  1813. torch.float32 = new torch.dtype(pytorch.ScalarType.float32);
  1814. torch.float64 = new torch.dtype(pytorch.ScalarType.float64);
  1815. torch.complex32 = new torch.dtype(pytorch.ScalarType.complex32);
  1816. torch.complex64 = new torch.dtype(pytorch.ScalarType.complex64);
  1817. torch.complex128 = new torch.dtype(pytorch.ScalarType.complex128);
  1818. torch.bool = new torch.dtype(pytorch.ScalarType.boolean);
  1819. torch.qint8 = new torch.dtype(pytorch.ScalarType.qint8);
  1820. torch.quint8 = new torch.dtype(pytorch.ScalarType.quint8);
  1821. torch.qint32 = new torch.dtype(pytorch.ScalarType.qint32);
  1822. torch.bfloat16 = new torch.dtype(pytorch.ScalarType.bfloat16);
  1823. }
  1824. debug(file) {
  1825. const buffer = this.source(file + '.debug_pkl');
  1826. if (buffer) {
  1827. return null;
  1828. // const unpickler = python.Unpickler.open(buffer);
  1829. // return unpickler.load((name, args) => this.invoke(name, args), null);
  1830. }
  1831. return null;
  1832. }
  1833. };
  1834. pytorch.Container = class {
  1835. static open(context, metadata, exception) {
  1836. const zip = pytorch.Container.Zip.open(context.entries('zip'), metadata, exception);
  1837. if (zip) {
  1838. return zip;
  1839. }
  1840. const stream = context.stream;
  1841. const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  1842. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  1843. return new pytorch.Container.Pickle(stream, exception);
  1844. }
  1845. const entries = context.entries('tar');
  1846. if (entries.has('pickle')) {
  1847. return new pytorch.Container.Tar(entries, exception);
  1848. }
  1849. return null;
  1850. }
  1851. };
  1852. pytorch.Container.Tar = class {
  1853. constructor(entries, exceptionCallback) {
  1854. this._entries = entries;
  1855. this._exceptionCallack = exceptionCallback;
  1856. }
  1857. get format() {
  1858. return 'PyTorch v0.1.1';
  1859. }
  1860. get type() {
  1861. this._unpickle();
  1862. return this._type;
  1863. }
  1864. get data() {
  1865. this._unpickle();
  1866. return this._data;
  1867. }
  1868. get littleEndian() {
  1869. this._unpickle();
  1870. return this._littleEndian;
  1871. }
  1872. _unpickle() {
  1873. if (!this._entries) {
  1874. return;
  1875. }
  1876. this._type = '';
  1877. this._data = null;
  1878. this._littleEndian = true;
  1879. const execution = new pytorch.Execution(null, this._exceptionCallback);
  1880. const entries = {};
  1881. for (const entry of this._entries) {
  1882. switch (entry.name) {
  1883. case 'sys_info': entries.sys_info = entry.stream.peek(); break;
  1884. case 'pickle': entries.pickle = entry.stream.peek(); break;
  1885. case 'storages': entries.storages = entry.stream.peek(); break;
  1886. case 'tensors': entries.tensors = entry.stream.peek(); break;
  1887. }
  1888. }
  1889. this._exceptionCallback = null;
  1890. this._entries = null;
  1891. if (entries.sys_info) {
  1892. const unpickler = python.Unpickler.open(entries.sys_info);
  1893. const sys_info = unpickler.load((name, args) => execution.invoke(name, args));
  1894. if (sys_info.protocol_version != 1000) {
  1895. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  1896. }
  1897. if (sys_info.type_sizes &&
  1898. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  1899. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  1900. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  1901. throw new pytorch.Error('Unsupported type sizes.');
  1902. }
  1903. this._littleEndian = sys_info.little_endian;
  1904. }
  1905. const deserialized_objects = {};
  1906. if (entries.storages) {
  1907. const unpickler = python.Unpickler.open(entries.storages);
  1908. const num_storages = unpickler.load((name, args) => execution.invoke(name, args));
  1909. for (let i = 0; i < num_storages; i++) {
  1910. const args = unpickler.load();
  1911. const key = args[0];
  1912. const storage_type = execution.type(args[2]);
  1913. const obj = storage_type._new_with_file(unpickler);
  1914. deserialized_objects[key] = obj;
  1915. }
  1916. /*
  1917. let storage_views = unpickler.load();
  1918. for target_cdata, root_cdata, offset, size in storage_views:
  1919. root = deserialized_objects[root_cdata]
  1920. deserialized_objects[target_cdata] = root[offset:offset + size]
  1921. */
  1922. }
  1923. if (entries.tensors) {
  1924. const unpickler = python.Unpickler.open(entries.tensors);
  1925. const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
  1926. for (let i = 0; i < num_tensors; i++) {
  1927. const args = unpickler.load();
  1928. const key = args[0];
  1929. const storage_id = args[1];
  1930. const storage = deserialized_objects[storage_id];
  1931. const ndim = unpickler.int32();
  1932. unpickler.read(4);
  1933. const shape = new Array(ndim);
  1934. for (let j = 0; j < ndim; j++) {
  1935. shape[j] = unpickler.int64();
  1936. }
  1937. const stride = new Array(ndim);
  1938. for (let j = 0; j < ndim; j++) {
  1939. stride[j] = unpickler.int64();
  1940. }
  1941. const storage_offset = unpickler.int64();
  1942. const tensor = execution.invoke('torch._utils._rebuild_tensor', [ storage, storage_offset, shape, stride ]);
  1943. deserialized_objects[key] = tensor;
  1944. }
  1945. }
  1946. if (entries.pickle) {
  1947. const unpickler = python.Unpickler.open(entries.pickle);
  1948. const persistent_load = (saved_id) => {
  1949. return deserialized_objects[saved_id];
  1950. };
  1951. const obj = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  1952. const weights = pytorch.Utility.findWeights(obj);
  1953. if (weights) {
  1954. this._type = 'weights';
  1955. this._data = weights;
  1956. }
  1957. else {
  1958. throw new pytorch.Error('File does not contain root module or state dictionary.');
  1959. }
  1960. }
  1961. }
  1962. };
  1963. pytorch.Container.Pickle = class {
  1964. constructor(stream, exception) {
  1965. this._stream = stream;
  1966. this._exceptionCallback = exception;
  1967. }
  1968. get format() {
  1969. return 'PyTorch v0.1.10';
  1970. }
  1971. get type() {
  1972. this._unpickle();
  1973. return this._type;
  1974. }
  1975. get data() {
  1976. this._unpickle();
  1977. return this._data;
  1978. }
  1979. get littleEndian() {
  1980. this._unpickle();
  1981. return this._littleEndian;
  1982. }
  1983. _unpickle() {
  1984. if (!this._stream) {
  1985. return;
  1986. }
  1987. const execution = new pytorch.Execution(null, this._exceptionCallback);
  1988. const unpickler = python.Unpickler.open(this._stream.length < 0x7ffff000 ? this._stream.peek() : this._stream);
  1989. this._stream = null;
  1990. this._exceptionCallback = null;
  1991. unpickler.load(); // magic_number
  1992. const protocol_version = unpickler.load();
  1993. if (protocol_version != 1001) {
  1994. throw new pytorch.Error("Unsupported protocol version '" + protocol_version + "'.");
  1995. }
  1996. const sys_info = unpickler.load();
  1997. if (sys_info.protocol_version != 1001) {
  1998. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  1999. }
  2000. this._littleEndian = sys_info.little_endian;
  2001. const module_source_map = new Map();
  2002. const deserialized_objects = new Map();
  2003. const persistent_load = (saved_id) => {
  2004. const typename = saved_id.shift();
  2005. const data = saved_id;
  2006. switch (typename) {
  2007. case 'module': {
  2008. const module = data[0];
  2009. const source = data[2];
  2010. module_source_map.set(module, source);
  2011. return data[0];
  2012. }
  2013. case 'storage': {
  2014. const data_type = execution.type(data.shift());
  2015. const root_key = data.shift();
  2016. data.shift(); // location
  2017. const size = data.shift();
  2018. const view_metadata = data.shift();
  2019. if (!deserialized_objects.has(root_key)) {
  2020. const obj = new data_type(size);
  2021. deserialized_objects.set(root_key, obj);
  2022. }
  2023. if (view_metadata) {
  2024. const view_key = view_metadata.shift();
  2025. view_metadata.shift(); // view_offset
  2026. view_metadata.shift(); // view_size
  2027. if (!deserialized_objects.has(view_key)) {
  2028. const view = null; // storage.slice(view_offset, view_offset + view_size);
  2029. deserialized_objects.set(view_key, view);
  2030. }
  2031. return deserialized_objects.get(view_key);
  2032. }
  2033. return deserialized_objects.get(root_key);
  2034. }
  2035. }
  2036. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2037. };
  2038. const obj = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  2039. if (!obj) {
  2040. throw new pytorch.Error('File format is not PyTorch.');
  2041. }
  2042. if (obj === 'None') {
  2043. throw new pytorch.Error("File contains 'None' root object.");
  2044. }
  2045. const deserialized_storage_keys = unpickler.load();
  2046. for (const deserialized_storage_key of deserialized_storage_keys) {
  2047. const storage = deserialized_objects.get(deserialized_storage_key);
  2048. storage._set_from_file(unpickler);
  2049. }
  2050. const root = pytorch.Utility.findModule(obj);
  2051. if (root) {
  2052. this._type = 'module';
  2053. this._data = root;
  2054. }
  2055. else {
  2056. const weights = pytorch.Utility.findWeights(obj);
  2057. if (weights) {
  2058. this._type = 'weights';
  2059. this._data = weights;
  2060. }
  2061. else {
  2062. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2063. }
  2064. }
  2065. }
  2066. };
  2067. pytorch.Container.Zip = class {
  2068. static open(entries, metadata, exception) {
  2069. const name = Array.from(entries.keys()).find((name) => name == 'model.json' || name == 'data.pkl' || name.endsWith('/model.json') || name.endsWith('/data.pkl'));
  2070. if (!name) {
  2071. return null;
  2072. }
  2073. let model = null;
  2074. if (name.endsWith('.json')) {
  2075. try {
  2076. const stream = entries.get(name);
  2077. const buffer = stream.peek();
  2078. const decoder = new TextDecoder('utf-8');
  2079. const text = decoder.decode(buffer);
  2080. model = JSON.parse(text);
  2081. if (!model.mainModule) {
  2082. return null;
  2083. }
  2084. }
  2085. catch (error) {
  2086. return null;
  2087. }
  2088. }
  2089. return new pytorch.Container.Zip(entries, name, model, metadata, exception);
  2090. }
  2091. constructor(entries, name, model, metadata, exception) {
  2092. this._entries = entries;
  2093. this._metadata = metadata;
  2094. this._exceptionCallback = exception;
  2095. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
  2096. this._model = model;
  2097. const lastIndex = name.lastIndexOf('/');
  2098. this._prefix = lastIndex === -1 ? '' : name.substring(0, lastIndex + 1);
  2099. }
  2100. get format() {
  2101. if (this._format === undefined) {
  2102. if (this._entry('model.json')) {
  2103. this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
  2104. }
  2105. else if (this._entry('data.pkl')) {
  2106. // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
  2107. // kProducedFileFormatVersion
  2108. const versions = new Map([
  2109. [ '1', 'v1.3' ],
  2110. [ '2', 'v1.5' ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
  2111. [ '3', 'v1.6' ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
  2112. [ '4', 'v1.6' ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
  2113. [ '5', 'v1.7' ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
  2114. [ '6', 'v1.9+' ] // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
  2115. ]);
  2116. const value = this.version;
  2117. if (!versions.has(value)) {
  2118. this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
  2119. }
  2120. const version = versions.get(value);
  2121. const constants = this._entry('constants.pkl');
  2122. this._format = (constants ? 'TorchScript' : 'PyTorch') + ' ' + (version || 'v-' + value.toString() );
  2123. }
  2124. }
  2125. return this._format;
  2126. }
  2127. get version() {
  2128. const stream = this._entry('version');
  2129. if (stream) {
  2130. const decoder = new TextDecoder('utf-8');
  2131. const buffer = stream.peek();
  2132. const value = decoder.decode(buffer);
  2133. return value.split('\n').shift();
  2134. }
  2135. return '';
  2136. }
  2137. get producer() {
  2138. return this.data ? this._producer : '';
  2139. }
  2140. get name() {
  2141. return this._name;
  2142. }
  2143. get littleEndian() {
  2144. return true;
  2145. }
  2146. get type() {
  2147. this._load();
  2148. return this._type;
  2149. }
  2150. get data() {
  2151. this._load();
  2152. return this._data;
  2153. }
  2154. get constants() {
  2155. if (this._constants === undefined) {
  2156. this._constants = [];
  2157. const stream = this._entry('constants.pkl');
  2158. if (stream) {
  2159. const buffer = stream.peek();
  2160. this._constants = this._unpickle(buffer, this._storage('constants'));
  2161. for (let i = 0; i < this._constants.length; i++) {
  2162. const constant = this._constants[i];
  2163. const variable = 'CONSTANTS.c' + i.toString();
  2164. if (pytorch.Utility.isTensor(constant)) {
  2165. constant.__variable__ = variable;
  2166. }
  2167. else if (constant && constant.__class__ && constant.__class__.__module__ === '__torch__.torch.classes.xnnpack') {
  2168. switch (constant.__class__.__name__) {
  2169. case 'LinearOpContext':
  2170. case 'Conv2dOpContext':
  2171. if (pytorch.Utility.isTensor(constant[0])) {
  2172. constant[0].__variable__ = variable + '.weight';
  2173. }
  2174. if (pytorch.Utility.isTensor(constant[1])) {
  2175. constant[1].__variable__ = variable + '.bias';
  2176. }
  2177. break;
  2178. default:
  2179. throw new pytorch.Error("Unsupported constant context '" + constant.__class__.__name__ + "'.");
  2180. }
  2181. }
  2182. else {
  2183. throw new pytorch.Error('Unsupported constant.');
  2184. }
  2185. }
  2186. }
  2187. }
  2188. return this._constants;
  2189. }
  2190. get execution() {
  2191. if (this._execution === undefined) {
  2192. const sources = new Map();
  2193. for (const entry of this._entries) {
  2194. const name = entry[0];
  2195. if (name.startsWith(this._prefix + 'code')) {
  2196. const file = name.substring(this._prefix.length);
  2197. if (sources.has(file)) {
  2198. throw new pytorch.Error("Duplicate source file '" + file + "'.");
  2199. }
  2200. const stream = entry[1];
  2201. const buffer = stream.peek();
  2202. sources.set(file, buffer);
  2203. }
  2204. }
  2205. this._execution = new pytorch.Container.Zip.Execution(sources, this._exceptionCallback, this._metadata);
  2206. const constants = {};
  2207. for (let i = 0; i < this.constants.length; i++) {
  2208. constants['c' + i.toString()] = this.constants[i];
  2209. }
  2210. this._execution.context.set('CONSTANTS', constants);
  2211. }
  2212. return this._execution;
  2213. }
  2214. _entry(name) {
  2215. return this._entries.get(this._prefix + name);
  2216. }
  2217. _load() {
  2218. if (this._data === undefined) {
  2219. this._data = null;
  2220. const stream = this._entry('data.pkl');
  2221. if (stream) {
  2222. const buffer = stream.peek();
  2223. this._data = this._unpickle(buffer, this._storage('data'));
  2224. }
  2225. else {
  2226. if (this._model) {
  2227. this._producer = this._model.producerName + (this._model.producerVersion ? ' v' + this._model.producerVersion : '');
  2228. this._data = this._model.mainModule || {};
  2229. this._name = this._data.name || '';
  2230. if (this._data.torchscriptArena) {
  2231. this._torchscriptArena = this._data.torchscriptArena.key;
  2232. }
  2233. const queue = [ this._data ];
  2234. const entries = new Map();
  2235. for (const entry of this._entries) {
  2236. const name = entry[0];
  2237. const stream = entry[1];
  2238. const buffer = stream.peek();
  2239. entries.set(name, buffer);
  2240. }
  2241. const tensorTypeMap = new Map([
  2242. [ 'FLOAT', 'Float' ],
  2243. [ 'FLOAT16', 'Half' ],
  2244. [ 'DOUBLE', 'Double' ],
  2245. [ 'INT8', 'Char' ],
  2246. [ 'INT32', 'Int' ],
  2247. [ 'INT64', 'Long' ]
  2248. ]);
  2249. const constants = this._model.tensors || [];
  2250. this._constants = constants.map((constant) => {
  2251. const key = this._prefix + constant.data.key;
  2252. if (!tensorTypeMap.has(constant.dataType)) {
  2253. throw new pytorch.Error("Unknown tensor data type '" + constant.dataType + "'.");
  2254. }
  2255. const type = tensorTypeMap.get(constant.dataType);
  2256. const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
  2257. const storage_type = this.execution.type('torch.' + type + 'Storage');
  2258. const size = (shape || []).reduce((a, b) => a * b, 1);
  2259. const offset = parseInt(constant.offset, 10) || 0;
  2260. const storage = new storage_type([ size ]);
  2261. const itemsize = storage.dtype.itemsize();
  2262. const buffer = entries.get(key);
  2263. const length = size * itemsize;
  2264. const data = buffer.slice(offset, offset + length);
  2265. storage._set_cdata(data);
  2266. const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
  2267. tensor.name = constant.data.key;
  2268. return tensor;
  2269. });
  2270. this._attributes = [];
  2271. const stream = this._entry('attributes.pkl');
  2272. if (stream) {
  2273. const buffer = stream.peek();
  2274. const unpickler = python.Unpickler.open(buffer);
  2275. this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
  2276. }
  2277. while (queue.length > 0) {
  2278. const module = queue.shift();
  2279. if (!module.__class__) {
  2280. module.__class__ = {
  2281. __module__: 'torch.nn.modules.module',
  2282. __name__: 'Module'
  2283. };
  2284. }
  2285. if (module.name) {
  2286. module.__id__ = module.name;
  2287. }
  2288. if (module.submodules) {
  2289. for (const submodule of module.submodules) {
  2290. module[submodule.name] = submodule;
  2291. submodule.__parent__ = module;
  2292. queue.push(submodule);
  2293. }
  2294. delete module.submodules;
  2295. }
  2296. const attributes = [];
  2297. if (module.attributes) {
  2298. attributes.push(...module.attributes);
  2299. delete module.attributes;
  2300. }
  2301. const parameters = [];
  2302. if (module.parameters) {
  2303. parameters.push(...module.parameters);
  2304. delete module.parameters;
  2305. }
  2306. if (module.arguments) {
  2307. parameters.push(...module.arguments);
  2308. delete module.arguments;
  2309. }
  2310. for (const parameter of parameters) {
  2311. const tensor = this._constants[parameter.tensorId];
  2312. module[parameter.name] = tensor;
  2313. if (!parameter.__class__) {
  2314. parameter.__class__ = {
  2315. __module__: 'torch',
  2316. __name__: 'Tensor'
  2317. };
  2318. }
  2319. }
  2320. for (const attribute of attributes) {
  2321. module[attribute.name] = this._attributes[attribute.id];
  2322. }
  2323. }
  2324. delete this._model;
  2325. }
  2326. }
  2327. if (this.format.startsWith('TorchScript ')) {
  2328. this._type = 'script';
  2329. }
  2330. else {
  2331. const obj = this._data;
  2332. const root = pytorch.Utility.findModule(obj);
  2333. if (root) {
  2334. this._type = 'module';
  2335. this._data = root;
  2336. }
  2337. else {
  2338. const weights = pytorch.Utility.findWeights(obj);
  2339. if (weights) {
  2340. this._type = 'weights';
  2341. this._data = weights;
  2342. }
  2343. else {
  2344. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2345. }
  2346. }
  2347. }
  2348. }
  2349. }
  2350. _unpickle(data, storage_map) {
  2351. const loaded_storages = new Map();
  2352. const persistent_load = (saved_id) => {
  2353. const typename = saved_id.shift();
  2354. if (typename !== 'storage') {
  2355. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2356. }
  2357. const data_type = this.execution.type(saved_id.shift());
  2358. const root_key = saved_id.shift();
  2359. /* const location = */ saved_id.shift();
  2360. const size = saved_id.shift();
  2361. if (!loaded_storages.has(root_key)) {
  2362. const storage = new data_type(size);
  2363. storage._set_cdata(storage_map.get(root_key));
  2364. loaded_storages.set(root_key, storage);
  2365. }
  2366. const storage = loaded_storages.get(root_key);
  2367. const view_metadata = saved_id.shift();
  2368. if (view_metadata) {
  2369. const view_key = view_metadata.shift();
  2370. view_metadata.shift(); // view_offset
  2371. view_metadata.shift(); // view_size
  2372. let view = null;
  2373. if (loaded_storages.has(view_key)) {
  2374. view = loaded_storages.get(root_key);
  2375. }
  2376. else {
  2377. view = null; // storage.slice(view_offset, view_offset + view_size);
  2378. loaded_storages.set(view_key, view);
  2379. }
  2380. return view;
  2381. }
  2382. return storage;
  2383. };
  2384. const unpickler = python.Unpickler.open(data);
  2385. return unpickler.load((name, args) => this.execution.invoke(name, args), persistent_load);
  2386. }
  2387. _storage(dirname) {
  2388. const map = new Map();
  2389. const prefix = this._prefix + dirname + '/';
  2390. for (const entry of this._entries) {
  2391. if (entry[0].startsWith(prefix)) {
  2392. const key = entry[0].substring(prefix.length);
  2393. const buffer = entry[1].peek();
  2394. map.set(key, buffer);
  2395. }
  2396. }
  2397. return map;
  2398. }
  2399. trace() {
  2400. this._inputs = [];
  2401. this._outputs = [];
  2402. this.execution.reset();
  2403. if (this._torchscriptArena) {
  2404. const program = this.execution.parse(this._torchscriptArena);
  2405. for (const statement of program.body) {
  2406. if (statement.type == 'def') {
  2407. const self = this;
  2408. const globals = this.execution.context;
  2409. const func = {
  2410. __class__: this.execution.context.scope.builtins.function,
  2411. __name__: statement.name,
  2412. __code__: statement,
  2413. __call__: function(args) {
  2414. return self.execution.apply(this.__code__, args, globals);
  2415. }
  2416. };
  2417. this.data[statement.name] = func;
  2418. }
  2419. }
  2420. }
  2421. if (this.data.forward) {
  2422. const args = [ this.data ]; // self
  2423. if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
  2424. for (const parameter of this.data.forward.__code__.parameters) {
  2425. const defaultValue = (type, name) => {
  2426. if (type.type === 'type' && type.name.type) {
  2427. switch (type.name.value) {
  2428. case 'Tensor': {
  2429. const tensor = this.execution.invoke('torch.Tensor', []);
  2430. tensor.__variable__ = name;
  2431. tensor.__origin__ = 'graph-input';
  2432. return tensor;
  2433. }
  2434. case 'Tuple':
  2435. return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']'));
  2436. case 'List':
  2437. return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']' ));
  2438. case 'Dict':
  2439. return {};
  2440. case 'int':
  2441. return 0;
  2442. case 'float':
  2443. return 0.0;
  2444. case 'bool':
  2445. return false;
  2446. case 'Optional':
  2447. return undefined;
  2448. }
  2449. }
  2450. throw new pytorch.Error("Unknown function parameter type '" + JSON.stringify(type) + "'.");
  2451. };
  2452. if (parameter.name !== 'self') {
  2453. const type = parameter.parameterType;
  2454. const value = defaultValue(type, parameter.name);
  2455. if (pytorch.Utility.isTensor(value)) {
  2456. value.__variable__ = parameter.name;
  2457. value.__origin__ = 'graph-input';
  2458. this._inputs.push(parameter.name);
  2459. }
  2460. args.push(value);
  2461. }
  2462. }
  2463. }
  2464. const result = this.data.forward.__call__(args);
  2465. if (Array.isArray(result)) {
  2466. for (const output of result) {
  2467. if (pytorch.Utility.isTensor(output)) {
  2468. this._outputs.push(output.__variable__);
  2469. }
  2470. }
  2471. }
  2472. else if (pytorch.Utility.isTensor(result)) {
  2473. this._outputs.push(result.__variable__);
  2474. }
  2475. else if (Object(result) === result) {
  2476. for (const key of Object.keys(result)) {
  2477. const value = result[key];
  2478. if (Array.isArray(value)) {
  2479. for (const output of value) {
  2480. if (pytorch.Utility.isTensor(output)) {
  2481. this._outputs.push(output.__variable__);
  2482. }
  2483. }
  2484. }
  2485. else if (pytorch.Utility.isTensor(value)) {
  2486. this._outputs.push(value.__variable__);
  2487. }
  2488. }
  2489. }
  2490. this._nodes = this.execution.nodes;
  2491. return true;
  2492. }
  2493. throw new pytorch.Error("Module 'forward' not implemented.");
  2494. }
  2495. get inputs() {
  2496. return this._inputs;
  2497. }
  2498. get outputs() {
  2499. return this._outputs;
  2500. }
  2501. get nodes() {
  2502. return this._nodes;
  2503. }
  2504. };
  2505. pytorch.Container.Zip.Execution = class extends pytorch.Execution {
  2506. constructor(sources, exceptionCallback, metadata) {
  2507. super(sources, exceptionCallback);
  2508. this._metadata = metadata;
  2509. this.reset();
  2510. }
  2511. reset() {
  2512. this._nodes = [];
  2513. this._variableIndex = 0;
  2514. }
  2515. get nodes() {
  2516. return this._nodes;
  2517. }
  2518. call(target, name, args, context) {
  2519. let resolvedTarget = pytorch.Utility.target(target);
  2520. let outputTypes = null;
  2521. if (resolvedTarget && resolvedTarget + '.' + name === 'ops.prim.NumToTensor' &&
  2522. args.length === 1 && args[0].type === 'call' && args[0].target.member.type == 'id') {
  2523. const innerCall = args[0];
  2524. resolvedTarget = pytorch.Utility.target(innerCall.target.target);
  2525. name = innerCall.target.member.value;
  2526. args = innerCall.arguments;
  2527. outputTypes = [ 'int64' ];
  2528. }
  2529. if (resolvedTarget) {
  2530. const type = resolvedTarget + '.' + name;
  2531. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
  2532. let schemas = this._metadata.type(type);
  2533. if (schemas) {
  2534. schemas = !Array.isArray(schemas) ? [ schemas ] : schemas;
  2535. const evalArgs = args.map((argument) => argument.type === '=' && argument.target && argument.target.type === 'id' ? this.expression(argument.expression, context) : this.expression(argument, context));
  2536. for (const schema of schemas) {
  2537. const copyArgs = Array.prototype.slice.call(args);
  2538. const copyEvalArgs = Array.prototype.slice.call(evalArgs);
  2539. const node = {
  2540. type: schema.name,
  2541. inputs: [],
  2542. attributes: [],
  2543. outputs: []
  2544. };
  2545. const referencedParameters = [];
  2546. let next = false;
  2547. const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
  2548. while (copyEvalArgs.length > 0) {
  2549. if (parameters.length <= 0) {
  2550. next = true;
  2551. break;
  2552. }
  2553. const paramsBase = copyEvalArgs[0];
  2554. if (paramsBase && paramsBase.__class__ && paramsBase.__class__.__module__ === '__torch__.torch.classes.quantized') {
  2555. switch (paramsBase.__class__.__name__) {
  2556. case 'Conv2dPackedParamsBase':
  2557. case 'Conv3dPackedParamsBase': {
  2558. copyArgs.shift();
  2559. copyEvalArgs.shift();
  2560. copyArgs.unshift({ type: null });
  2561. copyEvalArgs.unshift(paramsBase.bias);
  2562. copyArgs.unshift({ type: null });
  2563. copyEvalArgs.unshift(paramsBase.weight);
  2564. break;
  2565. }
  2566. case 'LinearPackedParamsBase': {
  2567. copyArgs.shift();
  2568. copyEvalArgs.shift();
  2569. copyArgs.unshift({ type: null });
  2570. copyEvalArgs.unshift(paramsBase.bias);
  2571. copyArgs.unshift({ type: null });
  2572. copyEvalArgs.unshift(paramsBase.weight);
  2573. break;
  2574. }
  2575. default:
  2576. throw new pytorch.Error("Unsupported type '" + paramsBase.__name__ + "'.");
  2577. }
  2578. }
  2579. const op_context = copyEvalArgs[0];
  2580. if (op_context && op_context.__class__ && op_context.__class__.__module__ === '__torch__.torch.classes.xnnpack') {
  2581. switch (op_context.__class__.__name__) {
  2582. case 'LinearOpContext':
  2583. case 'Conv2dOpContext':
  2584. copyArgs.shift();
  2585. copyEvalArgs.shift();
  2586. for (const key of Object.keys(op_context).filter((key) => Number.isInteger(parseInt(key, 10)))) {
  2587. copyArgs.push({ type: null });
  2588. copyEvalArgs.push(op_context[key]);
  2589. }
  2590. break;
  2591. default:
  2592. throw new pytorch.Error("Unsupported type '" + paramsBase.__name__ + "'.");
  2593. }
  2594. }
  2595. if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
  2596. parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) {
  2597. const map = new Map();
  2598. for (const parameter of parameters) {
  2599. map.set(parameter.name, parameter);
  2600. }
  2601. while (copyArgs.length > 0) {
  2602. const argument = copyArgs.shift();
  2603. const value = copyEvalArgs.shift();
  2604. const parameter = map.get(argument.target.value);
  2605. if (!parameter) {
  2606. next = true;
  2607. break;
  2608. }
  2609. if (!pytorch.Utility.isType(value, parameter.type)) {
  2610. if (parameter.optional) {
  2611. continue;
  2612. }
  2613. next = true;
  2614. break;
  2615. }
  2616. node.attributes.push({ name: parameter.name, value: value });
  2617. }
  2618. continue;
  2619. }
  2620. if (next) {
  2621. break;
  2622. }
  2623. const parameter = parameters.shift();
  2624. const argument = copyEvalArgs[0];
  2625. switch (parameter.type) {
  2626. case 'Tensor': {
  2627. if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
  2628. if (parameter.optional) {
  2629. if (argument === undefined) {
  2630. copyArgs.shift();
  2631. copyEvalArgs.shift();
  2632. }
  2633. continue;
  2634. }
  2635. next = true;
  2636. break;
  2637. }
  2638. copyArgs.shift();
  2639. copyEvalArgs.shift();
  2640. const item = (argument === null || argument === undefined) ? {} : argument;
  2641. item.__variable__ = item.__variable__ || this.variable();
  2642. const inputs = [];
  2643. inputs.push({ id: item.__variable__ });
  2644. referencedParameters.push(item);
  2645. node.inputs.push(inputs);
  2646. break;
  2647. }
  2648. case 'Tensor[]': {
  2649. const argument = copyEvalArgs[0];
  2650. if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) {
  2651. if (parameter.optional) {
  2652. continue;
  2653. }
  2654. next = true;
  2655. break;
  2656. }
  2657. copyArgs.shift();
  2658. copyEvalArgs.shift();
  2659. const inputs = [];
  2660. for (let item of argument) {
  2661. if (item === null) {
  2662. item = {};
  2663. }
  2664. item.__variable__ = item.__variable__ || this.variable();
  2665. inputs.push({ id: item.__variable__ });
  2666. referencedParameters.push(item);
  2667. }
  2668. node.inputs.push(inputs);
  2669. break;
  2670. }
  2671. default: {
  2672. const arg = copyArgs[0];
  2673. if (!pytorch.Utility.isType(argument, parameter.type) && argument !== null) {
  2674. if (parameter.optional) {
  2675. continue;
  2676. }
  2677. next = true;
  2678. break;
  2679. }
  2680. if (arg.type !== '=') {
  2681. copyArgs.shift();
  2682. copyEvalArgs.shift();
  2683. node.attributes.push({ name: parameter.name, value: argument });
  2684. }
  2685. else {
  2686. throw new pytorch.Error('Expected named argument.');
  2687. }
  2688. break;
  2689. }
  2690. }
  2691. if (next) {
  2692. break;
  2693. }
  2694. }
  2695. if (next) {
  2696. continue;
  2697. }
  2698. const result = [];
  2699. for (let i = 0; i < schema.outputs.length; i++) {
  2700. const parameter = schema.outputs[i];
  2701. switch (parameter.type) {
  2702. case 'Tensor': {
  2703. const parameter = this.invoke('torch.Tensor', []);
  2704. parameter.__origin__ = type;
  2705. if (i === 0) {
  2706. switch (type) {
  2707. case 'torch.cat':
  2708. case 'torch.conv2d':
  2709. case 'torch.dropout':
  2710. case 'torch.flatten':
  2711. case 'torch.max_pool2d':
  2712. case 'torch.adaptive_avg_pool2d':
  2713. case 'torch.avg_pool2d':
  2714. case 'torch.quantize_per_tensor':
  2715. case 'torch.relu_':
  2716. case 'torch.hardtanh_':
  2717. case 'torch.unsqueeze':
  2718. case 'torch.slice': {
  2719. parameter.resize_([ NaN, NaN, NaN, NaN ]);
  2720. break;
  2721. }
  2722. case 'torch.conv3d': {
  2723. parameter.resize_([ NaN, NaN, NaN, NaN, NaN ]);
  2724. break;
  2725. }
  2726. case 'torch.embedding': {
  2727. parameter.resize_([ NaN, NaN, NaN ]);
  2728. break;
  2729. }
  2730. case 'torch.mean':
  2731. case 'torch.mul':
  2732. case 'torch.add':
  2733. case 'torch.batch_norm':
  2734. case 'torch.relu': {
  2735. const input = this.expression(args[0], context);
  2736. if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
  2737. parameter.resize_(input.size());
  2738. }
  2739. break;
  2740. }
  2741. case 'torch.ones':
  2742. case 'torch.zeros':
  2743. case 'torch.zeros_like': {
  2744. parameter.resize_(this.expression(args[0], context));
  2745. break;
  2746. }
  2747. case 'torch.new_full': {
  2748. parameter.resize_(this.expression(args[1], context));
  2749. break;
  2750. }
  2751. case 'ops.quantized.cat':
  2752. case 'ops.quantized.cat_relu':
  2753. case 'ops.quantized.linear':
  2754. case 'ops.quantized.conv2d':
  2755. case 'ops.quantized.conv2d_relu':
  2756. case 'ops.quantized.add_relu':
  2757. parameter.resize_([ NaN, NaN, NaN, NaN ]);
  2758. break;
  2759. case 'torch.view':
  2760. parameter.resize_(this.expression(args[1], context));
  2761. break;
  2762. }
  2763. }
  2764. parameter.__variable__ = this.variable();
  2765. result.push(parameter);
  2766. node.outputs.push([ { id: parameter.__variable__ } ]);
  2767. break;
  2768. }
  2769. case 'Tensor[]': {
  2770. let count = 1;
  2771. switch (type) {
  2772. case 'torch.chunk':
  2773. count = node.attributes.filter((attribute) => attribute.name == 'chunks')[0].value;
  2774. break;
  2775. case 'torch.meshgrid':
  2776. count = node.inputs[0].length;
  2777. break;
  2778. case 'torch.unbind':
  2779. count = args[0].__tuple__ || count;
  2780. break;
  2781. }
  2782. const tensors = [];
  2783. const outputs = [];
  2784. for (let i = 0; i < count; i ++) {
  2785. const tensor = this.invoke('torch.Tensor', []);
  2786. tensor.__origin__ = type;
  2787. tensor.__variable__ = this.variable();
  2788. tensors.push(tensor);
  2789. outputs.push({ id: tensor.__variable__ });
  2790. }
  2791. result.push(tensors);
  2792. node.outputs.push(outputs);
  2793. break;
  2794. }
  2795. default: {
  2796. if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) {
  2797. next = true;
  2798. break;
  2799. }
  2800. const tensor = this.invoke('torch.Tensor', []);
  2801. tensor.resize_([]);
  2802. tensor.__origin__ = type;
  2803. tensor.__variable__ = this.variable();
  2804. result.push(tensor);
  2805. node.outputs.push([ { id: tensor.__variable__ } ]);
  2806. break;
  2807. }
  2808. }
  2809. }
  2810. if (next) {
  2811. continue;
  2812. }
  2813. for (const parameter of referencedParameters) {
  2814. parameter.__count__ = (parameter.__count__ || 0) + 1;
  2815. }
  2816. this.push(node);
  2817. if (result.length > 1) {
  2818. return result;
  2819. }
  2820. return result[0];
  2821. }
  2822. }
  2823. }
  2824. return super.call(target, name, args, context);
  2825. }
  2826. block(statements, context) {
  2827. statements = Array.prototype.slice.call(statements);
  2828. while (statements.length > 0) {
  2829. if (statements.length > 1) {
  2830. const assign = statements[0];
  2831. const condition = statements[1];
  2832. // _x = torch.ne(torch.len(torch.size(input)), 5)
  2833. // if _x:
  2834. // ops.prim.RaiseException(...)
  2835. if (assign.type === '=' &&
  2836. condition.type === 'if' &&
  2837. pytorch.Utility.isEqual(assign.target, condition.condition) &&
  2838. pytorch.Utility.isCall(assign.expression, 'torch.ne', 2) &&
  2839. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.len', 1) &&
  2840. pytorch.Utility.isCall(assign.expression.arguments[0].arguments[0], 'torch.size', 1) &&
  2841. condition.then.statements.length == 1 &&
  2842. pytorch.Utility.isCall(condition.then.statements[0], 'ops.prim.RaiseException', 1)) {
  2843. const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
  2844. if (tensor && tensor.size) {
  2845. const number = this.expression(assign.expression.arguments[1], context);
  2846. const size = tensor.size();
  2847. if (size && size.length && size.length !== number &&
  2848. size.every((item) => isNaN(item)) && number >= 3 && number <= 5) {
  2849. if (tensor.__origin__ === 'torch.quantize_per_tensor') {
  2850. tensor.resize_(Array(number).fill(NaN));
  2851. }
  2852. }
  2853. }
  2854. }
  2855. // _0 = torch.eq(torch.len(torch.size(x)), 2)
  2856. // if _0:
  2857. // pass
  2858. // else:
  2859. // ops.prim.RaiseException("AssertionError: ")
  2860. if (assign.type === '=' &&
  2861. condition.type === 'if' &&
  2862. pytorch.Utility.isEqual(assign.target, condition.condition) &&
  2863. pytorch.Utility.isCall(assign.expression, 'torch.eq', 2) &&
  2864. pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.len', 1) &&
  2865. pytorch.Utility.isCall(assign.expression.arguments[0].arguments[0], 'torch.size', 1) &&
  2866. condition.else.statements.length == 1 &&
  2867. pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
  2868. const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
  2869. if (tensor && tensor.shape === undefined) {
  2870. const number = this.expression(assign.expression.arguments[1], context);
  2871. tensor.resize_(Array(number).fill(NaN));
  2872. }
  2873. }
  2874. }
  2875. const statement = statements.shift();
  2876. // input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
  2877. if (statement.type === '=' &&
  2878. pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
  2879. pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 1)) {
  2880. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  2881. if (tensor && tensor.shape === undefined) {
  2882. tensor.resize_([ 1, 3, 299, 299 ]);
  2883. }
  2884. }
  2885. // torch.slice(ops.prim.shape(input), 0, 2, 1)
  2886. if (statement.type === '=' &&
  2887. pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
  2888. pytorch.Utility.isCall(statement.expression.arguments[0], 'ops.prim.shape', 1)) {
  2889. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  2890. if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  2891. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  2892. }
  2893. }
  2894. // _3 = torch.le(xxxx, torch.dim(f0))
  2895. if (statement.type === '=' &&
  2896. pytorch.Utility.isCall(statement.expression, 'torch.le', 2) &&
  2897. pytorch.Utility.isCall(statement.expression.arguments[1], 'torch.dim', 1)) {
  2898. const tensor = this.expression(statement.expression.arguments[1].arguments[0], context);
  2899. if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  2900. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  2901. }
  2902. }
  2903. // if torch.ne(torch.dim(image), 3):
  2904. // xxxx
  2905. // ops.prim.RaiseException(_7)
  2906. if (statement.type === 'if' &&
  2907. pytorch.Utility.isCall(statement.condition, 'torch.ne', 2) &&
  2908. pytorch.Utility.isCall(statement.condition.arguments[0], 'torch.dim', 1) &&
  2909. statement.then.statements.length > 0 &&
  2910. pytorch.Utility.isCall(statement.then.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) {
  2911. const tensor = this.expression(statement.condition.arguments[0].arguments[0], context);
  2912. const size = this.expression(statement.condition.arguments[1], context);
  2913. if (tensor && Number.isInteger(size) && size < 10) {
  2914. tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN));
  2915. }
  2916. }
  2917. // dim = torch.sub(torch.dim(input), 2)
  2918. if (statement.type === '=' &&
  2919. statement.target.type === 'id' && statement.target.value === 'dim' &&
  2920. pytorch.Utility.isCall(statement.expression, 'torch.sub', 2) &&
  2921. pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.dim', 1)) {
  2922. const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
  2923. if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  2924. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  2925. }
  2926. }
  2927. // a, b = torch.unbind(size, 0)
  2928. if (statement.type === '=' &&
  2929. statement.target.type === 'tuple' &&
  2930. pytorch.Utility.isCall(statement.expression, 'torch.unbind', 2)) {
  2931. statement.expression.arguments[0].__tuple__ = statement.target.value.length;
  2932. }
  2933. // x = torch.len(input)
  2934. if (statement.type === '=' &&
  2935. statement.target.type === 'id' &&
  2936. pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) {
  2937. const tensor = this.expression(statement.expression.arguments[0], context);
  2938. if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
  2939. tensor.resize_([ NaN, NaN, NaN, NaN ]);
  2940. }
  2941. }
  2942. const value = this.statement(statement, context);
  2943. if (value !== undefined) {
  2944. return value;
  2945. }
  2946. }
  2947. }
  2948. push(node) {
  2949. this._nodes.push(node);
  2950. }
  2951. variable() {
  2952. this._variableIndex++;
  2953. return this._variableIndex.toString();
  2954. }
  2955. };
  2956. pytorch.ScalarType = {
  2957. uint8: 0,
  2958. int8: 1,
  2959. int16: 2,
  2960. int32: 3,
  2961. int64: 4,
  2962. float16: 5,
  2963. float32: 6,
  2964. float64: 7,
  2965. complex32: 8,
  2966. complex64: 9,
  2967. complex128: 10,
  2968. boolean: 11,
  2969. qint8: 12,
  2970. quint8: 13,
  2971. qint32: 14,
  2972. bfloat16: 15,
  2973. quint4x2: 16
  2974. };
  2975. pytorch.MemoryFormat = {
  2976. Contiguous: 0,
  2977. Preserve: 1,
  2978. ChannelsLast: 2,
  2979. ChannelsLast3d: 3
  2980. };
  2981. pytorch.Layout = {
  2982. Strided: 0,
  2983. Sparse: 1,
  2984. Mkldnn: 2
  2985. };
  2986. pytorch.Utility = class {
  2987. static getScalarType(scalarType) {
  2988. if (!pytorch.Utility._scalarTypes) {
  2989. pytorch.Utility._scalarTypes = [
  2990. { name: 'uint8', itemsize: 1 },
  2991. { name: 'int8', itemsize: 1 },
  2992. { name: 'int16', itemsize: 2 },
  2993. { name: 'int32', itemsize: 4 },
  2994. { name: 'int64', itemsize: 8 },
  2995. { name: 'float16', itemsize: 2 },
  2996. { name: 'float32', itemsize: 4 },
  2997. { name: 'float64', itemsize: 8 },
  2998. { name: 'complex32', itemsize: 4 },
  2999. { name: 'complex64', itemsize: 8 },
  3000. { name: 'complex128', itemsize: 16 },
  3001. { name: 'boolean', itemsize: 1 },
  3002. { name: 'qint8', itemsize: 1 },
  3003. { name: 'quint8', itemsize: 1 },
  3004. { name: 'qint32', itemsize: 4 },
  3005. { name: 'bfloat16', itemsize: 2 },
  3006. { name: 'quint4x2' }
  3007. ];
  3008. }
  3009. if (scalarType < pytorch.Utility._scalarTypes.length) {
  3010. return pytorch.Utility._scalarTypes[scalarType];
  3011. }
  3012. throw new pytorch.Error("Unknown scalar type '" + scalarType + "'.");
  3013. }
  3014. static target(expression) {
  3015. if (expression.type == 'id') {
  3016. return expression.value;
  3017. }
  3018. if (expression.type == '.') {
  3019. return pytorch.Utility.target(expression.target) + '.' + pytorch.Utility.target(expression.member);
  3020. }
  3021. return null;
  3022. }
  3023. static isTensor(obj) {
  3024. if (obj && obj.__class__) {
  3025. switch (obj.__class__.__module__) {
  3026. case 'torch':
  3027. case 'torch.cuda':
  3028. return obj.__class__.__name__.endsWith('Tensor');
  3029. case 'torch.nn.parameter':
  3030. return obj.__class__.__name__ === 'Parameter';
  3031. }
  3032. }
  3033. return false;
  3034. }
  3035. static toTensor(obj) {
  3036. if (obj && obj.__class__) {
  3037. switch (obj.__class__.__module__) {
  3038. case 'torch':
  3039. case 'torch.cuda':
  3040. return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
  3041. case 'torch.nn.parameter':
  3042. return obj.__class__.__name__ === 'Parameter' ? obj.data : null;
  3043. }
  3044. }
  3045. return null;
  3046. }
  3047. static isType(obj, type) {
  3048. switch (type) {
  3049. case 'Tensor':
  3050. return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null);
  3051. case 'Tensor[]':
  3052. return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null);
  3053. case 'Scalar':
  3054. return (obj !== null && obj !== Object(obj)) || (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0);
  3055. case 'boolean':
  3056. return obj === true || obj === false;
  3057. case 'int64':
  3058. return Number.isInteger(obj) || obj instanceof base.Int64 || (typeof obj === 'number' && isNaN(obj));
  3059. case 'int64[]':
  3060. return Array.isArray(obj) && obj.every((item) => Number.isInteger(item) || (typeof item === 'number' && isNaN(item)) || item === undefined);
  3061. case 'int64[1]':
  3062. return pytorch.Utility.isType(obj, 'int64') || pytorch.Utility.isType(obj, 'int64[]');
  3063. case 'float32':
  3064. case 'float64':
  3065. return obj !== null && obj !== Object(obj);
  3066. case 'Layout':
  3067. case 'ScalarType':
  3068. case 'MemoryFormat':
  3069. return Number.isInteger(obj) || obj === null;
  3070. case 'Device':
  3071. return obj === null || obj === Object(obj);
  3072. }
  3073. return true;
  3074. }
  3075. static isCall(expression, name, size) {
  3076. if (expression.type === 'call' &&
  3077. expression.arguments.length === size &&
  3078. pytorch.Utility.target(expression.target) === name) {
  3079. return true;
  3080. }
  3081. return false;
  3082. }
  3083. static isEqual(a, b) {
  3084. return (a.type === 'id' && b.type === 'id' && a.value === b.value);
  3085. }
  3086. static findModule(root) {
  3087. if (root) {
  3088. const keys = [ '', 'model', 'net' ];
  3089. for (const key of keys) {
  3090. const obj = key === '' ? root : root[key];
  3091. if (obj) {
  3092. if (obj._modules) {
  3093. return [ { name: '', obj: obj } ];
  3094. }
  3095. const objKeys = Object.keys(obj).filter((key) => obj[key] && obj[key]._modules);
  3096. if (objKeys.length > 1) {
  3097. return objKeys.map((key) => { return { name: key, obj: obj[key] }; });
  3098. }
  3099. }
  3100. }
  3101. }
  3102. return null;
  3103. }
  3104. static findWeights(root) {
  3105. if (!root) {
  3106. return null;
  3107. }
  3108. if (root instanceof Map) {
  3109. const obj = {};
  3110. for (const pair of root) {
  3111. const key = pair[0];
  3112. const value = pair[1];
  3113. obj[key] = value;
  3114. }
  3115. root = obj;
  3116. }
  3117. const keys = root && !Array.isArray(root) ? Object.keys(root) : [];
  3118. if (keys.length > 1) {
  3119. keys.splice(0, keys.length);
  3120. }
  3121. keys.push(...[
  3122. 'state_dict', 'state', 'model_state', 'model', 'model_state_dict', 'model_dict', 'net_dict', 'params', 'generator',
  3123. 'discriminator', 'g_state', 'network', 'net', 'netG', 'net_states', 'state_dict_stylepredictor', 'state_dict_ghiasi', 'runner', ''
  3124. ]);
  3125. for (const key of keys) {
  3126. const obj = key === '' ? root : root[key];
  3127. let graphs = null;
  3128. graphs = graphs || pytorch.Utility._convertTensor(obj);
  3129. graphs = graphs || pytorch.Utility._convertObjectList(obj);
  3130. graphs = graphs || pytorch.Utility._convertStateDict(obj);
  3131. if (graphs) {
  3132. return graphs;
  3133. }
  3134. }
  3135. return null;
  3136. }
  3137. static _convertTensor(obj) {
  3138. if (obj && pytorch.Utility.isTensor(obj)) {
  3139. const layers = [];
  3140. const argument = { id: '', value: obj };
  3141. const parameter = { name: 'value', arguments: [ argument ] };
  3142. layers.push({ states: [ parameter ] });
  3143. return [ { layers: layers } ];
  3144. }
  3145. return null;
  3146. }
  3147. static _convertObjectList(list) {
  3148. if (list && Array.isArray(list) && list.every((obj) => obj && Object.keys(obj).filter((key) => pytorch.Utility.isTensor(obj[key]).length > 0))) {
  3149. const layers = [];
  3150. for (const obj of list) {
  3151. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?';
  3152. const layer = { type: type, states: [], attributes: [] };
  3153. for (const key of Object.keys(obj)) {
  3154. const value = obj[key];
  3155. if (pytorch.Utility.isTensor(value)) {
  3156. layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
  3157. }
  3158. else {
  3159. layer.attributes.push({ name: key, value: value });
  3160. }
  3161. }
  3162. layers.push(layer);
  3163. }
  3164. return [ { layers: layers } ];
  3165. }
  3166. return null;
  3167. }
  3168. static _convertStateDict(obj) {
  3169. const clean = (obj) => {
  3170. if (obj && Array.isArray(obj)) {
  3171. return obj;
  3172. }
  3173. if (obj && obj instanceof Map) {
  3174. return obj;
  3175. }
  3176. if (obj && Object(obj) === obj) {
  3177. const integer = new Set([ 'epoch', 'i_batch', 'num_vid', 'seen' ]);
  3178. const target = {};
  3179. for (const key of Object.keys(obj)) {
  3180. const value = obj[key];
  3181. if (key.indexOf('optim') !== -1 || key.indexOf('opt') !== -1) {
  3182. if (value === null || (value.state && value.param_groups)) {
  3183. continue;
  3184. }
  3185. }
  3186. if (key.indexOf('loss') !== -1 && Array.isArray(value)) {
  3187. continue;
  3188. }
  3189. if (integer.has(key) && Number.isInteger(value)) {
  3190. continue;
  3191. }
  3192. target[key] = value;
  3193. }
  3194. return target;
  3195. }
  3196. return obj;
  3197. };
  3198. const validate = (map) => {
  3199. let tensor = false;
  3200. if (map && map instanceof Map) {
  3201. for (const pair of map) {
  3202. const key = pair[0];
  3203. const value = pair[1];
  3204. if (key.split('.').pop() === '_metadata') {
  3205. continue;
  3206. }
  3207. if (pytorch.Utility.isTensor(value)) {
  3208. tensor = true;
  3209. continue;
  3210. }
  3211. else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
  3212. tensor = true;
  3213. continue;
  3214. }
  3215. else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
  3216. continue;
  3217. }
  3218. return false;
  3219. }
  3220. }
  3221. return tensor;
  3222. };
  3223. const flatten = (obj) => {
  3224. if (!obj || Array.isArray(obj)) {
  3225. return null;
  3226. }
  3227. if (obj instanceof Map) {
  3228. if (validate(obj)) {
  3229. return obj;
  3230. }
  3231. return null;
  3232. }
  3233. if (Object(obj) !== obj) {
  3234. return null;
  3235. }
  3236. const map = new Map(Object.keys(obj).map((key) => [ key, obj[key] ]));
  3237. if (validate(map)) {
  3238. return map;
  3239. }
  3240. map.clear();
  3241. for (const key of Object.keys(obj)) {
  3242. const value = flatten(obj[key]);
  3243. if (value && value instanceof Map) {
  3244. for (const pair of value) {
  3245. map.set(key + '.' + pair[0], pair[1]);
  3246. }
  3247. continue;
  3248. }
  3249. return null;
  3250. }
  3251. return map;
  3252. };
  3253. if (!obj) {
  3254. return null;
  3255. }
  3256. obj = clean(obj);
  3257. const map = new Map();
  3258. if (Array.isArray(obj) && obj.every((item) => validate(item))) {
  3259. for (let i = 0; i < obj.length; i++) {
  3260. map.set(i.toString(), flatten(obj[i]));
  3261. }
  3262. }
  3263. else if (obj instanceof Map && validate(obj)) {
  3264. map.set('', flatten(obj));
  3265. }
  3266. else if (Object(obj) === obj && Object.keys(obj).every((key) => validate(obj[key]))) {
  3267. for (const key of Object.keys(obj)) {
  3268. map.set(key, obj[key]);
  3269. }
  3270. }
  3271. else if (Object(obj) === obj && Object.keys(obj).every((key) => pytorch.Utility.isTensor(obj[key]))) {
  3272. map.set('', new Map(Object.keys(obj).map((key) => [ key, obj[key] ])));
  3273. }
  3274. else {
  3275. const value = flatten(obj);
  3276. if (value) {
  3277. map.set('', value);
  3278. }
  3279. }
  3280. if (map.size > 0) {
  3281. const graphs = [];
  3282. for (const pair of map) {
  3283. const graph_key = pair[0];
  3284. const layer_map = pair[1];
  3285. const layers = new Map();
  3286. for (const item of layer_map) {
  3287. const key = item[0];
  3288. const value = item[1];
  3289. let layerName = '';
  3290. let parameter = '';
  3291. const keys = key.split('.');
  3292. if (keys[keys.length - 1] === '_metadata') {
  3293. continue;
  3294. }
  3295. if (keys.length >= 2 && keys[keys.length - 2] === '_packed_params') {
  3296. parameter = keys.slice(-2).join('.');
  3297. keys.pop();
  3298. keys.pop();
  3299. }
  3300. else {
  3301. parameter = keys.pop();
  3302. if (keys.length < 0) {
  3303. keys.push('');
  3304. }
  3305. }
  3306. layerName = keys.join('.');
  3307. if (!layers.has(layerName)) {
  3308. layers.set(layerName, { name: layerName, states: [], attributes: [] });
  3309. }
  3310. const layer = layers.get(layerName);
  3311. if (pytorch.Utility.isTensor(value)) {
  3312. layer.states.push({ name: parameter, arguments: [ { id: key, value: value } ] });
  3313. if (layer.name == '' && layer.states.length > 12) {
  3314. return null;
  3315. }
  3316. }
  3317. else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
  3318. layer.states.push({ name: parameter, arguments: value.map((item) => { return { id: '', value: item }; }) });
  3319. }
  3320. else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
  3321. layer.attributes.push({ name: parameter, value: value });
  3322. }
  3323. }
  3324. graphs.push({
  3325. name: graph_key,
  3326. layers: layers.values()
  3327. });
  3328. }
  3329. return graphs;
  3330. }
  3331. return null;
  3332. }
  3333. };
  3334. pytorch.nnapi = {};
  3335. pytorch.nnapi.SerializedModel = class {
  3336. constructor(serialized_model /*, buffer_ptrs */) {
  3337. const reader = new pytorch.nnapi.SerializedModel.BinaryReader(serialized_model);
  3338. this.version = reader.int32();
  3339. if (this.version !== 1) {
  3340. throw new pytorch.Error('Invalid NNAPI serialized model version.');
  3341. }
  3342. const operands = new Array(reader.int32());
  3343. const values = new Array(reader.int32());
  3344. this.operations = new Array(reader.int32());
  3345. this.inputs = new Array(reader.int32());
  3346. this.outputs = new Array(reader.int32());
  3347. const types = new Map([
  3348. [ 0, 'float32' ],
  3349. [ 1, 'int32' ],
  3350. [ 2, 'uint32' ],
  3351. [ 3, 'float32' ],
  3352. [ 4, 'int32*' ],
  3353. [ 5, 'quant8_asymm*' ],
  3354. [ 6, 'boolean' ],
  3355. [ 7, 'quant16_symm*' ],
  3356. [ 8, 'float16*' ],
  3357. [ 9, 'boolean*' ],
  3358. [ 10, 'float16' ],
  3359. [ 11, 'quant8_symm_per_channel*' ],
  3360. [ 12, 'quant16_asymm*' ],
  3361. [ 13, 'quant8_symm*' ],
  3362. [ 14, 'quant8_asymm_signed*' ],
  3363. [ 16, 'model' ]
  3364. ]);
  3365. const operations = new Map([
  3366. [ 0, 'ADD' ],
  3367. [ 1, 'AVERAGE_POOL_2D' ],
  3368. [ 2, 'CONCATENATION' ],
  3369. [ 3, 'CONV_2D' ],
  3370. [ 4, 'DEPTHWISE_CONV_2D' ],
  3371. [ 5, 'DEPTH_TO_SPACE' ],
  3372. [ 6, 'DEQUANTIZE' ],
  3373. [ 7, 'EMBEDDING_LOOKUP' ],
  3374. [ 8, 'FLOOR' ],
  3375. [ 9, 'FULLY_CONNECTED' ],
  3376. [ 10, 'HASHTABLE_LOOKUP' ],
  3377. [ 11, 'L2_NORMALIZATION' ],
  3378. [ 12, 'L2_POOL_2D' ],
  3379. [ 13, 'LOCAL_RESPONSE_NORMALIZATION' ],
  3380. [ 14, 'LOGISTIC' ],
  3381. [ 15, 'LSH_PROJECTION' ],
  3382. [ 16, 'LSTM' ],
  3383. [ 17, 'MAX_POOL_2D' ],
  3384. [ 18, 'MUL' ],
  3385. [ 19, 'RELU' ],
  3386. [ 20, 'RELU1' ],
  3387. [ 21, 'RELU6' ],
  3388. [ 22, 'RESHAPE' ],
  3389. [ 23, 'RESIZE_BILINEAR' ],
  3390. [ 24, 'RNN' ],
  3391. [ 25, 'SOFTMAX' ],
  3392. [ 26, 'SPACE_TO_DEPTH' ],
  3393. [ 27, 'SVDF' ],
  3394. [ 28, 'TANH' ],
  3395. [ 29, 'BATCH_TO_SPACE_ND' ],
  3396. [ 30, 'DIV' ],
  3397. [ 31, 'MEAN' ],
  3398. [ 32, 'PAD' ],
  3399. [ 33, 'SPACE_TO_BATCH_ND' ],
  3400. [ 34, 'SQUEEZE' ],
  3401. [ 35, 'STRIDED_SLICE' ],
  3402. [ 36, 'SUB' ],
  3403. [ 37, 'TRANSPOSE' ],
  3404. [ 38, 'ABS' ],
  3405. [ 39, 'ARGMAX' ],
  3406. [ 40, 'ARGMIN' ],
  3407. [ 41, 'AXIS_ALIGNED_BBOX_TRANSFORM' ],
  3408. [ 42, 'BIDIRECTIONAL_SEQUENCE_LSTM' ],
  3409. [ 43, 'BIDIRECTIONAL_SEQUENCE_RNN' ],
  3410. [ 44, 'BOX_WITH_NMS_LIMIT' ],
  3411. [ 45, 'CAST' ],
  3412. [ 46, 'CHANNEL_SHUFFLE' ],
  3413. [ 47, 'DETECTION_POSTPROCESSING' ],
  3414. [ 48, 'EQUAL' ],
  3415. [ 49, 'EXP' ],
  3416. [ 50, 'EXPAND_DIMS' ],
  3417. [ 51, 'GATHER' ],
  3418. [ 52, 'GENERATE_PROPOSALS' ],
  3419. [ 53, 'GREATER' ],
  3420. [ 54, 'GREATER_EQUAL' ],
  3421. [ 55, 'GROUPED_CONV_2D' ],
  3422. [ 56, 'HEATMAP_MAX_KEYPOINT' ],
  3423. [ 57, 'INSTANCE_NORMALIZATION' ],
  3424. [ 58, 'LESS' ],
  3425. [ 59, 'LESS_EQUAL' ],
  3426. [ 60, 'LOG' ],
  3427. [ 61, 'LOGICAL_AND' ],
  3428. [ 62, 'LOGICAL_NOT' ],
  3429. [ 63, 'LOGICAL_OR' ],
  3430. [ 64, 'LOG_SOFTMAX' ],
  3431. [ 65, 'MAXIMUM' ],
  3432. [ 66, 'MINIMUM' ],
  3433. [ 67, 'NEG' ],
  3434. [ 68, 'NOT_EQUAL' ],
  3435. [ 69, 'PAD_V2' ],
  3436. [ 70, 'POW' ],
  3437. [ 71, 'PRELU' ],
  3438. [ 72, 'QUANTIZE' ],
  3439. [ 73, 'QUANTIZED_16BIT_LSTM' ],
  3440. [ 74, 'RANDOM_MULTINOMIAL' ],
  3441. [ 75, 'REDUCE_ALL' ],
  3442. [ 76, 'REDUCE_ANY' ],
  3443. [ 77, 'REDUCE_MAX' ],
  3444. [ 78, 'REDUCE_MIN' ],
  3445. [ 79, 'REDUCE_PROD' ],
  3446. [ 80, 'REDUCE_SUM' ],
  3447. [ 81, 'ROI_ALIGN' ],
  3448. [ 82, 'ROI_POOLING' ],
  3449. [ 83, 'RSQRT' ],
  3450. [ 84, 'SELECT' ],
  3451. [ 85, 'SIN' ],
  3452. [ 86, 'SLICE' ],
  3453. [ 87, 'SPLIT' ],
  3454. [ 88, 'SQRT' ],
  3455. [ 89, 'TILE' ],
  3456. [ 90, 'TOPK_V2' ],
  3457. [ 91, 'TRANSPOSE_CONV_2D' ],
  3458. [ 92, 'UNIDIRECTIONAL_SEQUENCE_LSTM' ],
  3459. [ 93, 'UNIDIRECTIONAL_SEQUENCE_RNN' ],
  3460. [ 94, 'RESIZE_NEAREST_NEIGHBOR' ],
  3461. [ 95, 'QUANTIZED_LSTM' ],
  3462. [ 96, 'IF' ],
  3463. [ 97, 'WHILE' ],
  3464. [ 98, 'ELU' ],
  3465. [ 99, 'HARD_SWISH' ],
  3466. [ 100, 'FILL' ],
  3467. [ 101, 'RANK' ],
  3468. ]);
  3469. for (let i = 0; i < operands.length; i++) {
  3470. const type = reader.int32();
  3471. operands[i] = {
  3472. type: types.has(type) ? types.get(type) : type,
  3473. dimensions: new Array(reader.uint32()),
  3474. scale: reader.float32(),
  3475. zero_point: reader.int32()
  3476. };
  3477. }
  3478. for (let i = 0; i < values.length; i++) {
  3479. values[i] = {
  3480. index: reader.int32(),
  3481. source_type: reader.int32(),
  3482. source_length: reader.uint32()
  3483. };
  3484. }
  3485. for (let i = 0; i < this.operations.length; i++) {
  3486. const operation_type = reader.int32();
  3487. this.operations[i] = {
  3488. operation_type: operations.has(operation_type) ? operations.get(operation_type) : operation_type,
  3489. inputs: new Array(reader.uint32()),
  3490. outputs: new Array(reader.uint32())
  3491. };
  3492. }
  3493. for (const operand of operands) {
  3494. for (let i = 0; i< operand.dimensions.length; i++) {
  3495. operand.dimensions[i] = reader.uint32();
  3496. }
  3497. }
  3498. for (const value of values) {
  3499. const operand = operands[value.index];
  3500. switch (value.source_type) {
  3501. case 0: { // immediate
  3502. switch (operand.type) {
  3503. case 'boolean':
  3504. operand.value = reader.byte() ? true : false;
  3505. reader.skip(3);
  3506. break;
  3507. case 'int32':
  3508. operand.value = reader.int32();
  3509. break;
  3510. case 'float32':
  3511. operand.value = reader.float32();
  3512. break;
  3513. case 'int32*':
  3514. operand.value = reader.read(value.source_length);
  3515. break;
  3516. default:
  3517. throw new pytorch.Error("Unsupported NNAPI operand type '" + operand.type.toString() + "'.");
  3518. }
  3519. break;
  3520. }
  3521. case 2: { // numbered buffer
  3522. if (value.source_length !== 12) {
  3523. throw new pytorch.Error('Invalid NNAPI numbered buffer source length.');
  3524. }
  3525. const number = reader.uint32();
  3526. const offset = reader.uint32();
  3527. const operand_length = reader.uint32();
  3528. operand.value = [ number, offset, operand_length ];
  3529. break;
  3530. }
  3531. case 3: { // numbered memory
  3532. throw new pytorch.Error('NNAPI numbered memory buffer not implemented.');
  3533. }
  3534. default: {
  3535. throw new pytorch.Error('Unsupported NNAPI value source type.');
  3536. }
  3537. }
  3538. }
  3539. for (const operation of this.operations) {
  3540. for (let i = 0; i< operation.inputs.length; i++) {
  3541. operation.inputs[i] = operands[reader.uint32()];
  3542. }
  3543. for (let i = 0; i< operation.outputs.length; i++) {
  3544. operation.outputs[i] = operands[reader.uint32()];
  3545. }
  3546. }
  3547. for (let i = 0; i< this.inputs.length; i++) {
  3548. this.inputs[i] = operands[reader.uint32()];
  3549. }
  3550. for (let i = 0; i< this.outputs.length; i++) {
  3551. this.outputs[i] = operands[reader.uint32()];
  3552. }
  3553. if (!reader.end()) {
  3554. throw new pytorch.Error('Invalid NNAPI serialized model length.');
  3555. }
  3556. }
  3557. };
  3558. pytorch.nnapi.SerializedModel.BinaryReader = class {
  3559. constructor(buffer) {
  3560. this._buffer = buffer;
  3561. this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  3562. this._position = 0;
  3563. }
  3564. end() {
  3565. return this._position >= this._buffer.length;
  3566. }
  3567. skip(offset) {
  3568. this._position += offset;
  3569. if (this._position > this._buffer.length) {
  3570. throw new pytorch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  3571. }
  3572. }
  3573. read(length) {
  3574. const position = this._position;
  3575. this.skip(length);
  3576. return this._buffer.subarray(position, this._position);
  3577. }
  3578. byte() {
  3579. const position = this._position;
  3580. this.skip(1);
  3581. return this._dataView.getUint8(position, true);
  3582. }
  3583. int32() {
  3584. const position = this._position;
  3585. this.skip(4);
  3586. return this._dataView.getInt32(position, true);
  3587. }
  3588. uint32() {
  3589. const position = this._position;
  3590. this.skip(4);
  3591. return this._dataView.getUint32(position, true);
  3592. }
  3593. int64() {
  3594. const value = this.int32();
  3595. if (this.int32() !== 0) {
  3596. throw new pytorch.Error('Invalid int64 value.');
  3597. }
  3598. return value;
  3599. }
  3600. float32() {
  3601. const position = this._position;
  3602. this.skip(4);
  3603. return this._dataView.getFloat32(position, true);
  3604. }
  3605. };
  3606. pytorch.Metadata = class {
  3607. static open(context) {
  3608. if (pytorch.Metadata._metadata) {
  3609. return Promise.resolve(pytorch.Metadata._metadata);
  3610. }
  3611. else {
  3612. return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
  3613. pytorch.Metadata._metadata = new pytorch.Metadata(data);
  3614. return pytorch.Metadata._metadata;
  3615. }).catch(() => {
  3616. pytorch.Metadata._metadata = new pytorch.Metadata(null);
  3617. return pytorch.Metadata._metadata;
  3618. });
  3619. }
  3620. }
  3621. constructor(data) {
  3622. this._map = new Map();
  3623. this._attributeCache = new Map();
  3624. if (data) {
  3625. const items = JSON.parse(data);
  3626. for (const item of items) {
  3627. this._map.set(item.name, item);
  3628. const index = item.name.indexOf(':');
  3629. if (index !== -1) {
  3630. const name = item.name.substring(0, index);
  3631. if (!this._map.has(name)) {
  3632. this._map.set(name, []);
  3633. }
  3634. this._map.get(name).push(item.name);
  3635. }
  3636. }
  3637. }
  3638. }
  3639. type(name) {
  3640. const schema = this._map.get(name);
  3641. if (schema) {
  3642. return Array.isArray(schema) ? schema.map((name) => this._map.get(name)) : schema;
  3643. }
  3644. return null;
  3645. }
  3646. attribute(type, name) {
  3647. const attributeName = type + ':' + name;
  3648. if (!this._attributeCache.has(attributeName)) {
  3649. this._attributeCache.set(attributeName, null);
  3650. const schema = this.type(type);
  3651. if (schema) {
  3652. if (schema.inputs) {
  3653. for (const input of schema.inputs) {
  3654. this._attributeCache.set(type + ':' + input.name, input);
  3655. }
  3656. }
  3657. if (schema.attributes) {
  3658. for (const attribute of schema.attributes) {
  3659. this._attributeCache.set(type + ':' + attribute.name, attribute);
  3660. }
  3661. }
  3662. }
  3663. }
  3664. return this._attributeCache.get(attributeName);
  3665. }
  3666. };
  3667. pytorch.Error = class extends Error {
  3668. constructor(message) {
  3669. super(message);
  3670. this.name = 'Error loading PyTorch model.';
  3671. }
  3672. };
  3673. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  3674. module.exports.ModelFactory = pytorch.ModelFactory;
  3675. }