torchscript.js 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269
  1. /* jshint esversion: 6 */
  2. /* eslint "indent": [ "error", 4, { "SwitchCase": 1 } ] */
  3. // Experimental
  4. var torchscript = torchscript || {};
  5. var base = base || require('./base');
  6. var long = long || { Long: require('long') };
  7. var marked = marked || require('marked');
  8. var zip = zip || require('./zip');
  9. torchscript.ModelFactory = class {
  10. match(context) {
  11. var identifier = context.identifier;
  12. var extension = identifier.split('.').pop().toLowerCase();
  13. if (extension == 'pt' || extension == 'pth' || extension == 'pkl' || extension == 'h5' || extension == 't7' ||
  14. extension == 'dms' || extension == 'model' || extension == 'ckpt' || identifier.endsWith('.pth.tar')) {
  15. if (torchscript.ModelFactory._openContainer(context)) {
  16. return true;
  17. }
  18. }
  19. return false;
  20. }
  21. open(context, host) {
  22. return host.require('./python').then((python) => {
  23. return host.require('./pickle').then((pickle) => {
  24. var identifier = context.identifier;
  25. try {
  26. var container = torchscript.ModelFactory._openContainer(context);
  27. if (container.attributes) {
  28. container.attributes = new pickle.Unpickler(container.attributes.data).load((name, args) => {
  29. return { type: name, args: args[0] };
  30. });
  31. }
  32. container.identifier = identifier;
  33. return torchscript.Metadata.open(host).then((metadata) => {
  34. try {
  35. return new torchscript.Model(metadata, host, python, container);
  36. }
  37. catch (error) {
  38. host.exception(error, false);
  39. var message = error && error.message ? error.message : error.toString();
  40. message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
  41. throw new torchscript.Error(message + " in '" + identifier + "'.");
  42. }
  43. });
  44. }
  45. catch (error) {
  46. host.exception(error, false);
  47. var message = error && error.message ? error.message : error.toString();
  48. message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
  49. return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
  50. }
  51. });
  52. });
  53. }
  54. static _openContainer(context) {
  55. let entries = context.entries;
  56. if (entries && entries.length > 0) {
  57. var container = { };
  58. container.version = entries.find((entry) => entry.name == 'version' || entry.name.endsWith('/version'));
  59. if (container.version) {
  60. container.prefix = container.version.name.substring(0, container.version.name.length - 7);
  61. container.attributes = entries.find((entry) => entry.name == container.prefix + 'attributes.pkl');
  62. container.model = entries.find((entry) => entry.name == container.prefix + 'model.json');
  63. container.entries = entries;
  64. if (container.version && container.model) {
  65. return container;
  66. }
  67. }
  68. }
  69. return null;
  70. }
  71. };
  72. torchscript.Model = class {
  73. constructor(metadata, host, python, container) {
  74. var textDecoder = new TextDecoder('utf-8');
  75. var model = JSON.parse(textDecoder.decode(container.model.data));
  76. var version = JSON.parse(textDecoder.decode(container.version.data));
  77. this._format = 'TorchScript v' + version.toString();
  78. if (model.producerName) {
  79. this._producer = model.producerName;
  80. if (model.producerVersion) {
  81. this._producer = this._producer + ' v' + model.producerVersion;
  82. }
  83. }
  84. this._graphs = [];
  85. this._graphs.push(new torchscript.Graph(metadata, host, python, container, model.mainModule, model.tensors));
  86. }
  87. get format() {
  88. return this._format;
  89. }
  90. get producer() {
  91. return this._producer;
  92. }
  93. get graphs() {
  94. return this._graphs;
  95. }
  96. };
  97. torchscript.Graph = class {
  98. constructor(metadata, host, python, container, mainModule, tensors) {
  99. this._name = mainModule.name;
  100. this._inputs = [];
  101. this._outputs = [];
  102. this._nodes = [];
  103. container.tensors = tensors.map((tensor) => new torchscript.Tensor(tensor, container));
  104. var context = null;
  105. try {
  106. context = new torchscript.GraphContext(container, python, mainModule);
  107. }
  108. catch (error) {
  109. var message = error && error.message ? error.message : error.toString();
  110. message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
  111. host.exception(new torchscript.Error(message + " in '" + container.identifier + "'."), false);
  112. }
  113. container.parameters = {};
  114. var queue = [ mainModule ];
  115. while (queue.length > 0) {
  116. var module = queue.shift();
  117. if (module.parameters) {
  118. for (var parameter of module.parameters) {
  119. if (parameter.tensorId) {
  120. var tensorId = parseInt(parameter.tensorId, 10);
  121. parameter.initializer = container.tensors[tensorId];
  122. if (parameter.outputs && parameter.outputs.length == 1) {
  123. container.parameters[parameter.outputs[0]] = parameter;
  124. }
  125. }
  126. }
  127. }
  128. if (module.submodules) {
  129. for (var submodule of module.submodules) {
  130. submodule.parent = module;
  131. queue.push(submodule);
  132. }
  133. }
  134. }
  135. if (context) {
  136. for (var input of context.inputs) {
  137. this._inputs.push(new torchscript.Parameter(input, true, [
  138. new torchscript.Argument(input, null, null)
  139. ]));
  140. }
  141. for (var output of context.outputs) {
  142. this._outputs.push(new torchscript.Parameter(output, true, [
  143. new torchscript.Argument(output, null, null)
  144. ]));
  145. }
  146. for (var node of context.nodes) {
  147. this._nodes.push(new torchscript.Node(metadata, container, null, node));
  148. }
  149. }
  150. this._loadModule(metadata, container, mainModule);
  151. }
  152. _loadModule(metadata, container, module) {
  153. if (module.parameters && module.parameters.length > 0 && !module.hide) {
  154. var node = new torchscript.Node(metadata, container, module, null);
  155. this._nodes.push(node);
  156. }
  157. if (module.submodules) {
  158. for (var submodule of module.submodules) {
  159. this._loadModule(metadata, container, submodule);
  160. }
  161. }
  162. }
  163. get type() {
  164. return this._type;
  165. }
  166. get name() {
  167. return this._name;
  168. }
  169. get groups() {
  170. return this._groups;
  171. }
  172. get inputs() {
  173. return this._inputs;
  174. }
  175. get outputs() {
  176. return this._outputs;
  177. }
  178. get nodes() {
  179. return this._nodes;
  180. }
  181. };
  182. torchscript.Parameter = class {
  183. constructor(name, visible, args) {
  184. this._name = name;
  185. this._visible = visible;
  186. this._arguments = args;
  187. }
  188. get name() {
  189. return this._name;
  190. }
  191. get visible() {
  192. return this._visible;
  193. }
  194. get arguments() {
  195. return this._arguments;
  196. }
  197. };
  198. torchscript.Argument = class {
  199. constructor(id, type, initializer) {
  200. this._id = id;
  201. this._type = type;
  202. this._initializer = initializer;
  203. }
  204. get id() {
  205. return this._id;
  206. }
  207. get type() {
  208. if (this._initializer) {
  209. return this._initializer.type;
  210. }
  211. return this._type;
  212. }
  213. get initializer() {
  214. return this._initializer;
  215. }
  216. };
  217. torchscript.Node = class {
  218. constructor(metadata, container, module, node) {
  219. this._metadata = metadata;
  220. this._attributes = [];
  221. this._inputs = [];
  222. this._outputs = [];
  223. var input = null;
  224. var argument = null;
  225. var parameter = null;
  226. if (module) {
  227. this._operator = 'Module';
  228. if (module.parameters) {
  229. for (parameter of module.parameters) {
  230. this._inputs.push(new torchscript.Parameter(parameter.name, true, [
  231. new torchscript.Argument('', null, parameter.initializer || null)
  232. ]));
  233. if (parameter.outputs) {
  234. this._outputs.push(new torchscript.Parameter(parameter.name, true,
  235. parameter.outputs.map((id) => new torchscript.Argument(id, null, null))
  236. ));
  237. }
  238. }
  239. }
  240. }
  241. if (node) {
  242. this._operator = node.name;
  243. this._name = '';
  244. var schema = metadata.getSchema(this._operator);
  245. module = null;
  246. var match = true;
  247. var count = 0;
  248. for (input of node.inputs) {
  249. for (argument of input) {
  250. parameter = container.parameters[argument.id];
  251. if (parameter) {
  252. if (parameter.module && (module == null || module == parameter.module)) {
  253. module = parameter.module;
  254. count++;
  255. }
  256. else {
  257. match = false;
  258. break;
  259. }
  260. }
  261. }
  262. if (!match) {
  263. break;
  264. }
  265. }
  266. if (module && module.parameters.length == count && match) {
  267. module.hide = true;
  268. for (input of node.inputs) {
  269. for (argument of input) {
  270. parameter = container.parameters[argument.id];
  271. if (parameter && parameter.initializer) {
  272. argument.initializer = parameter.initializer;
  273. }
  274. }
  275. }
  276. }
  277. else {
  278. module = null;
  279. }
  280. for (var inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) {
  281. var inputName = inputIndex.toString();
  282. if (schema && schema.inputs && schema.inputs.length > inputIndex) {
  283. inputName = schema.inputs[inputIndex].name;
  284. }
  285. this._inputs.push(new torchscript.Parameter(inputName, true,
  286. node.inputs[inputIndex].map((input) => new torchscript.Argument(input.id, null, input.initializer || null))
  287. ));
  288. }
  289. for (var outputIndex = 0; outputIndex < node.outputs.length; outputIndex++) {
  290. var outputName = outputIndex.toString();
  291. if (schema && schema.outputs && schema.outputs.length > outputIndex) {
  292. outputName = schema.outputs[outputIndex].name;
  293. }
  294. this._outputs.push(new torchscript.Parameter(outputName, true, [
  295. new torchscript.Argument(node.outputs[outputIndex], null, null)
  296. ]));
  297. }
  298. for (var attributeIndex = 0; attributeIndex < node.attributes.length; attributeIndex++) {
  299. var attributeSchema = null;
  300. var attributeName = attributeIndex.toString();
  301. var attributeValue = node.attributes[attributeIndex];
  302. if (attributeValue && attributeValue.type === '=' && attributeValue.target.type == 'identifier') {
  303. attributeName = attributeValue.target.value;
  304. attributeValue = attributeValue.expression;
  305. if (schema && schema.attributes) {
  306. attributeSchema = schema.attributes.find((s) => s.name == attributeName);
  307. }
  308. }
  309. else {
  310. if (schema && schema.attributes && schema.attributes.length > attributeIndex) {
  311. attributeSchema = schema.attributes[attributeIndex];
  312. attributeName = attributeSchema.name;
  313. }
  314. }
  315. this._attributes.push(new torchscript.Attribute(this, attributeSchema, attributeName, attributeValue));
  316. }
  317. }
  318. if (module) {
  319. if (module.name) {
  320. var current = module;
  321. this._name = current.name;
  322. while (current.parent != null) {
  323. current = current.parent;
  324. this._name = [ current.name, this._name ].join('.')
  325. }
  326. }
  327. }
  328. }
  329. get name() {
  330. return this._name;
  331. }
  332. get group() {
  333. return this._group;
  334. }
  335. get operator() {
  336. return this._operator;
  337. }
  338. get category() {
  339. var schema = this._metadata.getSchema(this._operator);
  340. return (schema && schema.category) ? schema.category : '';
  341. }
  342. get documentation() {
  343. var schema = this._metadata.getSchema(this._operator);
  344. if (schema) {
  345. schema = JSON.parse(JSON.stringify(schema));
  346. schema.name = this._operator;
  347. if (schema.description) {
  348. schema.description = marked(schema.description);
  349. }
  350. if (schema.attributes) {
  351. for (var attribute of schema.attributes) {
  352. if (attribute.description) {
  353. attribute.description = marked(attribute.description);
  354. }
  355. }
  356. }
  357. if (schema.inputs) {
  358. for (var input of schema.inputs) {
  359. if (input.description) {
  360. input.description = marked(input.description);
  361. }
  362. }
  363. }
  364. if (schema.outputs) {
  365. for (var output of schema.outputs) {
  366. if (output.description) {
  367. output.description = marked(output.description);
  368. }
  369. }
  370. }
  371. return schema;
  372. }
  373. return '';
  374. }
  375. get function() {
  376. return false;
  377. }
  378. get attributes() {
  379. return this._attributes;
  380. }
  381. get inputs() {
  382. return this._inputs;
  383. }
  384. get outputs() {
  385. return this._outputs;
  386. }
  387. };
  388. torchscript.Attribute = class {
  389. constructor(node, schema, name, value) {
  390. this._node = node;
  391. this._name = name;
  392. this._value = value;
  393. if (value && value.type) {
  394. switch (value.type) {
  395. case 'number':
  396. this._value = value.value;
  397. break;
  398. case 'string':
  399. this._value = value.value;
  400. break;
  401. case 'boolean':
  402. this._value = value.value;
  403. break;
  404. case 'identifier':
  405. this._value = value.value;
  406. break;
  407. }
  408. }
  409. if (schema) {
  410. if (Object.prototype.hasOwnProperty.call(schema, 'type')) {
  411. this._type = schema.type;
  412. }
  413. switch (this._type) {
  414. case 'boolean':
  415. if (this._value == 'False') {
  416. this._value = false;
  417. }
  418. else if (this._value == 'True') {
  419. this._value = true;
  420. }
  421. break;
  422. case 'int32':
  423. case 'int64':
  424. this._value = parseInt(this._value, 10);
  425. break;
  426. case 'float32':
  427. case 'float64':
  428. this._value = parseFloat(this._value);
  429. break;
  430. case 'int32[]':
  431. case 'int64[]':
  432. if (this._value.type == 'list' && this._value.value.every((item) => item.type === 'number')) {
  433. this._value = this._value.value.map((item) => {
  434. var number = parseInt(item.value, 10);
  435. if (!Number.isNaN(item.value - number)) {
  436. return number;
  437. }
  438. return item.value;
  439. });
  440. }
  441. break;
  442. }
  443. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  444. this._visible = false;
  445. }
  446. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  447. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  448. this._visible = false;
  449. }
  450. else if (Array.isArray(this._value) &&
  451. !Array.isArray(schema.default) &&
  452. this.value.every((item) => item == schema.default)) {
  453. this._visible = false;
  454. }
  455. }
  456. }
  457. }
  458. get type() {
  459. return this._type;
  460. }
  461. get name() {
  462. return this._name;
  463. }
  464. get value() {
  465. return this._value;
  466. }
  467. get visible() {
  468. return (this._visible == false || this.name == 'training') ? false : true;
  469. }
  470. };
  471. torchscript.Tensor = class {
  472. constructor(tensor, container) {
  473. this._type = new torchscript.TensorType(tensor.dataType, new torchscript.TensorShape(tensor.dims));
  474. var key = container.prefix + tensor.data.key;
  475. var entry = container.entries.find((entry) => entry.name == key);
  476. this._name = tensor.data.key;
  477. this._data = entry.data;
  478. this._littleEndian = true;
  479. }
  480. get kind() {
  481. return 'Tensor';
  482. }
  483. get name() {
  484. return this._name;
  485. }
  486. get type() {
  487. return this._type;
  488. }
  489. get state() {
  490. return this._context().state;
  491. }
  492. get value() {
  493. var context = this._context();
  494. if (context.state) {
  495. return null;
  496. }
  497. context.limit = Number.MAX_SAFE_INTEGER;
  498. return this._decode(context, 0);
  499. }
  500. toString() {
  501. var context = this._context();
  502. if (context.state) {
  503. return '';
  504. }
  505. context.limit = 10000;
  506. var value = this._decode(context, 0);
  507. return torchscript.Tensor._stringify(value, '', ' ');
  508. }
  509. _context() {
  510. var context = {};
  511. context.state = null;
  512. context.index = 0;
  513. context.count = 0;
  514. if (!this._type.dataType) {
  515. context.state = 'Tensor has no data type.';
  516. return context;
  517. }
  518. if (!this._type.shape) {
  519. context.state = 'Tensor has no dimensions.';
  520. return context;
  521. }
  522. if (!this._data) {
  523. context.state = 'Tensor data is empty.';
  524. return context;
  525. }
  526. context.data = this._data;
  527. context.dataType = this._type.dataType;
  528. context.dimensions = this._type.shape.dimensions;
  529. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  530. return context;
  531. }
  532. _decode(context, dimension) {
  533. var results = [];
  534. var dimensions = context.dimensions;
  535. if (dimensions.length == 0) {
  536. dimensions = [ 1 ];
  537. }
  538. var size = dimensions[dimension];
  539. if (dimension == dimensions.length - 1) {
  540. for (var i = 0; i < size; i++) {
  541. if (context.count > context.limit) {
  542. results.push('...');
  543. return results;
  544. }
  545. switch (context.dataType)
  546. {
  547. case 'uint8':
  548. results.push(context.dataView.getUint8(context.index, this._littleEndian));
  549. context.index++;
  550. context.count++;
  551. break;
  552. case 'int8':
  553. results.push(context.dataView.getInt8(context.index, this._littleEndian));
  554. context.index++;
  555. context.count++;
  556. break;
  557. case 'int16':
  558. results.push(context.dataView.getInt16(context.index, this._littleEndian));
  559. context.index += 2;
  560. context.count++;
  561. break;
  562. case 'int32':
  563. results.push(context.dataView.getInt32(context.index, this._littleEndian));
  564. context.index += 4;
  565. context.count++;
  566. break;
  567. case 'int64':
  568. results.push(new long.Long(context.dataView.getUint32(context.index, true), context.dataView.getUint32(context.index + 4, true), false));
  569. context.index += 8;
  570. context.count++;
  571. break;
  572. case 'float16':
  573. results.push(context.dataView.getFloat16(context.index, this._littleEndian));
  574. context.index += 2;
  575. context.count++;
  576. break;
  577. case 'float32':
  578. results.push(context.dataView.getFloat32(context.index, this._littleEndian));
  579. context.index += 4;
  580. context.count++;
  581. break;
  582. case 'float64':
  583. results.push(context.dataView.getFloat64(context.index, this._littleEndian));
  584. context.index += 8;
  585. context.count++;
  586. break;
  587. }
  588. }
  589. }
  590. else {
  591. for (var j = 0; j < size; j++) {
  592. if (context.count > context.limit) {
  593. results.push('...');
  594. return results;
  595. }
  596. results.push(this._decode(context, dimension + 1));
  597. }
  598. }
  599. if (context.dimensions.length == 0) {
  600. return results[0];
  601. }
  602. return results;
  603. }
  604. static _stringify(value, indentation, indent) {
  605. if (Array.isArray(value)) {
  606. var result = [];
  607. result.push(indentation + '[');
  608. var items = value.map((item) => torchscript.Tensor._stringify(item, indentation + indent, indent));
  609. if (items.length > 0) {
  610. result.push(items.join(',\n'));
  611. }
  612. result.push(indentation + ']');
  613. return result.join('\n');
  614. }
  615. if (value && long.Long.isLong(value)) {
  616. return indentation + value.toString();
  617. }
  618. if (typeof value == 'string') {
  619. return indentation + value;
  620. }
  621. if (value == Infinity) {
  622. return indentation + 'Infinity';
  623. }
  624. if (value == -Infinity) {
  625. return indentation + '-Infinity';
  626. }
  627. if (isNaN(value)) {
  628. return indentation + 'NaN';
  629. }
  630. return indentation + value.toString();
  631. }
  632. };
  633. torchscript.TensorType = class {
  634. constructor(dataType, shape) {
  635. switch(dataType) {
  636. case 'FLOAT': this._dataType = 'float32'; break;
  637. case 'DOUBLE': this._dataType = 'float64'; break;
  638. case 'INT32': this._dataType = 'int32'; break;
  639. case 'INT64': this._dataType = 'int64'; break;
  640. default: throw new torchscript.Error("Unknown tensor data type '" + dataType + "'.");
  641. }
  642. this._shape = shape;
  643. }
  644. get dataType() {
  645. return this._dataType;
  646. }
  647. get shape() {
  648. return this._shape;
  649. }
  650. toString() {
  651. return this._dataType + this._shape.toString();
  652. }
  653. };
  654. torchscript.TensorShape = class {
  655. constructor(dimensions) {
  656. this._dimensions = dimensions || [];
  657. }
  658. get dimensions() {
  659. return this._dimensions;
  660. }
  661. toString() {
  662. if (this._dimensions && this._dimensions.length > 0) {
  663. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  664. }
  665. return '';
  666. }
  667. };
  668. torchscript.Metadata = class {
  669. static open(host) {
  670. if (torchscript.Metadata._metadata) {
  671. return Promise.resolve(torchscript.Metadata._metadata);
  672. }
  673. else {
  674. return host.request(null, 'torchscript-metadata.json', 'utf-8').then((data) => {
  675. torchscript.Metadata._metadata = new torchscript.Metadata(data);
  676. return torchscript.Metadata._metadata;
  677. }).catch(() => {
  678. torchscript.Metadata._metadata = new torchscript.Metadata(null);
  679. return torchscript.Metadata._metadata;
  680. });
  681. }
  682. }
  683. constructor(data) {
  684. this._map = {};
  685. this._attributeCache = {};
  686. if (data) {
  687. var items = JSON.parse(data);
  688. if (items) {
  689. for (var item of items) {
  690. if (item.name && item.schema) {
  691. this._map[item.name] = item.schema;
  692. }
  693. }
  694. }
  695. }
  696. }
  697. getSchema(operator) {
  698. return this._map[operator] || null;
  699. }
  700. getAttributeSchema(operator, name) {
  701. var map = this._attributeCache[operator];
  702. if (!map) {
  703. map = {};
  704. var schema = this.getSchema(operator);
  705. if (schema && schema.attributes && schema.attributes.length > 0) {
  706. for (var attribute of schema.attributes) {
  707. map[attribute.name] = attribute;
  708. }
  709. }
  710. this._attributeCache[operator] = map;
  711. }
  712. return map[name] || null;
  713. }
  714. };
  715. torchscript.GraphContext = class {
  716. constructor(container, python, mainModule) {
  717. this._container = container;
  718. this._mainModule = mainModule;
  719. this._inputs = [];
  720. this._outputs = [];
  721. this._nodes = [];
  722. this._moduleMap = {};
  723. this._argumentMap = {};
  724. this._numToTensorMap = {};
  725. if (mainModule.torchscriptArena && mainModule.torchscriptArena.key) {
  726. var codeKey = container.prefix + mainModule.torchscriptArena.key;
  727. var codeEntries = container.entries.filter((e) => e.name === codeKey);
  728. if (codeEntries.length == 1) {
  729. var codeEntry = codeEntries[0];
  730. var textDecoder = new TextDecoder('utf-8');
  731. var code = textDecoder.decode(codeEntry.data);
  732. var reader = new python.Parser(code);
  733. var program = reader.parse();
  734. var method = program.body.find((statement) => statement.type == 'def' && statement.name == 'forward');
  735. if (method) {
  736. this._body = method.body.statements;
  737. var methodParameters = method.parameters;
  738. if (methodParameters.length > 0 && methodParameters[0].name == 'self') {
  739. methodParameters.shift();
  740. }
  741. for (var parameter of methodParameters) {
  742. this._parameter(parameter);
  743. }
  744. if (this._body.length >= 2) {
  745. var returnStatement = this._body[this._body.length - 1];
  746. var assignStatement = this._body[this._body.length - 2];
  747. if (returnStatement.type == 'return' &&
  748. returnStatement.expression.type == 'identifier' &&
  749. assignStatement.target.type == 'identifier' &&
  750. assignStatement.target.value == returnStatement.expression.value) {
  751. returnStatement.expression = assignStatement.expression;
  752. this._body.pop();
  753. this._body.pop();
  754. this._body.push(returnStatement);
  755. }
  756. }
  757. while (this._body.length > 0) {
  758. var statement = this._body.shift();
  759. if (this._attributeStatement(statement)) {
  760. continue;
  761. }
  762. if (this._moduleStatement(statement)) {
  763. continue;
  764. }
  765. if (this._argumentStatement(statement)) {
  766. continue;
  767. }
  768. if (this._nodeStatement(statement)) {
  769. continue;
  770. }
  771. if (this._returnStatement(statement)) {
  772. continue;
  773. }
  774. throw new torchscript.Error("Unknown statement '" + JSON.stringify(statement) + "'.");
  775. }
  776. }
  777. }
  778. }
  779. }
  780. get inputs() {
  781. return this._inputs;
  782. }
  783. get outputs() {
  784. return this._outputs;
  785. }
  786. get nodes() {
  787. return this._nodes;
  788. }
  789. _parameter(parameter) {
  790. var type = parameter.parameterType;
  791. if (type.type == 'type' && type.value == 'Tuple' && type.arguments && type.arguments.length > 0) {
  792. if (this._body.length > 0) {
  793. var statement = this._body[0];
  794. if (statement.expression.type == 'identifier' && statement.expression.value == parameter.name) {
  795. if (statement.type === '=' && statement.target.type === 'tuple') {
  796. for (var input of statement.target.value) {
  797. if (input) {
  798. this._inputs.push(input.value);
  799. }
  800. }
  801. this._body.shift();
  802. }
  803. }
  804. }
  805. }
  806. else {
  807. this._inputs.push(parameter.name);
  808. }
  809. }
  810. _returnStatement(statement) {
  811. if (statement.type == 'return') {
  812. var variable = this._variable();
  813. if (this._nodeExpression(statement.expression, variable)) {
  814. this._outputs.push(variable.value);
  815. return true;
  816. }
  817. if (statement.expression.type == 'identifier') {
  818. this._outputs.push(statement.expression.value);
  819. return true;
  820. }
  821. if (statement.expression.type == 'tuple') {
  822. var outputs = [];
  823. for (var expression of statement.expression.value) {
  824. variable = this._variable();
  825. if (this._nodeExpression(expression, variable)) {
  826. outputs.push(variable.value);
  827. continue
  828. }
  829. if (expression.type == 'identifier') {
  830. outputs.push(expression.value);
  831. continue;
  832. }
  833. return false;
  834. }
  835. this._outputs = this._outputs.concat(outputs);
  836. return true;
  837. }
  838. }
  839. return false;
  840. }
  841. _nodeExpression(expression, target) {
  842. if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'tuple')) {
  843. var name = this._name(expression.target);
  844. var namespace = 'torch.';
  845. if (name.startsWith(namespace)) {
  846. var node = {};
  847. node.name = name.substring(namespace.length);
  848. node.inputs = [];
  849. node.outputs = [];
  850. node.attributes = [];
  851. var args = expression.arguments;
  852. while (args.length > 0) {
  853. var argument = args[0];
  854. argument = this._moduleTensor(argument);
  855. if (argument.type == 'identifier' && this._argumentMap[argument.value]) {
  856. argument = this._argumentMap[argument.value];
  857. delete this._argumentMap[argument.value];
  858. }
  859. if (argument.type == 'identifier') {
  860. if (argument.value === 'False' || argument.value === 'True') {
  861. break;
  862. }
  863. node.inputs.push([ { id: argument.value } ]);
  864. args.shift();
  865. continue;
  866. }
  867. if (argument.type == 'list') {
  868. var list = [];
  869. for (var input of argument.value) {
  870. var variable = this._variable();
  871. if (this._nodeExpression(input, variable)) {
  872. list.push({ id: variable.value });
  873. }
  874. else if (this._argumentExpression(input, variable)) {
  875. list.push({ id: variable.value });
  876. }
  877. else if (input.type == 'identifier') {
  878. list.push({ id: input.value });
  879. }
  880. else {
  881. list = null;
  882. break;
  883. }
  884. }
  885. if (list) {
  886. node.inputs.push(list);
  887. args.shift();
  888. continue;
  889. }
  890. }
  891. if (argument.type == 'list') {
  892. break;
  893. }
  894. if (argument.type == 'number' || argument.type == 'string' || argument.type == 'boolean') {
  895. break;
  896. }
  897. if (argument.type == '=') {
  898. break;
  899. }
  900. variable = this._variable();
  901. if (this._nodeExpression(argument, variable)) {
  902. node.inputs.push([ { id: variable.value } ]);
  903. args.shift();
  904. continue;
  905. }
  906. if (this._argumentExpression(argument, variable)) {
  907. node.inputs.push([ { id: variable.value } ]);
  908. args.shift();
  909. continue;
  910. }
  911. if (argument.type == '.' &&
  912. argument.target.type == 'identifier' &&
  913. argument.target.value == 'CONSTANTS' &&
  914. argument.member.type == 'identifier' &&
  915. argument.member.value.startsWith('c')) {
  916. var constantId = [ argument.target.value, argument.member.value ].join('.');
  917. var constantIndex = parseInt(argument.member.value.substring(1), 10);
  918. var constantTensor = this._container.tensors[constantIndex];
  919. node.inputs.push([ { id: constantId, initializer: constantTensor } ]);
  920. args.shift();
  921. continue;
  922. }
  923. throw new torchscript.Error('Unknown function argument.');
  924. }
  925. while (args.length > 0) {
  926. if (args[0].type == 'list') {
  927. for (var i = 0; i < args[0].value.length; i++) {
  928. args[0].value[i] = this._attributeExpression(args[0].value[i]);
  929. }
  930. }
  931. var intExpression = this._attributeExpression(args[0]);
  932. if (intExpression) {
  933. args[0] = intExpression;
  934. }
  935. node.attributes.push(args[0]);
  936. args.shift();
  937. }
  938. if (target.type == 'identifier') {
  939. node.outputs.push(target.value);
  940. }
  941. if (target.type == 'tuple') {
  942. for (var identifier of target.value) {
  943. node.outputs.push(identifier.value);
  944. }
  945. }
  946. this._nodes.push(node);
  947. return true;
  948. }
  949. }
  950. return false;
  951. }
  952. _nodeStatement(statement) {
  953. if (statement.type == '=') {
  954. if (this._nodeExpression(statement.expression, statement.target)) {
  955. return true;
  956. }
  957. }
  958. return false;
  959. }
  960. _attributeExpression(expression) {
  961. if (expression.type == 'identifier') {
  962. if (this._numToTensorMap[expression.value]) {
  963. return { type: 'number', value: this._numToTensorMap[expression.value] };
  964. }
  965. }
  966. if (expression.type == 'call' &&
  967. expression.target.type == 'identifier' &&
  968. expression.target.value == 'int' &&
  969. expression.arguments.length == 1)
  970. {
  971. var replace = this._attributeExpression(expression.arguments[0]);
  972. if (replace) {
  973. return replace;
  974. }
  975. }
  976. return expression;
  977. }
  978. _attributeStatement(statement) {
  979. if (statement.type == '=' &&
  980. statement.target.type == 'identifier') {
  981. if (statement.expression.type == 'call' &&
  982. this._name(statement.expression.target) == 'ops.prim.NumToTensor' &&
  983. statement.expression.arguments.length == 1) {
  984. var size = statement.expression.arguments[0];
  985. if (size.type == 'call' &&
  986. size.arguments.length == 2 &&
  987. this._name(size.target) == 'torch.size' &&
  988. size.arguments[0].type == 'identifier' &&
  989. size.arguments[1].type == 'number') {
  990. this._numToTensorMap[statement.target.value] = this._name(size.target) + '(' + size.arguments.map((a) => a.value.toString()).join(',') + ')';
  991. return true;
  992. }
  993. if (size.type == 'identifier') {
  994. var duplicate1 = this._numToTensorMap[size.value];
  995. if (duplicate1) {
  996. this._numToTensorMap[statement.target.value] = duplicate1;
  997. return true;
  998. }
  999. }
  1000. }
  1001. if (statement.expression.type == 'call' &&
  1002. statement.expression.arguments.length == 2 &&
  1003. this._name(statement.expression.target) == 'torch.size' &&
  1004. statement.expression.arguments[0].type == 'identifier' &&
  1005. statement.expression.arguments[1].type == 'number') {
  1006. this._numToTensorMap[statement.target.value] = this._name(statement.expression.target) + '(' + statement.expression.arguments.map((a) => a.value.toString()).join(',') + ')';
  1007. return true;
  1008. }
  1009. if (statement.expression.type == 'call' &&
  1010. statement.expression.target.type == 'identifier' &&
  1011. statement.expression.target.value == 'int' &&
  1012. statement.expression.arguments.length == 1 &&
  1013. statement.expression.arguments[0].type == 'identifier') {
  1014. var duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value];
  1015. if (duplicate2) {
  1016. this._numToTensorMap[statement.target.value] = duplicate2;
  1017. return true;
  1018. }
  1019. }
  1020. }
  1021. return false;
  1022. }
  1023. _module(expression) {
  1024. var module;
  1025. var submodule;
  1026. if (expression.type === '.') {
  1027. module = this._module(expression.target);
  1028. if (module && module.submodules) {
  1029. for (submodule of module.submodules) {
  1030. if (submodule.name === expression.member.value) {
  1031. return submodule;
  1032. }
  1033. }
  1034. }
  1035. }
  1036. if (expression.type == 'call' &&
  1037. expression.target.type == 'identifier' && expression.target.value == 'getattr' && expression.arguments.length == 2) {
  1038. module = this._module(expression.arguments[0]);
  1039. if (!module) {
  1040. return null;
  1041. }
  1042. var name = null;
  1043. if (expression.arguments[1].type == 'string') {
  1044. name = expression.arguments[1].value.substring(1, expression.arguments[1].value.length - 1);
  1045. }
  1046. if (module) {
  1047. for (submodule of module.submodules) {
  1048. if (submodule.name === name) {
  1049. return submodule;
  1050. }
  1051. }
  1052. }
  1053. }
  1054. if (expression.type == 'identifier') {
  1055. if (expression.value == 'self') {
  1056. return this._mainModule;
  1057. }
  1058. module = this._moduleMap[expression.value];
  1059. if (module) {
  1060. return module;
  1061. }
  1062. }
  1063. return null;
  1064. }
  1065. _moduleStatement(statement) {
  1066. if (statement.type == '=' &&
  1067. statement.target.type === 'identifier') {
  1068. var moduleName = statement.target.value;
  1069. var module = this._module(statement.expression);
  1070. if (module) {
  1071. this._moduleMap[moduleName] = module;
  1072. return true;
  1073. }
  1074. }
  1075. return false;
  1076. }
  1077. _argumentExpression(expression, target) {
  1078. expression = this._moduleTensor(expression);
  1079. if (expression.type === '.' && expression.member.type == 'identifier') {
  1080. var targetModule = this._module(expression.target);
  1081. if (targetModule && targetModule.parameters) {
  1082. for (var parameter of targetModule.parameters) {
  1083. parameter.module = targetModule;
  1084. if (parameter.name === expression.member.value) {
  1085. parameter.outputs = parameter.outputs || [];
  1086. parameter.outputs.push(target.value);
  1087. return true;
  1088. }
  1089. }
  1090. targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
  1091. for (var unresolvedParameter of targetModule.unresolvedParameters) {
  1092. unresolvedParameter.module = targetModule;
  1093. if (unresolvedParameter.name === expression.member.value) {
  1094. unresolvedParameter.outputs = unresolvedParameter.outputs || [];
  1095. unresolvedParameter.outputs.push(target.value);
  1096. return true;
  1097. }
  1098. }
  1099. targetModule.unresolvedParameters.push({
  1100. module: targetModule,
  1101. name: expression.member.value,
  1102. outputs: [ target.value ]
  1103. });
  1104. return true;
  1105. }
  1106. }
  1107. return false;
  1108. }
  1109. _argumentStatement(statement) {
  1110. if (statement.type === '=' && statement.target.type === 'identifier') {
  1111. if (this._argumentExpression(statement.expression, statement.target)) {
  1112. return true;
  1113. }
  1114. if (statement.target.type == 'identifier' &&
  1115. statement.expression.type == 'list') {
  1116. this._argumentMap[statement.target.value] = statement.expression;
  1117. return true;
  1118. }
  1119. }
  1120. return false;
  1121. }
  1122. _variable() {
  1123. return { type: 'identifier', value: '_gen' + Math.random().toString(36).substring(7) };
  1124. }
  1125. _name(expression) {
  1126. if (expression.type == 'identifier') {
  1127. return expression.value;
  1128. }
  1129. if (expression.type == '.') {
  1130. return [ this._name(expression.target), this._name(expression.member) ].join('.');
  1131. }
  1132. throw new torchscript.Error("Failed to resolve name '" + JSON.stringify(expression) + "'.");
  1133. }
  1134. _moduleTensor(expression) {
  1135. if (expression.type == 'call' &&
  1136. expression.arguments.length == 1 &&
  1137. this._name(expression.target) == 'torch.t') {
  1138. return expression.arguments[0];
  1139. }
  1140. return expression;
  1141. }
  1142. }
  1143. torchscript.Error = class extends Error {
  1144. constructor(message) {
  1145. super(message);
  1146. this.name = 'Error loading TorchScript model.';
  1147. }
  1148. };
  1149. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1150. module.exports.ModelFactory = torchscript.ModelFactory;
  1151. }