pytorch.js 132 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151
  1. /* jshint esversion: 6 */
  2. /* eslint "indent": [ "error", 4, { "SwitchCase": 1 } ] */
  3. // Experimental
  4. var pytorch = pytorch || {};
  5. var base = base || require('./base');
  6. var long = long || { Long: require('long') };
  7. pytorch.ModelFactory = class {
  8. match(context) {
  9. const identifier = context.identifier;
  10. const extension = identifier.split('.').pop().toLowerCase();
  11. if (extension === 'pth' || extension === 'pkl' || extension === 'pt' || extension === 'bin' ||
  12. extension === 'h5' || extension === 't7' || extension === 'dms' || extension === 'model' ||
  13. extension === 'ckpt' || extension == 'pt1' || identifier.toLowerCase().endsWith('.pth.tar')) {
  14. if (pytorch.Container.open(context)) {
  15. return true;
  16. }
  17. }
  18. return false;
  19. }
  20. open(context, host) {
  21. const identifier = context.identifier;
  22. return host.require('./pickle').then((pickle) => {
  23. return host.require('./python').then((python) => {
  24. return pytorch.Metadata.open(host).then((metadata) => {
  25. try {
  26. const container = pytorch.Container.open(context, metadata, pickle, python, (error, fatal) => {
  27. const message = error && error.message ? error.message : error.toString();
  28. host.exception(new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'."), fatal);
  29. });
  30. return new pytorch.Model(metadata, container);
  31. }
  32. catch (error) {
  33. host.exception(error, false);
  34. const message = error && error.message ? error.message : error.toString();
  35. throw new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  36. }
  37. });
  38. });
  39. });
  40. }
  41. };
  42. pytorch.Model = class {
  43. constructor(metadata, container) {
  44. this._format = container.format;
  45. this._producer = container.producer || '';
  46. this._graphs = [ new pytorch.Graph(metadata, container) ];
  47. }
  48. get format() {
  49. return this._format;
  50. }
  51. get graphs() {
  52. return this._graphs;
  53. }
  54. };
  55. pytorch.Graph = class {
  56. constructor(metadata, container) {
  57. this._nodes = [];
  58. this._inputs = [];
  59. this._outputs = [];
  60. this._groups = true;
  61. this._littleEndian = container.littleEndian;
  62. if (container.format.startsWith('TorchScript ')) {
  63. this._name = container.name;
  64. let traced = container.trace();
  65. let initializers = new Map();
  66. if (container.data) {
  67. let queue = [ container.data ];
  68. while (queue.length > 0) {
  69. let module = queue.shift();
  70. for (const key of Object.keys(module)) {
  71. if (key !== '__module__' && key !== '__name__' && key !== '__parent__') {
  72. let obj = module[key];
  73. if (!Array.isArray(obj) && obj === Object(obj)) {
  74. if (pytorch.Utility.isTensor(obj)) {
  75. let parameter = obj;
  76. parameter.__parent__ = module;
  77. if (!parameter.initializer && parameter.storage) {
  78. parameter.initializer = new pytorch.Tensor(parameter.name, parameter, true);
  79. }
  80. if (parameter.__variable__ && parameter.__count__ === 1) {
  81. initializers.set(parameter.__variable__, parameter);
  82. }
  83. }
  84. else if (obj && obj.__module__ && obj.__name__) {
  85. obj.__parent__ = module;
  86. if (!obj.__id__) {
  87. obj.__id__ = key;
  88. }
  89. queue.push(obj);
  90. }
  91. }
  92. }
  93. }
  94. }
  95. }
  96. if (traced) {
  97. if (container.inputs) {
  98. for (const input of container.inputs) {
  99. this._inputs.push(new pytorch.Parameter(input, true, [
  100. new pytorch.Argument(input, null, null)
  101. ]));
  102. }
  103. }
  104. if (container.outputs) {
  105. for (const output of container.outputs) {
  106. this._outputs.push(new pytorch.Parameter(output, true, [
  107. new pytorch.Argument(output, null, null)
  108. ]));
  109. }
  110. }
  111. if (container.nodes) {
  112. for (const node of container.nodes) {
  113. const item = {
  114. type: node.type,
  115. node: node
  116. };
  117. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  118. }
  119. }
  120. }
  121. if (container.data) {
  122. this._loadScriptModule(metadata, container, container.data, initializers);
  123. }
  124. }
  125. else if (container.data) {
  126. const data = container.data;
  127. this._type = (data.__module__ && data.__name__) ? (data.__module__ + '.' + data.__name__) : '';
  128. const input = 'data';
  129. this._inputs.push(new pytorch.Parameter(input, true, [ new pytorch.Argument(input, null, null) ]));
  130. const outputs = this._loadModule(metadata, container.data, [], [ input ]);
  131. for (const output of outputs) {
  132. this._outputs.push(new pytorch.Parameter(output, true, [ new pytorch.Argument(output, null, null) ]));
  133. }
  134. }
  135. else if (container.state) {
  136. for (const state_group of container.state) {
  137. const attributes = state_group.attributes || [];
  138. const inputs = state_group.states.map((state) => {
  139. const tensor = new pytorch.Tensor(state.id, state.value, this._littleEndian);
  140. const visible = state_group.states.length === 0 || tensor.type.toString() !== 'int64' || tensor.value < 1000;
  141. return new pytorch.Parameter(state.name, visible, [
  142. new pytorch.Argument(state.id, null, tensor)
  143. ]);
  144. });
  145. const obj = {
  146. name: state_group.name,
  147. type: 'torch.nn.Module',
  148. attributes: attributes,
  149. inputs: inputs,
  150. outputs: []
  151. };
  152. this._nodes.push(new pytorch.Node(metadata, '', obj, null));
  153. }
  154. }
  155. }
  156. _loadModule(metadata, parent, groups, inputs) {
  157. if (parent.__module__ &&
  158. !parent.__module__ === 'torch.nn.modules.container' &&
  159. (!parent._modules || parent._modules.length == 0)) {
  160. this._createNode(groups, '', parent, inputs);
  161. return [];
  162. }
  163. if (!parent._modules) {
  164. throw new pytorch.Error('Module does not contain modules.');
  165. }
  166. for (const module of parent._modules) {
  167. if (module && module.value) {
  168. const type = module.value.__module__ + '.' + module.value.__name__;
  169. switch (type) {
  170. case 'torch.nn.modules.container.Sequential':
  171. groups.push(module.key);
  172. inputs = this._loadModule(metadata, module.value, groups, inputs);
  173. groups.pop(module.key);
  174. break;
  175. case 'torchvision.models.densenet._Transition':
  176. case 'torchvision.models.resnet.Bottleneck':
  177. case 'torchvision.models.densenet._DenseBlock':
  178. case 'torchvision.models.densenet._DenseLayer':
  179. case 'torchvision.models.inception.BasicConv2d':
  180. case 'torchvision.models.inception.InceptionAux':
  181. case 'torchvision.models.inception.InceptionA':
  182. case 'torchvision.models.inception.InceptionB':
  183. case 'torchvision.models.inception.InceptionC':
  184. case 'torchvision.models.inception.InceptionD':
  185. case 'torchvision.models.inception.InceptionE': {
  186. groups.push(module.key);
  187. const node = this._createNode(metadata, groups, module.key, module.value, inputs, this._littleEndian);
  188. inputs = [ node.name ];
  189. groups.pop(module.key);
  190. break;
  191. }
  192. default: {
  193. const node = this._createNode(metadata, groups, module.key, module.value, inputs);
  194. inputs = [ node.name ];
  195. break;
  196. }
  197. }
  198. }
  199. }
  200. return inputs;
  201. }
  202. _createNode(metadata, groups, key, obj, args) {
  203. const type = obj.__module__ + '.' + obj.__name__;
  204. const schema = metadata.type(type);
  205. let inputSchema = [ { name: 'input'} ];
  206. if (schema && schema.inputs && schema.inputs.length > 0) {
  207. inputSchema = schema.inputs.slice();
  208. }
  209. let inputs = [];
  210. inputs.push(new pytorch.Parameter(inputSchema.shift().name, true, args.map((argument) => {
  211. return new pytorch.Argument(argument, null, null);
  212. })));
  213. let parameters = [];
  214. if (obj._parameters) {
  215. parameters = parameters.concat(obj._parameters);
  216. }
  217. if (obj._buffers) {
  218. parameters = parameters.concat(obj._buffers);
  219. }
  220. for (const parameter of parameters) {
  221. let visible = true;
  222. let inputName = '';
  223. if (inputSchema.length > 0) {
  224. const input = inputSchema.shift();
  225. inputName = input.name;
  226. visible = input.visible === false ? false : true;
  227. }
  228. if (parameter && parameter.value && (parameter.value.data || parameter.value.storage)) {
  229. let initializer = null;
  230. if (parameter.value.data) {
  231. initializer = new pytorch.Tensor('', parameter.value.data, this._littleEndian);
  232. }
  233. else if (parameter.value.storage) {
  234. initializer = new pytorch.Tensor('', parameter.value, this._littleEndian);
  235. }
  236. inputs.push(new pytorch.Parameter(inputName || parameter.key, visible, [ new pytorch.Argument('', null, initializer) ]));
  237. }
  238. }
  239. const group = groups.join('/');
  240. const name = group ? (group + '/' + key) : key;
  241. const outputs = [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ];
  242. let attributes = [];
  243. for (const name of Object.keys(obj)) {
  244. if (!name.startsWith('_')) {
  245. attributes.push({ name: name, value: obj[name] });
  246. }
  247. }
  248. const item = {
  249. name: name,
  250. type: type,
  251. attributes: attributes,
  252. inputs: inputs,
  253. outputs: outputs
  254. }
  255. const node = new pytorch.Node(metadata, group, item, {});
  256. this._nodes.push(node);
  257. return node;
  258. }
  259. _loadScriptModule(metadata, container, module, initializers) {
  260. if (module) {
  261. if (pytorch.Graph._getParameters(module).length > 0 && !module.__hide__) {
  262. const item = { module: module };
  263. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  264. }
  265. const submodules = pytorch.Graph._getSubmodules(module);
  266. for (const submodule of submodules) {
  267. this._loadScriptModule(metadata, container, submodule, initializers);
  268. }
  269. }
  270. }
  271. static _getParameters(module) {
  272. let parameters = [];
  273. if (module && module.__module__ && module.__name__) {
  274. for (const key of Object.keys(module)) {
  275. if (pytorch.Utility.isTensor(module[key])) {
  276. const parameter = module[key];
  277. parameter.__id__ = key;
  278. parameters.push(parameter);
  279. }
  280. }
  281. }
  282. return parameters;
  283. }
  284. static _getSubmodules(module) {
  285. let submodules = [];
  286. if (module && module.__module__ && module.__name__) {
  287. for (const key of Object.keys(module)) {
  288. if (!key.startsWith('__')) {
  289. const value = module[key];
  290. if (value && value.__module__ && value.__name__ && !pytorch.Utility.isTensor(value)) {
  291. submodules.push(value);
  292. }
  293. }
  294. }
  295. }
  296. return submodules;
  297. }
  298. get type() {
  299. return this._type;
  300. }
  301. get name() {
  302. return this._name;
  303. }
  304. get groups() {
  305. return this._groups;
  306. }
  307. get inputs() {
  308. return this._inputs;
  309. }
  310. get outputs() {
  311. return this._outputs;
  312. }
  313. get nodes() {
  314. return this._nodes;
  315. }
  316. };
  317. pytorch.Parameter = class {
  318. constructor(name, visible, args) {
  319. this._name = name;
  320. this._visible = visible;
  321. this._arguments = args;
  322. }
  323. get name() {
  324. return this._name;
  325. }
  326. get visible() {
  327. return this._visible;
  328. }
  329. get arguments() {
  330. return this._arguments;
  331. }
  332. };
  333. pytorch.Argument = class {
  334. constructor(id, type, initializer) {
  335. if (typeof id !== 'string') {
  336. throw new pytorch.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
  337. }
  338. this._id = id;
  339. this._type = type;
  340. this._initializer = initializer;
  341. }
  342. get id() {
  343. return this._id;
  344. }
  345. get type() {
  346. if (this._initializer) {
  347. return this._initializer.type;
  348. }
  349. return this._type;
  350. }
  351. get initializer() {
  352. return this._initializer;
  353. }
  354. };
  355. pytorch.Node = class {
  356. constructor(metadata, group, item, initializers) {
  357. this._metadata = metadata;
  358. this._group = group || '';
  359. this._name = item.name || '';
  360. if (!item.module && !item.node) {
  361. this._type = item.type;
  362. this._inputs = item.inputs;
  363. this._outputs = item.outputs;
  364. this._attributes = item.attributes.map((attribute) => {
  365. const schema = metadata.attribute(this._type, attribute.name);
  366. return new pytorch.Attribute(schema, attribute.name, attribute.value)
  367. });
  368. }
  369. else {
  370. this._attributes = [];
  371. this._inputs = [];
  372. this._outputs = [];
  373. let module = item.module;
  374. if (module) {
  375. this._type = 'torch.nn.modules.module.Module';
  376. for (const parameter of pytorch.Graph._getParameters(module)) {
  377. this._inputs.push(new pytorch.Parameter(parameter.__id__, true, [
  378. new pytorch.Argument('', null, parameter.initializer || null)
  379. ]));
  380. if (parameter.__variable__) {
  381. this._outputs.push(new pytorch.Parameter(parameter.__id__, true, [
  382. new pytorch.Argument(parameter.__variable__, null, null)
  383. ]));
  384. }
  385. }
  386. }
  387. if (item.node) {
  388. this._type = item.type;
  389. const schema = metadata.type(this._type);
  390. module = null;
  391. let match = true;
  392. let count = 0;
  393. for (const input of item.node.inputs) {
  394. for (const argument of input) {
  395. const parameter = initializers.get(argument.id);
  396. if (parameter) {
  397. if (parameter.__parent__ && (module == null || module == parameter.__parent__)) {
  398. module = parameter.__parent__;
  399. count++;
  400. }
  401. else {
  402. match = false;
  403. break;
  404. }
  405. }
  406. }
  407. if (!match) {
  408. break;
  409. }
  410. }
  411. if (module) {
  412. const params = pytorch.Graph._getParameters(module).filter((p) => p.__id__ !== 'num_batches_tracked');
  413. if (params.length == count && match) {
  414. module.__hide__ = true;
  415. for (const input of item.node.inputs) {
  416. for (const argument of input) {
  417. const parameter = initializers.get(argument.id);
  418. if (parameter && parameter.initializer) {
  419. argument.initializer = parameter.initializer;
  420. }
  421. }
  422. }
  423. }
  424. else {
  425. module = null;
  426. }
  427. }
  428. for (let inputIndex = 0; inputIndex < item.node.inputs.length; inputIndex++) {
  429. let inputName = inputIndex.toString();
  430. if (schema && schema.inputs && schema.inputs.length > inputIndex) {
  431. inputName = schema.inputs[inputIndex].name;
  432. }
  433. this._inputs.push(new pytorch.Parameter(inputName, true,
  434. item.node.inputs[inputIndex].map((input) => new pytorch.Argument(input.id, null, input.initializer || null))
  435. ));
  436. }
  437. for (let outputIndex = 0; outputIndex < item.node.outputs.length; outputIndex++) {
  438. let outputName = outputIndex.toString();
  439. if (schema && schema.outputs && schema.outputs.length > outputIndex) {
  440. outputName = schema.outputs[outputIndex].name;
  441. }
  442. this._outputs.push(new pytorch.Parameter(outputName, true, [
  443. new pytorch.Argument(item.node.outputs[outputIndex], null, null)
  444. ]));
  445. }
  446. for (let i = 0; i < item.node.attributes.length; i++) {
  447. let attributeSchema = null;
  448. let name = i.toString();
  449. let value = item.node.attributes[i];
  450. if (value && value.type === '=' && value.target.type == 'id') {
  451. name = value.target.value;
  452. value = value.expression;
  453. attributeSchema = metadata.attribute(this._type, name);
  454. }
  455. else if (schema && schema.attributes && schema.attributes.length > i) {
  456. attributeSchema = schema.attributes[i];
  457. name = attributeSchema.name;
  458. }
  459. this._attributes.push(new pytorch.Attribute(attributeSchema, name, value));
  460. }
  461. }
  462. if (module) {
  463. if (module.__id__) {
  464. let current = module;
  465. this._name = current.__id__;
  466. while (current.__parent__ != null) {
  467. current = current.__parent__;
  468. if (!current.__parent__ && !current.__id__) {
  469. break;
  470. }
  471. this._name = [ current.__id__, this._name ].join('.')
  472. }
  473. }
  474. }
  475. }
  476. }
  477. get name() {
  478. return this._name;
  479. }
  480. get group() {
  481. return this._group;
  482. }
  483. get operator() {
  484. const index = this._type.indexOf(':');
  485. return index === -1 ? this._type : this._type.substring(0, index);
  486. }
  487. get metadata() {
  488. return this._metadata.type(this._type);
  489. }
  490. get function() {
  491. return this._type.startsWith('torch.nn.modules.') && this._type !== 'torch.nn.modules.module.Module';
  492. }
  493. get attributes() {
  494. return this._attributes;
  495. }
  496. get inputs() {
  497. return this._inputs;
  498. }
  499. get outputs() {
  500. return this._outputs;
  501. }
  502. };
  503. pytorch.Attribute = class {
  504. constructor(schema, name, value) {
  505. this._name = name;
  506. this._value = value;
  507. if (this._name === 'training') {
  508. this._visible = false;
  509. this._type = 'boolean';
  510. return;
  511. }
  512. if (value && value.type) {
  513. switch (value.type) {
  514. case 'number':
  515. this._value = value.value;
  516. break;
  517. case 'string':
  518. this._value = value.value;
  519. break;
  520. case 'boolean':
  521. this._value = value.value;
  522. break;
  523. case 'id':
  524. this._value = value.value;
  525. break;
  526. }
  527. }
  528. if (schema) {
  529. if (Object.prototype.hasOwnProperty.call(schema, 'type')) {
  530. this._type = schema.type;
  531. }
  532. switch (this._type) {
  533. case 'boolean':
  534. if (this._value == 'False') {
  535. this._value = false;
  536. }
  537. else if (this._value == 'True') {
  538. this._value = true;
  539. }
  540. break;
  541. case 'int32':
  542. case 'int64':
  543. if (typeof this._value !== 'number') {
  544. if (typeof this._value === 'string') {
  545. this._value = parseInt(this._value, 10);
  546. }
  547. }
  548. break;
  549. case 'float32':
  550. case 'float64':
  551. if (typeof this._value !== 'number') {
  552. if (typeof this._value === 'string') {
  553. this._value = parseFloat(this._value);
  554. }
  555. }
  556. break;
  557. case 'int32[]':
  558. case 'int64[]': {
  559. switch (this._value.type) {
  560. case 'list':
  561. this._value = this._value.value.map((item) => {
  562. if (item.type === 'number') {
  563. let number = parseInt(item.value, 10);
  564. if (!Number.isNaN(item.value - number)) {
  565. return number;
  566. }
  567. }
  568. return item;
  569. });
  570. break;
  571. }
  572. break;
  573. }
  574. }
  575. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  576. this._visible = false;
  577. }
  578. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  579. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  580. this._visible = false;
  581. }
  582. else if (Array.isArray(this._value) &&
  583. !Array.isArray(schema.default) &&
  584. this.value.every((item) => item == schema.default)) {
  585. this._visible = false;
  586. }
  587. }
  588. }
  589. if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__module__ && obj.__module__.startsWith('torch.nn'))) {
  590. this._value = '?';
  591. }
  592. }
  593. get type() {
  594. return this._type;
  595. }
  596. get name() {
  597. return this._name;
  598. }
  599. get value() {
  600. return this._value;
  601. }
  602. get visible() {
  603. return this._visible == false ? false : true;
  604. }
  605. };
  606. pytorch.Tensor = class {
  607. constructor(name, tensor, littleEndian) {
  608. this._name = name || '';
  609. this._type = new pytorch.TensorType(tensor.storage.dataType, new pytorch.TensorShape(tensor.size));
  610. this._data = tensor.storage.data;
  611. this._littleEndian = littleEndian;
  612. }
  613. get kind() {
  614. return 'Tensor';
  615. }
  616. get name() {
  617. return this._name;
  618. }
  619. get type() {
  620. return this._type;
  621. }
  622. get state() {
  623. return this._context().state;
  624. }
  625. get value() {
  626. let context = this._context();
  627. if (context.state) {
  628. return null;
  629. }
  630. context.limit = Number.MAX_SAFE_INTEGER;
  631. return this._decode(context, 0);
  632. }
  633. toString() {
  634. let context = this._context();
  635. if (context.state) {
  636. return '';
  637. }
  638. context.limit = 10000;
  639. const value = this._decode(context, 0);
  640. return pytorch.Tensor._stringify(value, '', ' ');
  641. }
  642. _context() {
  643. let context = {};
  644. context.state = null;
  645. context.index = 0;
  646. context.count = 0;
  647. if (!this._type.dataType) {
  648. context.state = 'Tensor has no data type.';
  649. return context;
  650. }
  651. switch (this._type.dataType) {
  652. case 'uint8':
  653. case 'qint8':
  654. case 'int8':
  655. case 'int16':
  656. case 'int32':
  657. case 'int64':
  658. case 'float16':
  659. case 'float32':
  660. case 'float64':
  661. break;
  662. default:
  663. context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
  664. return context;
  665. }
  666. if (!this._type.shape) {
  667. context.state = 'Tensor has no dimensions.';
  668. return context;
  669. }
  670. if (!this._data) {
  671. context.state = 'Tensor data is empty.';
  672. return context;
  673. }
  674. context.data = this._data;
  675. context.dataType = this._type.dataType;
  676. context.dimensions = this._type.shape.dimensions;
  677. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  678. return context;
  679. }
  680. _decode(context, dimension) {
  681. let results = [];
  682. const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions;
  683. const size = dimensions[dimension];
  684. if (dimension == dimensions.length - 1) {
  685. for (let i = 0; i < size; i++) {
  686. if (context.count > context.limit) {
  687. results.push('...');
  688. return results;
  689. }
  690. switch (context.dataType) {
  691. case 'uint8':
  692. results.push(context.dataView.getUint8(context.index, this._littleEndian));
  693. context.index++;
  694. context.count++;
  695. break;
  696. case 'qint8':
  697. case 'int8':
  698. results.push(context.dataView.getInt8(context.index, this._littleEndian));
  699. context.index++;
  700. context.count++;
  701. break;
  702. case 'int16':
  703. results.push(context.dataView.getInt16(context.index, this._littleEndian));
  704. context.index += 2;
  705. context.count++;
  706. break;
  707. case 'int32':
  708. results.push(context.dataView.getInt32(context.index, this._littleEndian));
  709. context.index += 4;
  710. context.count++;
  711. break;
  712. case 'int64':
  713. results.push(new long.Long(context.dataView.getUint32(context.index, true), context.dataView.getUint32(context.index + 4, true), false));
  714. context.index += 8;
  715. context.count++;
  716. break;
  717. case 'float16':
  718. results.push(context.dataView.getFloat16(context.index, this._littleEndian));
  719. context.index += 2;
  720. context.count++;
  721. break;
  722. case 'float32':
  723. results.push(context.dataView.getFloat32(context.index, this._littleEndian));
  724. context.index += 4;
  725. context.count++;
  726. break;
  727. case 'float64':
  728. results.push(context.dataView.getFloat64(context.index, this._littleEndian));
  729. context.index += 8;
  730. context.count++;
  731. break;
  732. }
  733. }
  734. }
  735. else {
  736. for (let j = 0; j < size; j++) {
  737. if (context.count > context.limit) {
  738. results.push('...');
  739. return results;
  740. }
  741. results.push(this._decode(context, dimension + 1));
  742. }
  743. }
  744. if (context.dimensions.length == 0) {
  745. return results[0];
  746. }
  747. return results;
  748. }
  749. static _stringify(value, indentation, indent) {
  750. if (Array.isArray(value)) {
  751. let result = [];
  752. result.push(indentation + '[');
  753. const items = value.map((item) => pytorch.Tensor._stringify(item, indentation + indent, indent));
  754. if (items.length > 0) {
  755. result.push(items.join(',\n'));
  756. }
  757. result.push(indentation + ']');
  758. return result.join('\n');
  759. }
  760. if (value && long.Long.isLong(value)) {
  761. return indentation + value.toString();
  762. }
  763. if (typeof value == 'string') {
  764. return indentation + value;
  765. }
  766. if (value == Infinity) {
  767. return indentation + 'Infinity';
  768. }
  769. if (value == -Infinity) {
  770. return indentation + '-Infinity';
  771. }
  772. if (isNaN(value)) {
  773. return indentation + 'NaN';
  774. }
  775. return indentation + value.toString();
  776. }
  777. };
  778. pytorch.TensorType = class {
  779. constructor(dataType, shape) {
  780. this._dataType = dataType;
  781. this._shape = shape;
  782. }
  783. get dataType() {
  784. return this._dataType;
  785. }
  786. get shape() {
  787. return this._shape;
  788. }
  789. toString() {
  790. return this._dataType + this._shape.toString();
  791. }
  792. };
  793. pytorch.TensorShape = class {
  794. constructor(dimensions) {
  795. this._dimensions = dimensions || [];
  796. }
  797. get dimensions() {
  798. return this._dimensions;
  799. }
  800. toString() {
  801. if (this._dimensions && this._dimensions.length > 0) {
  802. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  803. }
  804. return '';
  805. }
  806. };
  807. pytorch.Metadata = class {
  808. static open(host) {
  809. if (pytorch.Metadata._metadata) {
  810. return Promise.resolve(pytorch.Metadata._metadata);
  811. }
  812. else {
  813. return host.request(null, 'pytorch-metadata.json', 'utf-8').then((data) => {
  814. pytorch.Metadata._metadata = new pytorch.Metadata(data);
  815. return pytorch.Metadata._metadata;
  816. }).catch(() => {
  817. pytorch.Metadata._metadata = new pytorch.Metadata(null);
  818. return pytorch.Metadata._metadata;
  819. });
  820. }
  821. }
  822. constructor(data) {
  823. this._map = new Map();
  824. this._attributeCache = new Map();
  825. if (data) {
  826. const items = JSON.parse(data);
  827. if (items) {
  828. for (const item of items) {
  829. if (item.name && item.schema) {
  830. item.schema.name = item.name;
  831. this._map.set(item.name, item.schema);
  832. }
  833. const index = item.name.indexOf(':');
  834. if (index !== -1) {
  835. const name = item.name.substring(0, index);
  836. if (!this._map.has(name)) {
  837. this._map.set(name, [])
  838. }
  839. this._map.get(name).push(item.name);
  840. }
  841. }
  842. }
  843. }
  844. }
  845. type(operator) {
  846. const schema = this._map.get(operator);
  847. if (schema) {
  848. return Array.isArray(schema) ? schema.map((name) => this._map.get(name)) : schema;
  849. }
  850. return null;
  851. }
  852. attribute(operator, name) {
  853. const attributeName = operator + ':' + name;
  854. if (!this._attributeCache.has(attributeName)) {
  855. this._attributeCache.set(attributeName, null);
  856. const schema = this.type(operator);
  857. if (schema && schema.attributes) {
  858. for (const attribute of schema.attributes) {
  859. this._attributeCache.set(operator + ':' + attribute.name, attribute);
  860. }
  861. }
  862. }
  863. return this._attributeCache.get(attributeName);
  864. }
  865. };
  866. pytorch.Error = class extends Error {
  867. constructor(message) {
  868. super(message);
  869. this.name = 'Error loading PyTorch model.';
  870. }
  871. };
  872. pytorch.Execution = class {
  873. constructor(python, sources, exceptionCallback) {
  874. const self = this;
  875. this._python = python;
  876. this._sources = sources;
  877. this._exceptionCallback = exceptionCallback;
  878. this._utf8Decoder = new TextDecoder('utf-8');
  879. this._unknownNameMap = new Set();
  880. this._knownPackageMap = new Set([ 'torch', 'torchvision', 'collections', '__builtin__', '_codecs', 'argparse', 'numpy' ]);
  881. this._packages = new Map();
  882. this._context = new pytorch.Execution.Context();
  883. this._context.scope.builtins = {};
  884. this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' };
  885. this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type };
  886. this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__:this._context.scope.builtins.type };
  887. this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__:this._context.scope.builtins.type };
  888. this._registerConstructor('argparse.Namespace', function (args) {
  889. this.args = args;
  890. });
  891. this._registerConstructor('torch.autograd.variable.Variable', function() {});
  892. this._registerConstructor('torch.backends.cudnn.rnn.Unserializable', function() {});
  893. this._registerConstructor('torch.device', function(type, index) {
  894. this.type = type;
  895. this.index = index;
  896. });
  897. this._registerConstructor('torch.distributions.multivariate_normal.MultivariateNormal', function() {});
  898. this._registerConstructor('torch.nn.backends.thnn._get_thnn_function_backend', function() {});
  899. this._registerConstructor('torch.nn.intrinsic.modules.fused.ConvReLU2d', function() {});
  900. this._registerConstructor('torch.nn.modules.activation.CELU', function() {});
  901. this._registerConstructor('torch.nn.modules.activation.ELU', function() {});
  902. this._registerConstructor('torch.nn.modules.activation.GLU', function() {});
  903. this._registerConstructor('torch.nn.modules.activation.Hardtanh', function() {});
  904. this._registerConstructor('torch.nn.modules.activation.LeakyReLU', function() {});
  905. this._registerConstructor('torch.nn.modules.activation.LogSigmoid', function() {});
  906. this._registerConstructor('torch.nn.modules.activation.LogSoftmax', function() {});
  907. this._registerConstructor('torch.nn.modules.activation.MultiheadAttention', function() {});
  908. this._registerConstructor('torch.nn.modules.activation.ReLU', function() {});
  909. this._registerConstructor('torch.nn.modules.activation.ReLU6', function() {});
  910. this._registerConstructor('torch.nn.modules.activation.PReLU', function() {});
  911. this._registerConstructor('torch.nn.modules.activation.SELU', function() {});
  912. this._registerConstructor('torch.nn.modules.activation.Sigmoid', function() {});
  913. this._registerConstructor('torch.nn.modules.activation.Softmax', function() {});
  914. this._registerConstructor('torch.nn.modules.activation.Softmax2d', function() {});
  915. this._registerConstructor('torch.nn.modules.activation.Softplus', function() {});
  916. this._registerConstructor('torch.nn.modules.activation.Tanh', function() {});
  917. this._registerConstructor('torch.nn.modules.activation.Threshold', function() {});
  918. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm1d', function() {});
  919. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm2d', function() {});
  920. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm3d', function() {});
  921. this._registerConstructor('torch.nn.modules.batchnorm.SyncBatchNorm', function() {});
  922. this._registerConstructor('torch.nn.modules.container.ModuleDict', function() {});
  923. this._registerConstructor('torch.nn.modules.container.ModuleList', function() {});
  924. this._registerConstructor('torch.nn.modules.container.ParameterList', function() {});
  925. this._registerConstructor('torch.nn.modules.container.Sequential', function() {});
  926. this._registerConstructor('torch.nn.modules.conv.Conv1d', function() {});
  927. this._registerConstructor('torch.nn.modules.conv.Conv2d', function() {});
  928. this._registerConstructor('torch.nn.modules.conv.Conv3d', function() {});
  929. this._registerConstructor('torch.nn.modules.conv.ConvTranspose1d', function() {});
  930. this._registerConstructor('torch.nn.modules.conv.ConvTranspose2d', function() {});
  931. this._registerConstructor('torch.nn.modules.conv.ConvTranspose3d', function() {});
  932. this._registerConstructor('torch.nn.modules.distance.CosineSimilarity', function() {});
  933. this._registerConstructor('torch.nn.modules.dropout.Dropout', function() {});
  934. this._registerConstructor('torch.nn.modules.dropout.Dropout2d', function() {});
  935. this._registerConstructor('torch.nn.modules.dropout.Dropout3d', function() {});
  936. this._registerConstructor('torch.nn.modules.fold.Unfold', function() {});
  937. this._registerConstructor('torch.nn.modules.flatten.Flatten', function() {});
  938. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm1d', function() {});
  939. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm2d', function() {});
  940. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm3d', function() {});
  941. this._registerConstructor('torch.nn.modules.linear.Linear', function() {});
  942. this._registerConstructor('torch.nn.modules.linear.Identity', function() {});
  943. this._registerConstructor('torch.nn.modules.loss.BCELoss', function() {});
  944. this._registerConstructor('torch.nn.modules.loss.BCEWithLogitsLoss', function() {});
  945. this._registerConstructor('torch.nn.modules.loss.CrossEntropyLoss', function() {});
  946. this._registerConstructor('torch.nn.modules.loss.L1Loss', function() {});
  947. this._registerConstructor('torch.nn.modules.loss.MSELoss', function() {});
  948. this._registerConstructor('torch.nn.modules.loss.NLLLoss', function() {});
  949. this._registerConstructor('torch.nn.modules.loss.SmoothL1Loss', function() {});
  950. this._registerConstructor('torch.nn.modules.normalization.CrossMapLRN2d', function() {});
  951. this._registerConstructor('torch.nn.modules.normalization.GroupNorm', function() {});
  952. this._registerConstructor('torch.nn.modules.normalization.LayerNorm', function() {});
  953. this._registerConstructor('torch.nn.modules.normalization.LocalResponseNorm', function() {});
  954. this._registerConstructor('torch.nn.modules.padding.ReflectionPad1d', function() {});
  955. this._registerConstructor('torch.nn.modules.padding.ReflectionPad2d', function() {});
  956. this._registerConstructor('torch.nn.modules.padding.ReplicationPad1d', function() {});
  957. this._registerConstructor('torch.nn.modules.padding.ReplicationPad2d', function() {});
  958. this._registerConstructor('torch.nn.modules.padding.ReplicationPad3d', function() {});
  959. this._registerConstructor('torch.nn.modules.padding.ZeroPad2d', function() {});
  960. this._registerConstructor('torch.nn.modules.padding.ConstantPad1d', function() {});
  961. this._registerConstructor('torch.nn.modules.padding.ConstantPad2d', function() {});
  962. this._registerConstructor('torch.nn.modules.padding.ConstantPad3d', function() {});
  963. this._registerConstructor('torch.nn.modules.pixelshuffle.PixelShuffle', function() {});
  964. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool1d', function() {});
  965. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool2d', function() {});
  966. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool3d', function() {});
  967. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool1d', function() {});
  968. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool2d', function() {});
  969. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool3d', function() {});
  970. this._registerConstructor('torch.nn.modules.pooling.AvgPool1d', function() {});
  971. this._registerConstructor('torch.nn.modules.pooling.AvgPool2d', function() {});
  972. this._registerConstructor('torch.nn.modules.pooling.AvgPool3d', function() {});
  973. this._registerConstructor('torch.nn.modules.pooling.FractionalMaxPool2d', function() {});
  974. this._registerConstructor('torch.nn.modules.pooling.MaxPool1d', function() {});
  975. this._registerConstructor('torch.nn.modules.pooling.MaxPool2d', function() {});
  976. this._registerConstructor('torch.nn.modules.pooling.MaxPool3d', function() {});
  977. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool1d', function() {});
  978. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool2d', function() {});
  979. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool3d', function() {});
  980. this._registerConstructor('torch.nn.modules.rnn.GRU', function() {});
  981. this._registerConstructor('torch.nn.modules.rnn.GRUCell', function() {});
  982. this._registerConstructor('torch.nn.modules.rnn.LSTM', function() {});
  983. this._registerConstructor('torch.nn.modules.rnn.LSTMCell', function() {});
  984. this._registerConstructor('torch.nn.modules.rnn.RNN', function() {});
  985. this._registerConstructor('torch.nn.modules.sparse.Embedding', function() {});
  986. this._registerConstructor('torch.nn.modules.transformer.TransformerEncoder', function() {});
  987. this._registerConstructor('torch.nn.modules.transformer.TransformerEncoderLayer', function() {});
  988. this._registerConstructor('torch.nn.modules.upsampling.Upsample', function() {});
  989. this._registerConstructor('torch.nn.modules.upsampling.UpsamplingBilinear2d', function() {});
  990. this._registerConstructor('torch.nn.modules.upsampling.UpsamplingNearest2d', function() {});
  991. this._registerConstructor('torch.nn.parallel.data_parallel.DataParallel', function() {});
  992. this._registerConstructor('torch.nn.parallel.distributed.DistributedDataParallel', function() {});
  993. this._registerConstructor('torch.nn.parameter.Parameter', function(data, requires_grad) {
  994. this.data = data;
  995. this.requires_grad = requires_grad;
  996. });
  997. this._registerConstructor('torch.nn.quantized.modules.functional_modules.FloatFunctional', function() {});
  998. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNorm', function() {});
  999. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNormStateDictHook', function() {});
  1000. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', function() {});
  1001. this._registerConstructor('torch.nn.utils.weight_norm.WeightNorm', function() {});
  1002. this._registerConstructor('torch.optim.adam.Adam', function() {});
  1003. this._registerConstructor('torch.optim.adagrad.Adagrad', function() {});
  1004. this._registerConstructor('torch.optim.lr_scheduler.MultiStepLR', function() {});
  1005. this._registerConstructor('torch.optim.lr_scheduler.StepLR', function() {});
  1006. this._registerConstructor('torch.optim.rmsprop.RMSprop', function() {});
  1007. this._registerConstructor('torch.optim.sgd.SGD', function() {});
  1008. this._registerConstructor('torch.quantization.stubs.DeQuantStub', function() {});
  1009. this._registerConstructor('torch.quantization.stubs.QuantStub', function() {});
  1010. this._registerConstructor('torchvision.datasets.folder.ImageFolder', function() {});
  1011. this._registerConstructor('torchvision.models.alexnet.AlexNet', function() {});
  1012. this._registerConstructor('torchvision.models.densenet.DenseNet', function() {});
  1013. this._registerConstructor('torchvision.models.densenet._DenseBlock', function() {});
  1014. this._registerConstructor('torchvision.models.densenet._DenseLayer', function() {});
  1015. this._registerConstructor('torchvision.models.densenet._Transition', function() {});
  1016. this._registerConstructor('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', function() {});
  1017. this._registerConstructor('torchvision.models.detection._utils.BoxCoder', function() {});
  1018. this._registerConstructor('torchvision.models.detection._utils.Matcher', function() {});
  1019. this._registerConstructor('torchvision.models.detection.backbone_utils.BackboneWithFPN', function() {});
  1020. this._registerConstructor('torchvision.models.detection.faster_rcnn.FasterRCNN', function() {});
  1021. this._registerConstructor('torchvision.models.detection.faster_rcnn.FastRCNNPredictor', function() {});
  1022. this._registerConstructor('torchvision.models.detection.faster_rcnn.TwoMLPHead', function() {});
  1023. this._registerConstructor('torchvision.models.detection.roi_heads.RoIHeads', function() {});
  1024. this._registerConstructor('torchvision.models.detection.rpn.AnchorGenerator', function() {});
  1025. this._registerConstructor('torchvision.models.detection.rpn.RegionProposalNetwork', function() {});
  1026. this._registerConstructor('torchvision.models.detection.rpn.RPNHead', function() {});
  1027. this._registerConstructor('torchvision.models.detection.transform.GeneralizedRCNNTransform', function() {});
  1028. this._registerConstructor('torchvision.models.googlenet.BasicConv2d', function() {});
  1029. this._registerConstructor('torchvision.models.googlenet.GoogLeNet', function() {});
  1030. this._registerConstructor('torchvision.models.googlenet.Inception', function() {});
  1031. this._registerConstructor('torchvision.models.inception.BasicConv2d', function() {});
  1032. this._registerConstructor('torchvision.models.inception.Inception3', function() {});
  1033. this._registerConstructor('torchvision.models.inception.InceptionAux', function() {});
  1034. this._registerConstructor('torchvision.models.inception.InceptionA', function() {});
  1035. this._registerConstructor('torchvision.models.inception.InceptionB', function() {});
  1036. this._registerConstructor('torchvision.models.inception.InceptionC', function() {});
  1037. this._registerConstructor('torchvision.models.inception.InceptionD', function() {});
  1038. this._registerConstructor('torchvision.models.inception.InceptionE', function() {});
  1039. this._registerConstructor('torchvision.models.mobilenet.ConvBNReLU', function() {});
  1040. this._registerConstructor('torchvision.models.mobilenet.MobileNetV2', function() {});
  1041. this._registerConstructor('torchvision.models.mobilenet.InvertedResidual', function() {});
  1042. this._registerConstructor('torchvision.models.resnet.Bottleneck', function() {});
  1043. this._registerConstructor('torchvision.models.resnet.BasicBlock', function() {});
  1044. this._registerConstructor('torchvision.models.quantization.resnet.QuantizableBottleneck', function() {});
  1045. this._registerConstructor('torchvision.models.quantization.resnet.QuantizableResNet', function() {});
  1046. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPP', function() {});
  1047. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPPConv', function() {});
  1048. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPPPooling', function() {});
  1049. this._registerConstructor('torchvision.models.segmentation.deeplabv3.DeepLabHead', function() {});
  1050. this._registerConstructor('torchvision.models.segmentation.deeplabv3.DeepLabV3', function() {});
  1051. this._registerConstructor('torchvision.models.segmentation.fcn.FCN', function() {});
  1052. this._registerConstructor('torchvision.models.segmentation.fcn.FCNHead', function() {});
  1053. this._registerConstructor('torchvision.models.shufflenetv2.ShuffleNetV2', function() {});
  1054. this._registerConstructor('torchvision.models.shufflenetv2.InvertedResidual', function() {});
  1055. this._registerConstructor('torchvision.models.squeezenet.Fire', function() {});
  1056. this._registerConstructor('torchvision.models.squeezenet.SqueezeNet', function() {});
  1057. this._registerConstructor('torchvision.models.resnet.ResNet', function() {});
  1058. this._registerConstructor('torchvision.models.vgg.VGG', function() {});
  1059. this._registerConstructor('torchvision.models.video.resnet.BasicBlock', function() {});
  1060. this._registerConstructor('torchvision.models.video.resnet.BasicStem', function() {});
  1061. this._registerConstructor('torchvision.models.video.resnet.Conv3DNoTemporal', function() {});
  1062. this._registerConstructor('torchvision.models.video.resnet.Conv3DSimple', function() {});
  1063. this._registerConstructor('torchvision.models.video.resnet.VideoResNet', function() {});
  1064. this._registerConstructor('torchvision.models._utils.IntermediateLayerGetter', function() {});
  1065. this._registerConstructor('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', function() {});
  1066. this._registerConstructor('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', function() {});
  1067. this._registerConstructor('torchvision.ops.misc.FrozenBatchNorm2d', function() {});
  1068. this._registerConstructor('torchvision.ops.poolers.LevelMapper', function() {});
  1069. this._registerConstructor('torchvision.ops.poolers.MultiScaleRoIAlign', function() {});
  1070. this._registerConstructor('torchvision.transforms.transforms.Compose', function() {});
  1071. this._registerConstructor('torchvision.transforms.transforms.Normalize', function() {});
  1072. this._registerConstructor('torchvision.transforms.transforms.Resize', function() {});
  1073. this._registerConstructor('torchvision.transforms.transforms.ToTensor', function() {});
  1074. this._registerConstructor('torch.ByteStorage', function (size) {
  1075. this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8';
  1076. });
  1077. this._registerConstructor('torch.CharStorage', function (size) {
  1078. this.size = size; this.dataTypeSize = 1; this.dataType = 'int8';
  1079. });
  1080. this._registerConstructor('torch.ShortStorage', function (size) {
  1081. this.size = size; this.dataTypeSize = 2; this.dataType = 'int16';
  1082. });
  1083. this._registerConstructor('torch.IntStorage', function (size) {
  1084. this.size = size; this.dataTypeSize = 4; this.dataType = 'int32';
  1085. });
  1086. this._registerConstructor('torch.LongStorage', function (size) {
  1087. this.size = size; this.dataTypeSize = 8; this.dataType = 'int64';
  1088. });
  1089. this._registerConstructor('torch.HalfStorage', function (size) {
  1090. this.size = size; this.dataTypeSize = 2; this.dataType = 'float16';
  1091. });
  1092. this._registerConstructor('torch.FloatStorage', function (size) {
  1093. this.size = size; this.dataTypeSize = 4; this.dataType = 'float32';
  1094. });
  1095. this._registerConstructor('torch.DoubleStorage', function (size) {
  1096. this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
  1097. });
  1098. this._registerConstructor('torch.QInt8Storage', function (size) {
  1099. this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
  1100. });
  1101. this._registerConstructor('torch.FloatTensor', function () {
  1102. this.__setstate__ = function(state) {
  1103. this.storage = state[0];
  1104. this.storage_offset = state[1];
  1105. this.size = state[2];
  1106. this.stride = state[3];
  1107. };
  1108. });
  1109. this._registerConstructor('torch.DoubleTensor', function () {
  1110. this.__setstate__ = function(state) {
  1111. this.storage = state[0];
  1112. this.storage_offset = state[1];
  1113. this.size = state[2];
  1114. this.stride = state[3];
  1115. };
  1116. });
  1117. this._registerConstructor('torch.cuda.FloatTensor', function () {
  1118. this.__setstate__ = function(state) {
  1119. this.storage = state[0];
  1120. this.storage_offset = state[1];
  1121. this.size = state[2];
  1122. this.stride = state[3];
  1123. };
  1124. });
  1125. this._registerConstructor('torch.cuda.DoubleTensor', function () {
  1126. this.__setstate__ = function(state) {
  1127. this.storage = state[0];
  1128. this.storage_offset = state[1];
  1129. this.size = state[2];
  1130. this.stride = state[3];
  1131. };
  1132. });
  1133. this._registerConstructor('numpy.dtype', function(obj, align, copy) {
  1134. switch (obj) {
  1135. case 'i1': this.name = 'int8'; this.itemsize = 1; break;
  1136. case 'i2': this.name = 'int16'; this.itemsize = 2; break;
  1137. case 'i4': this.name = 'int32'; this.itemsize = 4; break;
  1138. case 'i8': this.name = 'int64'; this.itemsize = 8; break;
  1139. case 'b1': this.name = 'uint8'; this.itemsize = 1; break;
  1140. case 'u1': this.name = 'uint8'; this.itemsize = 1; break;
  1141. case 'u2': this.name = 'uint16'; this.itemsize = 2; break;
  1142. case 'u4': this.name = 'uint32'; this.itemsize = 4; break;
  1143. case 'u8': this.name = 'uint64'; this.itemsize = 8; break;
  1144. case 'f4': this.name = 'float32'; this.itemsize = 4; break;
  1145. case 'f8': this.name = 'float64'; this.itemsize = 8; break;
  1146. default:
  1147. if (obj.startsWith('V')) {
  1148. this.itemsize = Number(obj.substring(1));
  1149. this.name = 'void' + (this.itemsize * 8).toString();
  1150. }
  1151. else if (obj.startsWith('O')) {
  1152. this.itemsize = Number(obj.substring(1));
  1153. this.name = 'object';
  1154. }
  1155. else if (obj.startsWith('S')) {
  1156. this.itemsize = Number(obj.substring(1));
  1157. this.name = 'string';
  1158. }
  1159. else if (obj.startsWith('U')) {
  1160. this.itemsize = Number(obj.substring(1));
  1161. this.name = 'string';
  1162. }
  1163. else if (obj.startsWith('M')) {
  1164. this.itemsize = Number(obj.substring(1));
  1165. this.name = 'datetime';
  1166. }
  1167. else {
  1168. throw new pytorch.Error("Unknown dtype '" + obj.toString() + "'.");
  1169. }
  1170. break;
  1171. }
  1172. this.align = align;
  1173. this.copy = copy;
  1174. this.__setstate__ = function(state) {
  1175. switch (state.length) {
  1176. case 8:
  1177. this.version = state[0];
  1178. this.byteorder = state[1];
  1179. this.subarray = state[2];
  1180. this.names = state[3];
  1181. this.fields = state[4];
  1182. this.elsize = state[5];
  1183. this.alignment = state[6];
  1184. this.int_dtypeflags = state[7];
  1185. break;
  1186. default:
  1187. throw new pytorch.Error("Unknown numpy.dtype setstate length '" + state.length.toString() + "'.");
  1188. }
  1189. };
  1190. });
  1191. this._registerConstructor('numpy.core.multiarray._reconstruct', function(subtype, shape, dtype) {
  1192. this.subtype = subtype;
  1193. this.shape = shape;
  1194. this.dtype = dtype;
  1195. this.__setstate__ = function(state) {
  1196. this.version = state[0];
  1197. this.shape = state[1];
  1198. this.typecode = state[2];
  1199. this.is_f_order = state[3];
  1200. this.rawdata = state[4];
  1201. };
  1202. this.__read__ = function(unpickler) {
  1203. let array = {};
  1204. const subtype = this.subtype.split('.');
  1205. array.__name__ = subtype.pop();
  1206. array.__module__ = subtype.join('.');
  1207. array.dtype = this.typecode;
  1208. array.shape = this.shape;
  1209. let size = array.dtype.itemsize;
  1210. for (let i = 0; i < array.shape.length; i++) {
  1211. size = size * array.shape[i];
  1212. }
  1213. if (typeof this.rawdata == 'string') {
  1214. array.data = unpickler.unescape(this.rawdata, size);
  1215. if (array.data.length != size) {
  1216. throw new pytorch.Error('Invalid string array data size.');
  1217. }
  1218. }
  1219. else {
  1220. array.data = this.rawdata;
  1221. if (array.data.length != size) {
  1222. // throw new pytorch.Error('Invalid array data size.');
  1223. }
  1224. }
  1225. return array;
  1226. };
  1227. });
  1228. this._registerFunction('__builtin__.bytearray', function(source, encoding /*, errors */) {
  1229. if (encoding === 'latin-1') {
  1230. let array = new Uint8Array(source.length);
  1231. for (let i = 0; i < source.length; i++) {
  1232. array[i] = source.charCodeAt(i);
  1233. }
  1234. return array;
  1235. }
  1236. throw new pytorch.Error("Unsupported bytearray encoding '" + JSON.stringify(encoding) + "'.");
  1237. });
  1238. this._registerFunction('collections.Counter', function(/* iterable */) {
  1239. return {};
  1240. });
  1241. this._registerFunction('collections.OrderedDict', function(args) {
  1242. let obj = [];
  1243. obj.__setitem__ = function(key, value) {
  1244. obj.push({ key: key, value: value });
  1245. };
  1246. if (args) {
  1247. for (const arg of args) {
  1248. obj.__setitem__(arg[0], arg[1]);
  1249. }
  1250. }
  1251. return obj;
  1252. });
  1253. this._registerFunction('numpy.core.multiarray.scalar', function(dtype, rawData) {
  1254. let data = rawData;
  1255. if (rawData.constructor !== Uint8Array) {
  1256. data = new Uint8Array(rawData.length);
  1257. for (let i = 0; i < rawData.length; i++) {
  1258. data[i] = rawData.charCodeAt(i);
  1259. }
  1260. }
  1261. let dataView = new DataView(data.buffer, data.byteOffset, data.byteLength);
  1262. switch (dtype.name) {
  1263. case 'float32':
  1264. return dataView.getFloat32(0, true);
  1265. case 'float64':
  1266. return dataView.getFloat64(0, true);
  1267. case 'uint8':
  1268. return dataView.getUint8(0, true);
  1269. case 'int8':
  1270. return dataView.getInt8(0, true);
  1271. case 'int16':
  1272. return dataView.getInt16(0, true);
  1273. case 'int32':
  1274. return dataView.getInt32(0, true);
  1275. case 'int64':
  1276. return new long.Long(dataView.getInt32(0, true), dataView.getInt32(4, true), false);
  1277. }
  1278. throw new pytorch.Error("Unknown scalar type '" + dtype.name + "'.");
  1279. });
  1280. this._registerFunction('_codecs.encode', function(obj /*, econding */) {
  1281. return obj;
  1282. });
  1283. this._registerFunction('collections.defaultdict', function(/* default_factory */) {
  1284. return {};
  1285. });
  1286. this._registerFunction('annotate', function(type, value) {
  1287. return value;
  1288. });
  1289. this._registerFunction('int', function(/* tensor */) {
  1290. return NaN; // TODO
  1291. });
  1292. this._registerFunction('float', function(/* tensor */) {
  1293. return NaN; // TODO
  1294. });
  1295. this._registerFunction('getattr', function(obj, name, defaultValue) {
  1296. if (Object.prototype.hasOwnProperty.call(obj, name)) {
  1297. return obj[name];
  1298. }
  1299. return defaultValue;
  1300. });
  1301. this._registerFunction('unchecked_cast', function(type, value) {
  1302. return value;
  1303. });
  1304. this._registerFunction('ops.prim.data', function(tensor) {
  1305. return tensor;
  1306. });
  1307. this._registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
  1308. return value;
  1309. });
  1310. this._registerFunction('ops.prim.NumToTensor', function(value) {
  1311. return { __module__: 'torch', __name__: 'Tensor', value: value }; // TODO
  1312. });
  1313. this._registerFunction('ops.prim.min', function(value) {
  1314. return Math.min.apply(null, value);
  1315. });
  1316. this._registerFunction('ops.prim.shape', function(value) {
  1317. return value.size;
  1318. });
  1319. this._registerFunction('ops.quantized.conv_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
  1320. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv_prepack' }; // TODO
  1321. });
  1322. this._registerFunction('ops.quantized.conv2d_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
  1323. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv2d_prepack' }; // TODO
  1324. });
  1325. this._registerFunction('ops.quantized.linear_prepack', function(/* weight, bias */) {
  1326. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.linear_prepack' }; // TODO
  1327. });
  1328. this._registerFunction('ops.prim.RaiseException', function(message) {
  1329. throw new pytorch.Error(message);
  1330. });
  1331. this._registerFunction('range', function(start, stop, step) {
  1332. if (start !== undefined && Number.isInteger(start) && stop === undefined && step === undefined) {
  1333. return Array(start).keys();
  1334. }
  1335. throw new pytorch.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
  1336. });
  1337. this._registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
  1338. return {
  1339. __module__: storage.__module__,
  1340. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1341. storage: storage,
  1342. storage_offset: storage_offset,
  1343. size: size,
  1344. stride: stride
  1345. };
  1346. });
  1347. this._registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
  1348. return {
  1349. __module__: storage.__module__,
  1350. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1351. storage: storage,
  1352. storage_offset: storage_offset,
  1353. size: size,
  1354. stride: stride,
  1355. requires_grad: requires_grad,
  1356. backward_hooks: backward_hooks
  1357. };
  1358. });
  1359. this._registerFunction('torch._utils._rebuild_parameter', function(data, requires_grad, backward_hooks) {
  1360. let obj = self.invoke('torch.nn.parameter.Parameter', [ data, requires_grad ]);
  1361. obj.backward_hooks = backward_hooks;
  1362. return obj;
  1363. });
  1364. this._registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
  1365. return {
  1366. __module__: storage.__module__,
  1367. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1368. storage: storage,
  1369. storage_offset: storage_offset,
  1370. size: size,
  1371. stride: stride,
  1372. quantizer_params: quantizer_params,
  1373. requires_grad:requires_grad,
  1374. backward_hooks: backward_hooks
  1375. };
  1376. });
  1377. this._registerFunction('torch._set_item', function(dict, key, value) {
  1378. dict[key] = value;
  1379. });
  1380. this._registerFunction('torch.__contains__', function(dict, key) {
  1381. return dict[key] !== undefined;
  1382. });
  1383. this._registerFunction('torch.__derive_index', function(index, start, step) {
  1384. return start + index * step;
  1385. });
  1386. this._registerFunction('torch.__is__', function(left, right) {
  1387. if (left === null && right === null) {
  1388. return true;
  1389. }
  1390. if ((left !== null && right === null) || (left === null && right !== null)) {
  1391. return false;
  1392. }
  1393. throw new pytorch.Error("Unknown 'torch.__is__' expression type.");
  1394. });
  1395. this._registerFunction('torch.__isnot__', function(left, right) {
  1396. if (left === null && right === null) {
  1397. return false;
  1398. }
  1399. if ((left !== null && right === null) || (left === null && right !== null)) {
  1400. return true;
  1401. }
  1402. throw new pytorch.Error("Unknown 'torch.__isnot__' expression type.");
  1403. });
  1404. this._registerFunction('torch.__not__', function(value) {
  1405. if (typeof value === 'boolean') {
  1406. return !value;
  1407. }
  1408. throw new pytorch.Error("Unknown 'torch.__not__' expression type.");
  1409. });
  1410. this._registerFunction('torch.__range_length', function(lo, hi, step) {
  1411. if (step === 0) {
  1412. throw new pytorch.Error('range() arg 3 must not be zero');
  1413. }
  1414. if (step > 0 && lo < hi) {
  1415. return 1 + (hi - 1 - lo) / step;
  1416. }
  1417. else if (step < 0 && lo > hi) {
  1418. return 1 + (lo - 1 - hi) / (0 - step);
  1419. }
  1420. return 0;
  1421. });
  1422. this._registerFunction('torch._unwrap_optional', function(value) {
  1423. return value; // TODO
  1424. });
  1425. this._registerFunction('torch.add', function(left, right) {
  1426. if (typeof left === 'number' && typeof right === 'number') {
  1427. return left * right;
  1428. }
  1429. throw new pytorch.Error('Unknown torch.add expression type.');
  1430. });
  1431. this._registerFunction('torch.append', function(tensors, tensor) {
  1432. tensors.push(tensor);
  1433. return tensor;
  1434. });
  1435. this._registerFunction('torch.dict', function(args) {
  1436. if (args) {
  1437. throw new pytorch.Error("'torch.dict' arguments not supported.");
  1438. }
  1439. return {};
  1440. });
  1441. this._registerFunction('torch.dim', function(tensor) {
  1442. if (tensor && tensor.size) {
  1443. return tensor.size.length;
  1444. }
  1445. return 0; // TODO
  1446. });
  1447. this._registerFunction('torch.eq', function(left, right) {
  1448. if (typeof left === 'string' && typeof right === 'string') {
  1449. return left === right;
  1450. }
  1451. if (typeof left === 'number' && typeof right === 'number') {
  1452. return left === right;
  1453. }
  1454. throw new pytorch.Error("Unknown 'torch.eq' expression type.");
  1455. });
  1456. this._registerFunction('torch.floordiv', function(/* left, right */) {
  1457. return undefined;
  1458. });
  1459. this._registerFunction('torch.gt', function(left, right) {
  1460. if (typeof left === 'number' && typeof right === 'number') {
  1461. return left > right;
  1462. }
  1463. throw new pytorch.Error("Unknown 'torch.gt' expression type.");
  1464. });
  1465. this._registerFunction('torch.jit._pickle.build_boollist', function(data) {
  1466. return data;
  1467. });
  1468. this._registerFunction('torch.jit._pickle.build_doublelist', function(data) {
  1469. return data;
  1470. });
  1471. this._registerFunction('torch.jit._pickle.build_intlist', function(data) {
  1472. return data;
  1473. });
  1474. this._registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
  1475. return data;
  1476. });
  1477. this._registerFunction('torch.jit._pickle.build_tensor_from_id', function(data) {
  1478. return data;
  1479. });
  1480. this._registerFunction('torch.jit._pickle.restore_type_tag', function(value /*, type_str */) {
  1481. return value;
  1482. });
  1483. this._registerFunction('torch.keys', function(dict) {
  1484. return Object.keys(dict);
  1485. });
  1486. this._registerFunction('torch.len', function(value) {
  1487. if (value) {
  1488. return value.length;
  1489. }
  1490. return NaN;
  1491. });
  1492. this._registerFunction('torch.le', function(left, right) {
  1493. if (typeof left === 'number' && typeof right === 'number') {
  1494. if (isNaN(left) || isNaN(right)) {
  1495. return false;
  1496. }
  1497. return left <= right;
  1498. }
  1499. throw new pytorch.Error("Unknown 'torch.le' expression type.");
  1500. });
  1501. this._registerFunction('torch.list', function(args) {
  1502. return args;
  1503. });
  1504. this._registerFunction('torch.list_with_default', function(size /*, defaults */) {
  1505. return size;
  1506. });
  1507. this._registerFunction('torch.lt', function(left, right) {
  1508. if (typeof left === 'number' && typeof right === 'number') {
  1509. return left < right;
  1510. }
  1511. throw new pytorch.Error("Unknown 'torch.lt' expression type.");
  1512. });
  1513. this._registerFunction('torch.mul', function(left, right) {
  1514. if (typeof left === 'number' && typeof right === 'number') {
  1515. return left * right;
  1516. }
  1517. if (isNaN(left) || isNaN(right)) {
  1518. return NaN;
  1519. }
  1520. throw new pytorch.Error("Unknown 'torch.mul' expression type.");
  1521. });
  1522. this._registerFunction('torch.ne', function(left, right) {
  1523. if (typeof left === 'number' && typeof right === 'number') {
  1524. if (isNaN(left) || isNaN(right)) {
  1525. return false;
  1526. }
  1527. return left !== right;
  1528. }
  1529. if (Array.isArray(left) && Array.isArray(right) && left.length === right.length) {
  1530. return false;
  1531. }
  1532. throw new pytorch.Error("Unknown 'torch.ne' expression type.");
  1533. });
  1534. this._registerFunction('torch.neg', function(value) {
  1535. if (typeof value === 'number') {
  1536. return -value;
  1537. }
  1538. throw new pytorch.Error("Unknown 'torch.neg' expression type.");
  1539. });
  1540. this._registerFunction('torch.q_scale', function(/* tensor */) {
  1541. return -1; // TODO
  1542. });
  1543. this._registerFunction('torch.t', function(tensor) {
  1544. return tensor;
  1545. });
  1546. this._registerFunction('torch.size', function(tensor, dim) {
  1547. if (tensor && Array.isArray(tensor.size)) {
  1548. if (dim === undefined) {
  1549. return tensor.size;
  1550. }
  1551. if (Number.isInteger(dim)) {
  1552. if (dim >= 0 && dim < tensor.size.length) {
  1553. return tensor.size[dim];
  1554. }
  1555. if (dim < 0 && -dim < tensor.size.length) {
  1556. return tensor.size[tensor.size.length + dim];
  1557. }
  1558. }
  1559. throw new pytorch.Error('Dimension out of range (expected to be in range of ' + JSON.stringify(tensor.size) + ', but got ' + JSON.stringify(dim) + ').');
  1560. }
  1561. return NaN;
  1562. });
  1563. this._registerFunction('torch.slice', function(l, start, end, step) {
  1564. if (step !== 1) {
  1565. throw new pytorch.Error('Slicing only supports step=1');
  1566. }
  1567. start = Math.max(0, start);
  1568. end = Math.min(l.length, end);
  1569. return l.slice(start, end);
  1570. });
  1571. this._registerFunction('torch.sub', function(left, right) {
  1572. if (typeof left === 'number' && typeof right === 'number') {
  1573. return left * right;
  1574. }
  1575. throw new pytorch.Error("Unknown 'torch.sub' expression type.");
  1576. });
  1577. this._registerFunction('torch.values', function(dict) {
  1578. return Object.keys(dict).map((key) => dict[key]);
  1579. });
  1580. this._registerFunction('torch.warn', function() {
  1581. });
  1582. this._registerFunction('uninitialized', function(type) {
  1583. return ({ __module__: 'torch', __name__: type, __origin__: 'uninitialized' });
  1584. });
  1585. }
  1586. get context() {
  1587. return this._context;
  1588. }
  1589. parse(file) {
  1590. const data = this._sources[file];
  1591. if (data) {
  1592. const code = this._utf8Decoder.decode(data);
  1593. const reader = new this._python.Parser(code, file);
  1594. const program = reader.parse();
  1595. if (!program) {
  1596. throw new pytorch.Error("Module '" + file + "' parse error.");
  1597. }
  1598. return program;
  1599. }
  1600. return null;
  1601. }
  1602. package(name, file, raw) {
  1603. if (this._python && !this._packages.has(name)) {
  1604. file = file || 'code/' + name.split('.').join('/') + '.py';
  1605. const program = this.parse(file);
  1606. if (program) {
  1607. let globals = this._context.getx(name);
  1608. if (globals === undefined) {
  1609. globals = {};
  1610. this._context.setx(name, globals);
  1611. }
  1612. globals.__class__ = this._context.scope.builtins.module;
  1613. globals.__name__ = name;
  1614. globals.__file__ = file;
  1615. this._packages.set(name, globals);
  1616. let context = this._context.push(globals);
  1617. this._block(program.body, context);
  1618. if (raw) {
  1619. return program;
  1620. }
  1621. }
  1622. }
  1623. return this._packages.get(name);
  1624. }
  1625. type(name) {
  1626. const type = this._context.getx(name);
  1627. if (type !== undefined) {
  1628. return type;
  1629. }
  1630. let parts = name.split('.');
  1631. const className = parts.pop();
  1632. const moduleName = parts.join('.');
  1633. const module = this.package(moduleName);
  1634. if (module) {
  1635. return module[className];
  1636. }
  1637. return null;
  1638. }
  1639. invoke(name, args) {
  1640. const target = this.type(name);
  1641. if (target) {
  1642. if (target.__class__ === this._context.scope.builtins.type) {
  1643. let obj = {};
  1644. obj.__proto__ = target;
  1645. if (obj.__init__ && typeof obj.__init__ === 'function') {
  1646. obj.__init__.apply(obj, args);
  1647. }
  1648. return obj;
  1649. }
  1650. else if (target.__class__ === this._context.scope.builtins.function) {
  1651. if (target.__call__) {
  1652. throw new pytorch.Error('Unexpected function __call__.');
  1653. }
  1654. else {
  1655. return target.apply(null, args);
  1656. }
  1657. }
  1658. }
  1659. this._raiseUnkownName(name);
  1660. const typeParts = name.split('.');
  1661. const typeName = typeParts.pop();
  1662. const typeModule = typeParts.join('.');
  1663. return {
  1664. __module__: typeModule,
  1665. __name__: typeName
  1666. };
  1667. }
  1668. call(target, name, args, context) {
  1669. const callTarget = this._target(target, context);
  1670. let callArguments = args.map((argument) => this.expression(argument, context));
  1671. if (!callTarget || (name !== null && !callTarget[name])) {
  1672. const targetName = pytorch.Utility.target(target) + '.' + name;
  1673. if (this.type(targetName)) {
  1674. return this.invoke(targetName, callArguments);
  1675. }
  1676. throw new pytorch.Error("Unsupported function '" + targetName + "'.");
  1677. }
  1678. const func = name ? callTarget[name] : callTarget;
  1679. if (func.__class__ === this._context.scope.builtins.type) {
  1680. let obj = {};
  1681. obj.__proto__ = func;
  1682. if (obj.__init__ && typeof obj.__init__ === 'function') {
  1683. obj.__init__.apply(obj, args);
  1684. }
  1685. return obj;
  1686. }
  1687. if (func.__class__ === this._context.scope.builtins.function) {
  1688. if (func.__call__) {
  1689. return func.__call__(callArguments);
  1690. }
  1691. }
  1692. if (func.__class__ === this._context.scope.builtins.method) {
  1693. if (func.__call__) {
  1694. return func.__call__([ callTarget ].concat(callArguments));
  1695. }
  1696. }
  1697. if (typeof func === 'function') {
  1698. return func.apply(callTarget, callArguments);
  1699. }
  1700. throw new pytorch.Error("Unsupported call expression.");
  1701. }
  1702. apply(method, args, context) {
  1703. let locals = Array.prototype.slice.call(args);
  1704. context = context.push();
  1705. for (const parameter of method.parameters) {
  1706. context.set(parameter.name, locals.shift());
  1707. }
  1708. return this._block(method.body.statements, context)
  1709. }
  1710. _block(statements, context) {
  1711. statements = Array.prototype.slice.call(statements);
  1712. while (statements.length > 0) {
  1713. const statement = statements.shift();
  1714. switch (statement.type) {
  1715. case 'pass': {
  1716. break;
  1717. }
  1718. case 'return': {
  1719. return this.expression(statement.expression, context);
  1720. }
  1721. case 'def': {
  1722. const module = context.get('__name__');
  1723. const self = this;
  1724. const parent = context.get('__class__');
  1725. let type = null;
  1726. if (parent === this._context.scope.builtins.type) {
  1727. type = this._context.scope.builtins.method;
  1728. }
  1729. else if (parent === this._context.scope.builtins.module) {
  1730. type = this._context.scope.builtins.function;
  1731. }
  1732. else {
  1733. throw new pytorch.Error('Invalid function scope.');
  1734. }
  1735. const func = {
  1736. __class__: type,
  1737. __globals__: context,
  1738. __module__: module,
  1739. __name__: statement.name,
  1740. __code__: statement,
  1741. __call__: function(args) {
  1742. return self.apply(this.__code__, args, this.__globals__);
  1743. }
  1744. }
  1745. context.set(statement.name, func);
  1746. break;
  1747. }
  1748. case 'class': {
  1749. const scope = {
  1750. __class__:this._context.scope.builtins.type,
  1751. __module__: context.get('__name__'),
  1752. __name__: statement.name,
  1753. };
  1754. context.set(statement.name, scope)
  1755. context = context.push(scope);
  1756. this._block(statement.body.statements, context);
  1757. context = context.pop();
  1758. break;
  1759. }
  1760. case 'var': {
  1761. context.set(statement.name, undefined);
  1762. break;
  1763. }
  1764. case '=': {
  1765. this.expression(statement, context);
  1766. break;
  1767. }
  1768. case 'if': {
  1769. const condition = this.expression(statement.condition, context);
  1770. if (condition === true || condition) {
  1771. statements = statement.then.statements.concat(statements);
  1772. break;
  1773. }
  1774. else if (condition === false) {
  1775. statements = statement.else.statements.concat(statements);
  1776. break;
  1777. }
  1778. throw new pytorch.Error("Unknown condition.");
  1779. }
  1780. case 'for': {
  1781. if (statement.target.length == 1 &&
  1782. statement.variable.length === 1 && statement.variable[0].type === 'id') {
  1783. const range = this.expression(statement.target[0], context);
  1784. const variable = statement.variable[0];
  1785. let loop = []
  1786. for (const value of range) {
  1787. loop.push({ type: '=', target: variable, expression: { type: 'number', value: value }});
  1788. loop = loop.concat(statement.body.statements);
  1789. }
  1790. statements = loop.concat(statements);
  1791. break;
  1792. }
  1793. throw new pytorch.Error("Unsupported 'for' statement.");
  1794. }
  1795. case 'call': {
  1796. this.expression(statement, context);
  1797. break;
  1798. }
  1799. case 'import': {
  1800. for (const module of statement.modules) {
  1801. const moduleName = pytorch.Utility.target(module.name);
  1802. const globals = this.package(moduleName);
  1803. if (module.as) {
  1804. context.set(module.as, globals);
  1805. }
  1806. }
  1807. break;
  1808. }
  1809. default: {
  1810. throw new pytorch.Error("Unknown statement '" + statement.type + "'.");
  1811. }
  1812. }
  1813. }
  1814. }
  1815. expression(expression, context) {
  1816. const self = context.getx('self');
  1817. switch (expression.type) {
  1818. case '=': {
  1819. const target = expression.target;
  1820. if (target.type === 'id') {
  1821. context.set(target.value, this.expression(expression.expression, context));
  1822. return;
  1823. }
  1824. else if (target.type === '[]') {
  1825. if (target.target.type === 'id' &&
  1826. target.arguments.type === 'list' &&
  1827. target.arguments.value.length === 1) {
  1828. const index = this.expression(target.arguments.value[0], context);
  1829. if (target.target.value === '__annotations__') {
  1830. context.set(target.target.value, context.get(target.target.value) || {});
  1831. }
  1832. context.get(target.target.value)[index] = this.expression(expression.expression, context);
  1833. return;
  1834. }
  1835. }
  1836. else if (target.type === '.' &&
  1837. target.member.type === 'id') {
  1838. this.expression(target.target, context)[target.member.value] = this.expression(expression.expression, context);
  1839. return;
  1840. }
  1841. else if (target.type === 'tuple') {
  1842. const value = this.expression(expression.expression, context);
  1843. if (target.value.length == value.length && target.value.every((item) => item.type === 'id')) {
  1844. for (let i = 0; i < value.length; i++) {
  1845. context.set(target.value[i].value, value[i]);
  1846. }
  1847. return;
  1848. }
  1849. }
  1850. break;
  1851. }
  1852. case 'list': {
  1853. return expression.value.map((item) => this.expression(item, context));
  1854. }
  1855. case 'string': {
  1856. return expression.value.substring(1, expression.value.length - 1);
  1857. }
  1858. case 'number': {
  1859. return Number(expression.value);
  1860. }
  1861. case '[]': {
  1862. if (expression.target.type === 'id' &&
  1863. expression.arguments.type === 'list' &&
  1864. expression.arguments.value.length === 1) {
  1865. if (context.get(expression.target.value)) {
  1866. const index = this.expression(expression.arguments.value[0], context);
  1867. return context.get(expression.target.value)[index];
  1868. }
  1869. if (expression.target.value === 'List' || expression.target.value === 'Optional') {
  1870. if (expression.arguments.value.every((item) => item.type === 'id')) {
  1871. throw new pytorch.Error('Unsupported index expression.');
  1872. // return { __typeref__: expression.target.value + '[' + expression.arguments.value.map((item) => item.value).join(',') + ']' };
  1873. }
  1874. }
  1875. }
  1876. const target = this.expression(expression.target, context);
  1877. if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) {
  1878. const index = this.expression(expression.arguments.value[0], context);
  1879. return target[index];
  1880. }
  1881. break;
  1882. }
  1883. case '.': {
  1884. if (expression.member.type == 'id') {
  1885. const target = this._target(expression.target, context);
  1886. return target[expression.member.value];
  1887. }
  1888. throw new pytorch.Error("Unsupported field expression.");
  1889. }
  1890. case 'call': {
  1891. if (expression.target.type === 'id' && expression.target.value === 'uninitialized' && expression.arguments.length === 1) {
  1892. const argument = expression.arguments[0];
  1893. if (argument.type === 'id' && argument.value === 'Tensor') {
  1894. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' };
  1895. }
  1896. if (argument.type === '[]' && argument.target.type === 'id' && argument.target.value === 'Tuple' &&
  1897. argument.arguments.type === 'list' && argument.arguments.value.every((item) => item.type === 'id' && item.value === 'Tensor')) {
  1898. return argument.arguments.value.map((/* item */) => { return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' }; });
  1899. }
  1900. }
  1901. if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) {
  1902. return this.expression(expression.arguments[1], context);
  1903. }
  1904. if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast' && expression.arguments.length === 2) {
  1905. return this.expression(expression.arguments[1], context);
  1906. }
  1907. if (expression.target.type === '.') {
  1908. return this.call(expression.target.target, expression.target.member.value, expression.arguments, context);
  1909. }
  1910. return this.call(expression.target, null, expression.arguments, context);
  1911. }
  1912. case 'id': {
  1913. switch (expression.value) {
  1914. case 'self': return self;
  1915. case 'None': return null;
  1916. case 'True': return true;
  1917. case 'False': return false;
  1918. }
  1919. return context.get(expression.value);
  1920. }
  1921. case 'tuple': {
  1922. return expression.value.map((expression) => this.expression(expression, context));
  1923. }
  1924. }
  1925. throw new pytorch.Error("Unknown expression '" + expression.type + "'.");
  1926. }
  1927. _target(expression, context) {
  1928. let current = expression;
  1929. let packageName = '';
  1930. for (;;) {
  1931. if (current.type === '.' && current.member && current.member.type === 'id') {
  1932. packageName = '.' + current.member.value + packageName;
  1933. current = current.target;
  1934. }
  1935. else if (current.type === 'id' && current.value !== 'self' && current.value !== 'CONSTANTS') {
  1936. packageName = current.value + packageName;
  1937. break;
  1938. }
  1939. else {
  1940. packageName = null;
  1941. break;
  1942. }
  1943. }
  1944. if (packageName) {
  1945. let target = context.getx(packageName);
  1946. if (!target) {
  1947. target = this.package(packageName);
  1948. if (!target) {
  1949. throw new pytorch.Error("Failed to resolve module '" + packageName + "'.");
  1950. }
  1951. }
  1952. return target;
  1953. }
  1954. return this.expression(expression, context);
  1955. }
  1956. _registerFunction(name, callback) {
  1957. if (this._context.getx(name)) {
  1958. throw new pytorch.Error("Function '" + name + "' is already registered.");
  1959. }
  1960. const parts = name.split('.');
  1961. callback.__class__ = this._context.scope.builtins.function;
  1962. callback.__name__ = parts.pop();
  1963. callback.__module__ = parts.join('.');
  1964. this._context.setx(name, callback);
  1965. }
  1966. _registerConstructor(name, callback) {
  1967. if (this._context.getx(name)) {
  1968. throw new pytorch.Error("Constructor '" + name + "' is already registered.");
  1969. }
  1970. const parts = name.split('.');
  1971. const typeName = parts.pop();
  1972. const typeModule = parts.join('.');
  1973. let type = {
  1974. __class__: this._context.scope.builtins.type,
  1975. __name__: typeName,
  1976. __module__: typeModule,
  1977. __init__: function() {
  1978. callback.apply(this, arguments);
  1979. }
  1980. }
  1981. this._context.setx(name, type);
  1982. }
  1983. _raiseUnkownName(name) {
  1984. if (name && !this._unknownNameMap.has(name)) {
  1985. this._unknownNameMap.add(name);
  1986. if (this._knownPackageMap.has(name.split('.').shift())) {
  1987. this._exceptionCallback(new pytorch.Error("Unknown function '" + name + "'."), false);
  1988. }
  1989. }
  1990. }
  1991. }
  1992. pytorch.Execution.Context = class {
  1993. constructor(parent, scope) {
  1994. this._parent = parent || null;
  1995. this._scope = scope || {};
  1996. }
  1997. push(scope) {
  1998. return new pytorch.Execution.Context(this, scope);
  1999. }
  2000. pop() {
  2001. return this._parent;
  2002. }
  2003. get scope() {
  2004. return this._scope;
  2005. }
  2006. set(name, value) {
  2007. this._scope[name] = value;
  2008. }
  2009. get(name) {
  2010. if (name in this._scope) {
  2011. return this._scope[name]
  2012. }
  2013. if (this._parent) {
  2014. return this._parent.get(name);
  2015. }
  2016. return undefined;
  2017. }
  2018. setx(name, value) {
  2019. let parts = name.split('.');
  2020. if (parts.length == 1) {
  2021. this.set(parts[0], value)
  2022. }
  2023. else {
  2024. let parent = this.get(parts[0]);
  2025. if (!parent) {
  2026. parent = {};
  2027. this.set(parts[0], parent)
  2028. }
  2029. parts.shift();
  2030. while (parts.length > 1) {
  2031. const part = parts.shift();
  2032. parent[part] = parent[part] || {};
  2033. parent = parent[part];
  2034. }
  2035. parent[parts[0]] = value;
  2036. }
  2037. }
  2038. getx(name) {
  2039. let parts = name.split('.');
  2040. let value = this.get(parts[0]);
  2041. if (value) {
  2042. parts.shift();
  2043. while (parts.length > 0 && value[parts[0]]) {
  2044. value = value[parts[0]];
  2045. parts.shift();
  2046. }
  2047. if (parts.length === 0) {
  2048. return value;
  2049. }
  2050. }
  2051. return undefined;
  2052. }
  2053. }
  2054. pytorch.Container = class {
  2055. static open(context, metadata, pickle, python, exception) {
  2056. if (context.entries('zip').some((entry) => entry.name === 'model.json' || entry.name === 'data.pkl' || entry.name.endsWith('/model.json') || entry.name.endsWith('/data.pkl'))) {
  2057. return new pytorch.Container.Zip(context.entries('zip'), metadata, pickle, python, exception);
  2058. }
  2059. const buffer = context.buffer;
  2060. const signature = [ 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  2061. if (buffer && buffer.length > 14 && buffer[0] == 0x80 && buffer[1] < 0x05 && signature.every((v, i) => v == buffer[i + 2])) {
  2062. return new pytorch.Container.Pickle(buffer, pickle, exception);
  2063. }
  2064. if (context.entries('tar').some((entry) => entry.name == 'pickle')) {
  2065. return new pytorch.Container.Tar(context.entries('tar'), pickle, exception);
  2066. }
  2067. return null;
  2068. }
  2069. }
  2070. pytorch.Container.Tar = class {
  2071. constructor(entries, pickle, exceptionCallback) {
  2072. this._entries = entries;
  2073. this._pickle = pickle;
  2074. this._exceptionCallack = exceptionCallback;
  2075. }
  2076. get format() {
  2077. return 'PyTorch v0.1.1'
  2078. }
  2079. get data() {
  2080. this._unpickle();
  2081. return this._data;
  2082. }
  2083. get state() {
  2084. this._unpickle();
  2085. return this._state;
  2086. }
  2087. get littleEndian() {
  2088. this._unpickle();
  2089. return this._littleEndian;
  2090. }
  2091. _unpickle() {
  2092. if (!this._entries) {
  2093. return;
  2094. }
  2095. this._data = null;
  2096. this._state = null;
  2097. this._littleEndian = true;
  2098. const execution = new pytorch.Execution(null, [], this._exceptionCallback);
  2099. let entries = {};
  2100. for (const entry of this._entries) {
  2101. switch (entry.name) {
  2102. case 'sys_info': entries.sys_info = entry.data; break;
  2103. case 'pickle': entries.pickle = entry.data; break;
  2104. case 'storages': entries.storages = entry.data; break;
  2105. case 'tensors': entries.tensors = entry.data; break;
  2106. }
  2107. }
  2108. this._exceptionCallback = null;
  2109. this._entries = null;
  2110. if (entries.sys_info) {
  2111. const unpickler = new this._pickle.Unpickler(entries.sys_info);
  2112. const sys_info = unpickler.load((name, args) => execution.invoke(name, args));
  2113. if (sys_info.protocol_version != 1000) {
  2114. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2115. }
  2116. if (sys_info.type_sizes &&
  2117. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  2118. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  2119. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  2120. throw new pytorch.Error('Unsupported type sizes.');
  2121. }
  2122. this._littleEndian = sys_info.little_endian;
  2123. }
  2124. let deserialized_objects = {};
  2125. if (entries.storages) {
  2126. const unpickler = new this._pickle.Unpickler(entries.storages);
  2127. const num_storages = unpickler.load((name, args) => execution.invoke(name, args));
  2128. for (let i = 0; i < num_storages; i++) {
  2129. const storage_args = unpickler.load();
  2130. const storage_key = storage_args[0];
  2131. const storage_type = storage_args[2];
  2132. const size = long.Long.fromBytesLE(unpickler.read(8), false).toNumber();
  2133. const storage = execution.invoke(storage_type, [ size ]);
  2134. storage.data = unpickler.read(storage.dataTypeSize * storage.size);
  2135. deserialized_objects[storage_key] = storage;
  2136. }
  2137. /*
  2138. let storage_views = unpickler.load();
  2139. for target_cdata, root_cdata, offset, size in storage_views:
  2140. root = deserialized_objects[root_cdata]
  2141. deserialized_objects[target_cdata] = root[offset:offset + size]
  2142. */
  2143. }
  2144. if (entries.tensors) {
  2145. const unpickler = new this._pickle.Unpickler(entries.tensors);
  2146. const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
  2147. for (let j = 0; j < num_tensors; j++) {
  2148. const tensor_args = unpickler.load();
  2149. const tensor_key = tensor_args[0];
  2150. const storage_id = tensor_args[1];
  2151. const storage = deserialized_objects[storage_id];
  2152. const ndim = long.Long.fromBytesLE(unpickler.read(4), false).toNumber();
  2153. unpickler.read(4);
  2154. let shape = [];
  2155. for (let k = 0; k < ndim; k++) {
  2156. shape.push(long.Long.fromBytesLE(unpickler.read(8), false).toNumber())
  2157. }
  2158. let stride = [];
  2159. for (let l = 0; l < ndim; l++) {
  2160. stride.push(long.Long.fromBytesLE(unpickler.read(8), false).toNumber())
  2161. }
  2162. const storage_offset = long.Long.fromBytesLE(unpickler.read(8), false).toNumber();
  2163. const tensor_type_name = storage.__name__.replace('Storage', 'Tensor');
  2164. const tensor = execution.invoke(storage.__module__ + '.' + tensor_type_name, []);
  2165. tensor.__setstate__([ storage, storage_offset, shape, stride ]);
  2166. deserialized_objects[tensor_key] = tensor;
  2167. }
  2168. }
  2169. if (entries.pickle) {
  2170. const unpickler = new this._pickle.Unpickler(entries.pickle);
  2171. const persistent_load = (saved_id) => {
  2172. return deserialized_objects[saved_id];
  2173. };
  2174. let obj = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  2175. if (obj) {
  2176. if (!Array.isArray(obj)) {
  2177. let array = [];
  2178. for (const key of Object.keys(obj)) {
  2179. array.push({ key: key, value: obj[key] });
  2180. }
  2181. obj = array;
  2182. }
  2183. this._state = [];
  2184. let state_map = {};
  2185. if (obj && Array.isArray(obj)) {
  2186. for (const item of obj) {
  2187. if (!item || !item.key || !item.value) {
  2188. this._state = null;
  2189. break;
  2190. }
  2191. let state = {};
  2192. state.id = item.key;
  2193. state.value = null;
  2194. if (item.value && item.value.__module__ === 'torch.nn.parameter' && item.value.__name__ === 'Parameter') {
  2195. state.value = item.value[0];
  2196. }
  2197. else if (pytorch.Utility.isTensor(item.value)) {
  2198. state.value = item.value;
  2199. }
  2200. if (!state.value) {
  2201. this._state = null;
  2202. break;
  2203. }
  2204. const split = state.id.split('.');
  2205. if (split.length < 2) {
  2206. this._state = null;
  2207. break;
  2208. }
  2209. state.name = split.pop();
  2210. const state_group_name = split.join('.');
  2211. let state_group = state_map[state_group_name];
  2212. if (!state_group) {
  2213. state_group = {};
  2214. state_group.name = state_group_name;
  2215. state_group.states = [];
  2216. state_map[state_group_name] = state_group;
  2217. this._state.push(state_group);
  2218. }
  2219. state_group.states.push(state);
  2220. }
  2221. }
  2222. }
  2223. }
  2224. }
  2225. }
  2226. pytorch.Container.Pickle = class {
  2227. constructor(buffer, pickle, exception) {
  2228. this._buffer = buffer;
  2229. this._pickle = pickle;
  2230. this._exceptionCallback = exception;
  2231. }
  2232. get format() {
  2233. return 'PyTorch v0.1.10'
  2234. }
  2235. get data() {
  2236. this._unpickle();
  2237. return this._data;
  2238. }
  2239. get state() {
  2240. this._unpickle();
  2241. return this._state;
  2242. }
  2243. get littleEndian() {
  2244. this._unpickle();
  2245. return this._littleEndian;
  2246. }
  2247. _unpickle() {
  2248. if (!this._buffer) {
  2249. return;
  2250. }
  2251. const execution = new pytorch.Execution(null, [], this._exceptionCallback);
  2252. const unpickler = new this._pickle.Unpickler(this._buffer);
  2253. this._buffer = null;
  2254. this._pickle = null;
  2255. this._exceptionCallback = null;
  2256. unpickler.load(); // magic_number
  2257. const protocol_version = unpickler.load();
  2258. if (protocol_version != 1001) {
  2259. throw new pytorch.Error("Unsupported protocol version '" + protocol_version + "'.");
  2260. }
  2261. const sys_info = unpickler.load();
  2262. if (sys_info.protocol_version != 1001) {
  2263. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2264. }
  2265. if (sys_info.type_sizes &&
  2266. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  2267. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  2268. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  2269. throw new pytorch.Error('Unsupported type sizes.');
  2270. }
  2271. this._littleEndian = sys_info.little_endian;
  2272. let module_source_map = new Map();
  2273. let deserialized_objects = new Map();
  2274. const persistent_load = (saved_id) => {
  2275. const typename = saved_id.shift();
  2276. const data = saved_id;
  2277. switch (typename) {
  2278. case 'module': {
  2279. const module = data[0];
  2280. const source = data[2];
  2281. module_source_map.set(module, source);
  2282. return data[0];
  2283. }
  2284. case 'storage': {
  2285. const data_type = data.shift();
  2286. const root_key = data.shift();
  2287. data.shift(); // location
  2288. const size = data.shift();
  2289. const view_metadata = data.shift();
  2290. if (!deserialized_objects.has(root_key)) {
  2291. const storage = execution.invoke(data_type, [ size ]);
  2292. deserialized_objects.set(root_key, storage);
  2293. }
  2294. if (view_metadata) {
  2295. const view_key = view_metadata.shift();
  2296. view_metadata.shift(); // view_offset
  2297. view_metadata.shift(); // view_size
  2298. if (!deserialized_objects.has(view_key)) {
  2299. const view = null; // storage.slice(view_offset, view_offset + view_size);
  2300. deserialized_objects.set(view_key, view);
  2301. }
  2302. return deserialized_objects.get(view_key);
  2303. }
  2304. return deserialized_objects.get(root_key);
  2305. }
  2306. }
  2307. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2308. };
  2309. const data = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  2310. if (!data) {
  2311. throw new pytorch.Error('File format is not PyTorch.');
  2312. }
  2313. const deserialized_storage_keys = unpickler.load();
  2314. for (const deserialized_storage_key of deserialized_storage_keys) {
  2315. const storage = deserialized_objects.get(deserialized_storage_key);
  2316. const size = long.Long.fromBytesLE(unpickler.read(8), false).toNumber();
  2317. if (size != storage.size) {
  2318. throw new pytorch.Error('Storage size mismatch.');
  2319. }
  2320. storage.data = unpickler.read(storage.dataTypeSize * storage.size);
  2321. }
  2322. this._data = this._findRootModule(data);
  2323. if (!this._data) {
  2324. this._state = this._findStateDict(data);
  2325. }
  2326. if (!this._data && !this._state) {
  2327. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2328. }
  2329. }
  2330. _findRootModule(root) {
  2331. const candidates = [ root, root.model, root.net ];
  2332. for (const obj of candidates) {
  2333. if (obj && obj._modules) {
  2334. return obj;
  2335. }
  2336. }
  2337. return null;
  2338. }
  2339. _findStateDict(root) {
  2340. if (root) {
  2341. if (root.encoder && Array.isArray(root.encoder) &&
  2342. root.decoder && Array.isArray(root.decoder) && !root.state_dict) {
  2343. root = root.encoder.concat(root.decoder);
  2344. }
  2345. if (Array.isArray(root) && root.every((item) => item.key && item.value)) {
  2346. let obj = {};
  2347. for (const item of root) {
  2348. obj[item.key] = item.value;
  2349. }
  2350. root = obj;
  2351. }
  2352. }
  2353. const candidates = [
  2354. root, root.state_dict, root.state,
  2355. root.model_state, root.model, root.model_state_dict, root.net_dict,
  2356. root.params, root.generator, root.discriminator, root.g_state,
  2357. root.network, root.net, root.netG,
  2358. root.state_dict_stylepredictor, root.state_dict_ghiasi
  2359. ];
  2360. for (const dict of candidates) {
  2361. let state_dict = null;
  2362. state_dict = state_dict || this._convertStateDictList(dict);
  2363. state_dict = state_dict || this._convertStateDictMap(dict);
  2364. state_dict = state_dict || this._convertStateDictGroupMap(dict);
  2365. if (state_dict) {
  2366. return state_dict;
  2367. }
  2368. }
  2369. return null;
  2370. }
  2371. _convertStateDictList(list) {
  2372. if (!list ||
  2373. !Array.isArray(list) ||
  2374. !list.every((item) => item && item.key && (item.value === null || pytorch.Utility.isTensor(item.value)))) {
  2375. return null;
  2376. }
  2377. let state_dict = [];
  2378. let state_map = {};
  2379. for (const item of list) {
  2380. if (item.value === null) {
  2381. continue;
  2382. }
  2383. const split = item.key.split('.');
  2384. if (split.length < 2) {
  2385. return null;
  2386. }
  2387. let state = {};
  2388. state.id = item.key;
  2389. state.name = split.pop();
  2390. state.value = item.value;
  2391. const state_group_name = split.join('.');
  2392. let state_group = state_map[state_group_name];
  2393. if (!state_group) {
  2394. state_group = {};
  2395. state_group.name = state_group_name;
  2396. state_group.states = [];
  2397. state_map[state_group_name] = state_group;
  2398. state_dict.push(state_group);
  2399. }
  2400. state_group.states.push(state);
  2401. }
  2402. return state_dict;
  2403. }
  2404. _convertStateDictMap(obj) {
  2405. if (!obj || Array.isArray(obj)) {
  2406. return null
  2407. }
  2408. let state_dict = [];
  2409. let state_map = {};
  2410. for (const key in obj) {
  2411. const split = key.split('.');
  2412. if (split.length < 1) {
  2413. return null;
  2414. }
  2415. let state = {};
  2416. state.id = key;
  2417. state.name = split.pop();
  2418. state.value = obj[key];
  2419. if (state.value && state.value.__module__ === 'torch.nn.parameter' && state.value.__name__ === 'Parameter') {
  2420. if (pytorch.Utility.isTensor(state.value.data)) {
  2421. state.value = state.value.data;
  2422. }
  2423. }
  2424. if (!pytorch.Utility.isTensor(state.value)) {
  2425. return null;
  2426. }
  2427. const state_group_name = split.join('.');
  2428. let state_group = state_map[state_group_name];
  2429. if (!state_group) {
  2430. state_group = {};
  2431. state_group.name = state_group_name;
  2432. state_group.states = [];
  2433. state_map[state_group_name] = state_group;
  2434. state_dict.push(state_group);
  2435. }
  2436. state_group.states.push(state);
  2437. }
  2438. return state_dict;
  2439. }
  2440. _convertStateDictGroupMap(obj) {
  2441. if (!obj || Array.isArray(obj)) {
  2442. return null;
  2443. }
  2444. let state_dict = [];
  2445. let state_map = {};
  2446. for (const state_group_name in obj) {
  2447. let state_group = state_map[state_group_name];
  2448. if (!state_group) {
  2449. state_group = {};
  2450. state_group.name = state_group_name;
  2451. state_group.states = [];
  2452. state_group.attributes = [];
  2453. state_map[state_group_name] = state_group;
  2454. state_dict.push(state_group);
  2455. }
  2456. const item = obj[state_group_name];
  2457. if (!item) {
  2458. return null;
  2459. }
  2460. if (Array.isArray(item)) {
  2461. for (const entry of item) {
  2462. if (!entry || !entry.key || !entry.value || !pytorch.Utility.isTensor(entry.value)) {
  2463. return null;
  2464. }
  2465. state_group.states.push({
  2466. id: state_group_name + '.' + entry.key,
  2467. name: entry.key,
  2468. value: entry.value
  2469. });
  2470. }
  2471. }
  2472. else if (item instanceof Uint8Array) {
  2473. return null;
  2474. }
  2475. else if (Object(item) === item) {
  2476. let hasTensors = false;
  2477. for (const key in item) {
  2478. const value = item[key];
  2479. if (pytorch.Utility.isTensor(value)) {
  2480. state_group.states.push({ name: key, value: value, id: state_group_name + '.' + key });
  2481. hasTensors = true;
  2482. }
  2483. else if (value !== Object(value)) {
  2484. state_group.attributes.push({ name: key, value: value });
  2485. }
  2486. else if (value && value.data && value.__module__ === 'torch.nn.parameter' && value.__name__ === 'Parameter') {
  2487. state_group.states.push({ name: key, value: value.data, id: state_group_name + '.' + key });
  2488. hasTensors = true;
  2489. }
  2490. else {
  2491. return null;
  2492. }
  2493. }
  2494. if (!hasTensors) {
  2495. return null;
  2496. }
  2497. }
  2498. else {
  2499. return null;
  2500. }
  2501. }
  2502. return state_dict;
  2503. }
  2504. }
  2505. pytorch.Container.Zip = class {
  2506. constructor(entries, metadata, pickle, python, exceptionCallback) {
  2507. this._entries = entries;
  2508. this._metadata = metadata;
  2509. this._pickle = pickle;
  2510. this._python = python;
  2511. this._exceptionCallback = exceptionCallback;
  2512. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
  2513. const entry = this._entries.find((entry) => entry.name == 'model.json' || entry.name == 'data.pkl' || entry.name.endsWith('/model.json') || entry.name.endsWith('/data.pkl'));
  2514. if (!entry) {
  2515. throw new pytorch.Error("PyTorch Zip container does not contain 'data.pkl' or 'model.json'.");
  2516. }
  2517. const lastIndex = entry.name.lastIndexOf('/');
  2518. this._prefix = lastIndex === -1 ? '' : entry.name.substring(0, lastIndex + 1);
  2519. this._utf8Decoder = new TextDecoder('utf-8');
  2520. }
  2521. get format() {
  2522. if (this._format === undefined) {
  2523. if (this._entry('model.json')) {
  2524. this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
  2525. }
  2526. else if (this._entry('data.pkl')) {
  2527. // kProducedFileFormatVersion in ./third_party/src/pytorch/caffe2/serialize/inline_container.h
  2528. const versionEntry = this._entry('version');
  2529. const versionNumber = versionEntry ? this._utf8Decoder.decode(versionEntry.data).split('\n').shift() : '';
  2530. const versionTable = { '1': 'v1.3', '2': 'v1.4', '3': 'v1.5' };
  2531. const version = versionTable[versionNumber];
  2532. if (!version) {
  2533. this._exceptionCallback(new pytorch.Error("Unsupported PyTorch ZIP version '" + versionNumber + "'."));
  2534. }
  2535. this._format = (this._entry('constants.pkl') ? 'TorchScript' : 'PyTorch') + ' ' + (version || 'v#' + versionNumber.toString() );
  2536. }
  2537. }
  2538. return this._format;
  2539. }
  2540. get producer() {
  2541. return this.data ? this._producer : '';
  2542. }
  2543. get name() {
  2544. return this._name;
  2545. }
  2546. get data() {
  2547. if (this._data === undefined) {
  2548. this._data = null;
  2549. const dataEntry = this._entry('data.pkl');
  2550. if (dataEntry && dataEntry.data) {
  2551. this._data = this._unpickle(dataEntry.data, this._storage('data'));
  2552. }
  2553. else {
  2554. const modelEntry = this._entry('model.json');
  2555. if (modelEntry) {
  2556. const model = JSON.parse(this._utf8Decoder.decode(modelEntry.data))
  2557. this._producer = model.producerName + (model.producerVersion ? ' v' + model.producerVersion : '');
  2558. this._data = model.mainModule || {};
  2559. this._name = this._data.name || '';
  2560. if (this._data.torchscriptArena) {
  2561. this._torchscriptArena = this._data.torchscriptArena.key;
  2562. }
  2563. let queue = [ this._data ];
  2564. let entries = new Map();
  2565. for (const entry of this._entries) {
  2566. entries.set(entry.name, entry.data);
  2567. }
  2568. const tensorTypeMap = new Map([
  2569. [ 'FLOAT', 'Float' ],
  2570. [ 'FLOAT16', 'Half' ],
  2571. [ 'DOUBLE', 'Double' ],
  2572. [ 'INT8', 'Char' ],
  2573. [ 'INT32', 'Int' ],
  2574. [ 'INT64', 'Long' ]
  2575. ]);
  2576. this._constants = model.tensors || [];
  2577. for (const tensor of this._constants) {
  2578. const key = this._prefix + tensor.data.key;
  2579. if (!tensorTypeMap.has(tensor.dataType)) {
  2580. throw new pytorch.Error("Unknown tensor data type '" + tensor.dataType + "'.");
  2581. }
  2582. const type = tensorTypeMap.get(tensor.dataType);
  2583. tensor.__module__ = 'torch';
  2584. tensor.__name__ = 'Tensor';
  2585. tensor.name = tensor.data.key;
  2586. tensor.size = tensor.dims ? tensor.dims.map((dim) => parseInt(dim, 10)) : null;
  2587. tensor.storage = this.execution.invoke('torch.' + type + 'Storage', [ tensor.size ]);
  2588. tensor.storage.data = entries.get(key);
  2589. }
  2590. while (queue.length > 0) {
  2591. let module = queue.shift();
  2592. if (!module.__module__ && !module.__name__) {
  2593. module.__module__ = 'torch.nn.modules.module'
  2594. module.__name__ = 'Module';
  2595. }
  2596. if (module.name) {
  2597. module.__id__ = module.name;
  2598. }
  2599. if (module.submodules) {
  2600. for (const submodule of module.submodules) {
  2601. module[submodule.name] = submodule;
  2602. submodule.__parent__ = module;
  2603. queue.push(submodule);
  2604. }
  2605. delete module.submodules;
  2606. }
  2607. let parameters = [];
  2608. if (module.parameters) {
  2609. parameters = parameters.concat(module.parameters);
  2610. delete module.parameters;
  2611. }
  2612. if (module.arguments) {
  2613. parameters = parameters.concat(module.arguments);
  2614. delete module.arguments;
  2615. }
  2616. for (const parameter of parameters) {
  2617. const tensor = this._constants[parameter.tensorId];
  2618. module[parameter.name] = tensor;
  2619. if (!parameter.__module__ || !parameter.__name__) {
  2620. parameter.__module__ = 'torch';
  2621. parameter.__name__ = 'Tensor';
  2622. }
  2623. }
  2624. }
  2625. }
  2626. }
  2627. }
  2628. return this._data;
  2629. }
  2630. get constants() {
  2631. if (this._constants === undefined) {
  2632. this._constants = [];
  2633. const entry = this._entry('constants.pkl');
  2634. if (entry && entry.data) {
  2635. this._constants = this._unpickle(entry.data, this._storage('constants'));
  2636. }
  2637. }
  2638. return this._constants;
  2639. }
  2640. get execution() {
  2641. if (this._execution === undefined) {
  2642. this._types = new Map(); // TODO
  2643. let sources = {};
  2644. for (const entry of this._entries) {
  2645. if (entry.name.startsWith(this._prefix + 'code')) {
  2646. const file = entry.name.substring(this._prefix.length);
  2647. if (sources[file]) {
  2648. throw new pytorch.Error("Duplicate source file '" + file + "'.");
  2649. }
  2650. sources[file] = entry.data;
  2651. }
  2652. }
  2653. this._execution = new pytorch.Container.Zip.Execution(this._python, sources, this._exceptionCallback, this._metadata);
  2654. let constants = {};
  2655. for (let i = 0; i < this.constants.length; i++) {
  2656. constants['c' + i.toString()] = this.constants[i];
  2657. }
  2658. this._execution.context.set('CONSTANTS', constants);
  2659. }
  2660. return this._execution;
  2661. }
  2662. _entry(name) {
  2663. return this._entries.find((entry) => entry.name == this._prefix + name)
  2664. }
  2665. _unpickle(data, storage_map) {
  2666. let deserialized_objects = new Map();
  2667. const persistent_load = (saved_id) => {
  2668. const typename = saved_id.shift();
  2669. if (typename !== 'storage') {
  2670. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2671. }
  2672. const data_type = saved_id.shift();
  2673. const root_key = saved_id.shift();
  2674. saved_id.shift(); // location
  2675. const size = saved_id.shift();
  2676. let storage = null;
  2677. if (deserialized_objects.has(root_key)) {
  2678. storage = deserialized_objects.get(root_key);
  2679. }
  2680. else {
  2681. storage = this.execution.invoke(data_type, [ size ]);
  2682. storage.data = storage_map.get(root_key);
  2683. deserialized_objects.set(root_key, storage);
  2684. }
  2685. const view_metadata = saved_id.shift();
  2686. if (view_metadata) {
  2687. let view_key = view_metadata.shift();
  2688. view_metadata.shift(); // view_offset
  2689. view_metadata.shift(); // view_size
  2690. let view = null;
  2691. if (deserialized_objects.has(view_key)) {
  2692. view = deserialized_objects.get(root_key);
  2693. }
  2694. else {
  2695. view = null; // storage.slice(view_offset, view_offset + view_size);
  2696. deserialized_objects.set(view_key, view);
  2697. }
  2698. return view;
  2699. }
  2700. return storage;
  2701. };
  2702. return new this._pickle.Unpickler(data).load((name, args) => this.execution.invoke(name, args), persistent_load);
  2703. }
  2704. _storage(dirname) {
  2705. let map = new Map();
  2706. const prefix = this._prefix + dirname + '/';
  2707. for (const entry of this._entries) {
  2708. if (entry.name.startsWith(prefix)) {
  2709. const key = entry.name.substring(prefix.length);
  2710. map.set(key, entry.data);
  2711. }
  2712. }
  2713. return map;
  2714. }
  2715. _type(name) {
  2716. if (!this._types.has(name)) {
  2717. let parts = name.split('.');
  2718. const className = parts.pop();
  2719. const file = 'code/' + parts.join('/') + '.py';
  2720. const program = this.execution.parse(file);
  2721. if (program) {
  2722. for (const statement of program.body) {
  2723. if (statement.type === 'class' && statement.name == className) {
  2724. this._types.set(name, statement);
  2725. break;
  2726. }
  2727. }
  2728. }
  2729. }
  2730. return this._types.get(name);
  2731. }
  2732. trace() {
  2733. this._inputs = [];
  2734. this._outputs = [];
  2735. this.execution.reset();
  2736. if (this._torchscriptArena) {
  2737. const program = this.execution.parse(this._torchscriptArena);
  2738. for (let statement of program.body) {
  2739. if (statement.type == 'def') {
  2740. const self = this;
  2741. const globals = this.execution.context;
  2742. const func = {
  2743. __class__: this.execution.context.scope.builtins.function,
  2744. __name__: statement.name,
  2745. __code__: statement,
  2746. __call__: function(args) {
  2747. return self.execution.apply(this.__code__, args, globals);
  2748. }
  2749. }
  2750. this.data[statement.name] = func;
  2751. }
  2752. }
  2753. }
  2754. if (this.data.forward) {
  2755. let args = [ this.data ]; // self
  2756. if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
  2757. for (const parameter of this.data.forward.__code__.parameters) {
  2758. if (parameter.name !== 'self') {
  2759. const type = parameter.parameterType;
  2760. if (type.type === 'type' && type.name.type) {
  2761. if (type.name.value === 'Tensor') {
  2762. this._inputs.push(parameter.name);
  2763. args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
  2764. }
  2765. if (type.name.value === 'Tuple' && type.arguments.every((item) => item.type === 'type' && item.name.type === 'id' && item.name.value === 'Tensor')) {
  2766. this._inputs.push(parameter.name);
  2767. args.push(type.arguments.map(() => { return { __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' } }));
  2768. }
  2769. }
  2770. }
  2771. }
  2772. }
  2773. const result = this.data.forward.__call__(args);
  2774. const outputs = !Array.isArray(result) ? [ result ] : result;
  2775. for (const output of outputs) {
  2776. if (pytorch.Utility.isTensor(output)) {
  2777. this._outputs.push(output.__variable__);
  2778. }
  2779. }
  2780. this._nodes = this.execution.nodes;
  2781. return true;
  2782. }
  2783. throw new pytorch.Error("Module 'forward' not implemented.");
  2784. }
  2785. get inputs() {
  2786. return this._inputs;
  2787. }
  2788. get outputs() {
  2789. return this._outputs;
  2790. }
  2791. get nodes() {
  2792. return this._nodes;
  2793. }
  2794. }
  2795. pytorch.Container.Zip.Execution = class extends pytorch.Execution {
  2796. constructor(python, sources, exceptionCallback, metadata) {
  2797. super(python, sources, exceptionCallback);
  2798. this._metadata = metadata;
  2799. this.reset();
  2800. }
  2801. reset() {
  2802. this._nodes = [];
  2803. this._variableIndex = 0;
  2804. }
  2805. get nodes() {
  2806. return this._nodes;
  2807. }
  2808. call(target, name, args, context) {
  2809. let callTarget = pytorch.Utility.target(target);
  2810. let outputTypes = null;
  2811. if (callTarget && callTarget + '.' + name === 'ops.prim.NumToTensor' &&
  2812. args.length === 1 && args[0].type === 'call' && args[0].target.member.type == 'id') {
  2813. const innerCall = args[0];
  2814. callTarget = pytorch.Utility.target(innerCall.target.target);
  2815. args = innerCall.arguments;
  2816. name = innerCall.target.member.value;
  2817. outputTypes = [ 'int64' ];
  2818. }
  2819. if (callTarget) {
  2820. const type = callTarget + '.' + name;
  2821. // ./third_party/src/pytorch/aten/src/ATen/native/native_functions.yaml
  2822. let schemas = this._metadata.type(type);
  2823. if (schemas) {
  2824. if (!Array.isArray(schemas)) {
  2825. schemas = [ schemas ]
  2826. }
  2827. for (const schema of schemas) {
  2828. let callArgs = Array.prototype.slice.call(args);
  2829. let node = {
  2830. type: schema.name,
  2831. inputs: [],
  2832. attributes: [],
  2833. outputs: []
  2834. };
  2835. let referencedParameters = [];
  2836. let next = false;
  2837. const inputSchemas = Array.prototype.slice.call(schema.inputs || []);
  2838. while (inputSchemas.length > 0) {
  2839. let inputSchema = inputSchemas.shift();
  2840. const argument = this.expression(callArgs.shift(), context);
  2841. if ((Array.isArray(argument) && inputSchema.type !== 'T[]') ||
  2842. (!Array.isArray(argument) && inputSchema.type === 'T[]')) {
  2843. next = true;
  2844. break;
  2845. }
  2846. const parameters = Array.isArray(argument) ? argument : [ argument ];
  2847. let inputs = [];
  2848. for (let parameter of parameters) {
  2849. if (parameter !== undefined) {
  2850. if (!pytorch.Utility.isTensor(parameter) && parameter !== null) {
  2851. next = true;
  2852. break;
  2853. }
  2854. if (parameter === null) {
  2855. parameter = {};
  2856. }
  2857. if (!parameter.__variable__) {
  2858. parameter.__variable__ = this._variable();
  2859. }
  2860. inputs.push({ id: parameter.__variable__ });
  2861. referencedParameters.push(parameter);
  2862. }
  2863. }
  2864. if (next) {
  2865. break;
  2866. }
  2867. node.inputs.push(inputs);
  2868. }
  2869. if (next) {
  2870. continue;
  2871. }
  2872. while (callArgs.length > 0 && callArgs[0].type !== '=') {
  2873. const value = this.expression(callArgs.shift(), context);
  2874. node.attributes.push(value);
  2875. }
  2876. while (callArgs.length > 0) {
  2877. const arg = callArgs.shift();
  2878. if (arg.type === '=' && arg.target && arg.target.type === 'id') {
  2879. const value = this.expression(arg.expression, context);
  2880. node.attributes.push({ type: '=', target: arg.target, expression: value });
  2881. }
  2882. else {
  2883. throw new pytorch.Attribute('Expected named argument.');
  2884. }
  2885. }
  2886. let outputs = []
  2887. for (let i = 0; i < schema.outputs.length; i++) {
  2888. if (schema.outputs[i].type && schema.outputs[i].type !== 'T') {
  2889. if (!outputTypes || outputTypes.length !== schema.outputs.length || schema.outputs[i].type !== outputTypes[i]) {
  2890. next = true;
  2891. break;
  2892. }
  2893. }
  2894. let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
  2895. switch (type) {
  2896. case 'torch.cat':
  2897. case 'torch.conv2d':
  2898. case 'torch.dropout':
  2899. case 'torch.flatten':
  2900. case 'torch.max_pool2d':
  2901. case 'torch.quantize_per_tensor':
  2902. case 'torch.relu_':
  2903. case 'torch.hardtanh_':
  2904. case 'torch.slice': {
  2905. parameter.size = [ NaN, NaN, NaN, NaN ];
  2906. break;
  2907. }
  2908. case 'torch.conv3d': {
  2909. parameter.size = [ NaN, NaN, NaN, NaN, NaN ];
  2910. break;
  2911. }
  2912. case 'torch.embedding': {
  2913. parameter.size = [ NaN, NaN, NaN ];
  2914. break;
  2915. }
  2916. case 'torch.ones':
  2917. case 'torch.zeros':
  2918. case 'torch.zeros_like': {
  2919. parameter.size = this.expression(args[0], context);
  2920. break;
  2921. }
  2922. }
  2923. parameter.__variable__ = this._variable();
  2924. outputs.push(parameter)
  2925. node.outputs.push(parameter.__variable__);
  2926. }
  2927. if (next) {
  2928. continue;
  2929. }
  2930. for (const parameter of referencedParameters) {
  2931. parameter.__count__ = (parameter.__count__ || 0) + 1;
  2932. }
  2933. this._nodes.push(node);
  2934. if (outputs.length > 1) {
  2935. return outputs;
  2936. }
  2937. return outputs[0];
  2938. }
  2939. }
  2940. }
  2941. return super.call(target, name, args, context);
  2942. }
  2943. _variable() {
  2944. this._variableIndex++;
  2945. return this._variableIndex.toString();
  2946. }
  2947. }
  2948. pytorch.Utility = class {
  2949. static target(expression) {
  2950. if (expression.type == 'id') {
  2951. return expression.value;
  2952. }
  2953. if (expression.type == '.') {
  2954. return pytorch.Utility.target(expression.target) + '.' + pytorch.Utility.target(expression.member)
  2955. }
  2956. return null;
  2957. }
  2958. static isTensor(obj) {
  2959. return obj && (obj.__module__ === 'torch' || obj.__module__ === 'torch.cuda') && obj.__name__ && obj.__name__.endsWith('Tensor');
  2960. }
  2961. }
  2962. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  2963. module.exports.ModelFactory = pytorch.ModelFactory;
  2964. }