pytorch.js 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. /*jshint esversion: 6 */
  2. // Experimental
  3. var pytorch = pytorch || {};
  4. var base = base || require('./base');
  5. pytorch.ModelFactory = class {
  6. match(context, host) {
  7. var extension = context.identifier.split('.').pop();
  8. if (extension == 'pt' || extension == 'pth' || extension == 'pkl') {
  9. var buffer = context.buffer;
  10. var torch = [ 0x80, 0x02, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  11. if (buffer && buffer.length > torch.length) {
  12. if (torch.every((value, index) => value == buffer[index])) {
  13. return true;
  14. }
  15. }
  16. }
  17. return false;
  18. }
  19. open(context, host, callback) {
  20. host.require('./pickle', (err, pickle) => {
  21. if (err) {
  22. callback(err, null);
  23. return;
  24. }
  25. pytorch.OperatorMetadata.open(host, (err, metadata) => {
  26. this._openModel(context, host, pickle, callback);
  27. });
  28. });
  29. }
  30. _openModel(context, host, pickle, callback) {
  31. try {
  32. var identifier = context.identifier;
  33. var unpickler = new pickle.Unpickler(context.buffer);
  34. var signature = [ 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  35. var magic_number = unpickler.load();
  36. if (!Array.isArray(magic_number) ||
  37. signature.length != magic_number.length ||
  38. !signature.every((value, index) => value == magic_number[index]))
  39. {
  40. callback(new pytorch.Error('Invalid signature.', null));
  41. return;
  42. }
  43. var protocol_version = unpickler.load();
  44. if (protocol_version != 1001) {
  45. callback(new pytorch.Error("Unsupported protocol version '" + protocol_version + "'.", null));
  46. return;
  47. }
  48. var sysInfo = unpickler.load();
  49. if (sysInfo.protocol_version != 1001) {
  50. callback(new pytorch.Error("Unsupported protocol version '" + sysInfo.protocol_version + "'.", null));
  51. return;
  52. }
  53. if (sysInfo.type_sizes) {
  54. if ((sysInfo.type_sizes.int && sysInfo.type_sizes.int != 4) ||
  55. (sysInfo.type_sizes.long && sysInfo.type_sizes.long != 4) ||
  56. (sysInfo.type_sizes.short && sysInfo.type_sizes.short != 2))
  57. {
  58. callback(new pytorch.Error('Unsupported type sizes.'));
  59. return;
  60. }
  61. }
  62. var constructorTable = {};
  63. var functionTable = {};
  64. constructorTable['argparse.Namespace'] = function (args) { this.args = args; };
  65. constructorTable['torch.nn.modules.activation.LeakyReLU'] = function () {};
  66. constructorTable['torch.nn.modules.activation.ReLU'] = function () {};
  67. constructorTable['torch.nn.modules.activation.PReLU'] = function () {};
  68. constructorTable['torch.nn.modules.activation.Sigmoid'] = function () {};
  69. constructorTable['torch.nn.modules.activation.Tanh'] = function () {};
  70. constructorTable['torch.nn.modules.batchnorm.BatchNorm1d'] = function () {};
  71. constructorTable['torch.nn.modules.batchnorm.BatchNorm2d'] = function () {};
  72. constructorTable['torch.nn.modules.batchnorm.BatchNorm3d'] = function () {};
  73. constructorTable['torch.nn.modules.container.ModuleList'] = function () {};
  74. constructorTable['torch.nn.modules.container.Sequential'] = function () {};
  75. constructorTable['torch.nn.modules.conv.Conv1d'] = function () {};
  76. constructorTable['torch.nn.modules.conv.Conv2d'] = function () {};
  77. constructorTable['torch.nn.modules.conv.Conv3d'] = function () {};
  78. constructorTable['torch.nn.modules.conv.ConvTranspose1d'] = function () {};
  79. constructorTable['torch.nn.modules.conv.ConvTranspose2d'] = function () {};
  80. constructorTable['torch.nn.modules.conv.ConvTranspose3d'] = function () {};
  81. constructorTable['torch.nn.modules.dropout.Dropout'] = function () {};
  82. constructorTable['torch.nn.modules.dropout.Dropout2d'] = function () {};
  83. constructorTable['torch.nn.modules.dropout.Dropout3d'] = function () {};
  84. constructorTable['torch.nn.modules.linear.Linear'] = function () {};
  85. constructorTable['torch.nn.modules.normalization.GroupNorm'] = function () {};
  86. constructorTable['torch.nn.modules.pooling.AvgPool1d'] = function () {};
  87. constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {};
  88. constructorTable['torch.nn.modules.pooling.AvgPool3d'] = function () {};
  89. constructorTable['torch.nn.modules.pooling.MaxPool1d'] = function() {};
  90. constructorTable['torch.nn.modules.pooling.MaxPool2d'] = function () {};
  91. constructorTable['torch.nn.modules.pooling.MaxPool3d'] = function() {};
  92. constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool1d'] = function() {};
  93. constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool2d'] = function() {};
  94. constructorTable['torch.nn.modules.pooling.AdaptiveAvgPool3d'] = function() {};
  95. constructorTable['torch.nn.modules.rnn.LSTM'] = function () {};
  96. constructorTable['torch.nn.modules.sparse.Embedding'] = function () {};
  97. constructorTable['torchvision.models.squeezenet.Fire'] = function () {};
  98. constructorTable['torchvision.models.squeezenet.SqueezeNet'] = function () {};
  99. constructorTable['torch.nn.modules.upsampling.Upsample'] = function() {};
  100. constructorTable['torchvision.models.alexnet.AlexNet'] = function () {};
  101. constructorTable['torchvision.models.densenet.DenseNet'] = function () {};
  102. constructorTable['torchvision.models.densenet._DenseBlock'] = function () {};
  103. constructorTable['torchvision.models.densenet._DenseLayer'] = function () {};
  104. constructorTable['torchvision.models.densenet._Transition'] = function () {};
  105. constructorTable['torchvision.models.inception.BasicConv2d'] = function () {};
  106. constructorTable['torchvision.models.inception.Inception3'] = function () {};
  107. constructorTable['torchvision.models.inception.InceptionAux'] = function () {};
  108. constructorTable['torchvision.models.inception.InceptionA'] = function () {};
  109. constructorTable['torchvision.models.inception.InceptionB'] = function () {};
  110. constructorTable['torchvision.models.inception.InceptionC'] = function () {};
  111. constructorTable['torchvision.models.inception.InceptionD'] = function () {};
  112. constructorTable['torchvision.models.inception.InceptionE'] = function () {};
  113. constructorTable['torch.nn.modules.padding.ReflectionPad2d'] = function () {};
  114. constructorTable['torchvision.models.resnet.Bottleneck'] = function () {};
  115. constructorTable['torchvision.models.resnet.BasicBlock'] = function() {};
  116. constructorTable['torchvision.models.resnet.ResNet'] = function () {};
  117. constructorTable['torchvision.models.vgg.VGG'] = function () {};
  118. constructorTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function () {};
  119. constructorTable['torch.nn.parameter.Parameter'] = function(data, requires_grad) { this.data = data; this.requires_grad = requires_grad; };
  120. constructorTable['torch.ByteStorage'] = function (size) { this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8'; };
  121. constructorTable['torch.LongStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'int64'; };
  122. constructorTable['torch.HalfStorage'] = function (size) { this.size = size; this.dataTypeSize = 2; this.dataType = 'float16'; };
  123. constructorTable['torch.FloatStorage'] = function (size) { this.size = size; this.dataTypeSize = 4; this.dataType = 'float32'; };
  124. constructorTable['torch.DoubleStorage'] = function (size) { this.size = size; this.dataTypeSize = 8; this.dataType = 'float64'; };
  125. constructorTable['torch.FloatTensor'] = function () {
  126. this.__setstate__ = function(state) {
  127. this.storage = state[0];
  128. this.storage_offset = state[1];
  129. this.size = state[2];
  130. this.stride = state[3];
  131. };
  132. };
  133. functionTable['collections.OrderedDict'] = function(args) {
  134. var obj = [];
  135. obj.__setitem__ = function(key, value) {
  136. obj.push({ key: key, value: value });
  137. };
  138. if (args) {
  139. args.forEach((arg) => {
  140. obj.__setitem__(arg[0], arg[1]);
  141. });
  142. }
  143. return obj;
  144. };
  145. functionTable['torch._utils._rebuild_tensor'] = function (storage, storage_offset, size, stride) {
  146. var obj = {};
  147. obj.__type__ = storage.__type__.replace('Storage', 'Tensor');
  148. obj.storage = storage;
  149. obj.storage_offset = storage_offset;
  150. obj.size = size;
  151. obj.stride = stride;
  152. return obj;
  153. };
  154. functionTable['torch._utils._rebuild_tensor_v2'] = function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
  155. var obj = {};
  156. obj.__type__ = storage.__type__.replace('Storage', 'Tensor');
  157. obj.storage = storage;
  158. obj.storage_offset = storage_offset;
  159. obj.size = size;
  160. obj.stride = stride;
  161. obj.requires_grad = requires_grad;
  162. obj.backward_hooks = backward_hooks;
  163. return obj;
  164. };
  165. functionTable['torch._utils._rebuild_parameter'] = function(data, requires_grad, backward_hooks) {
  166. var obj = {};
  167. obj.__type__ = 'torch.nn.parameter.Parameter';
  168. constructorTable[obj.__type__].apply(obj, [ data, requires_grad ]);
  169. obj.backward_hooks = backward_hooks;
  170. return obj;
  171. };
  172. var function_call = (name, args) => {
  173. var func = functionTable[name];
  174. if (func) {
  175. return func.apply(null, args);
  176. }
  177. var obj = { __type__: name };
  178. var constructor = constructorTable[name];
  179. if (constructor) {
  180. constructor.apply(obj, args);
  181. }
  182. else {
  183. debugger;
  184. host.exception(new pytorch.Error("Unknown function '" + name + "' in '" + identifier + "'."), false);
  185. }
  186. return obj;
  187. };
  188. var module_source_map = {};
  189. var deserialized_objects = {};
  190. var persistent_load = (saved_id) => {
  191. var typename = saved_id.shift();
  192. var data = saved_id;
  193. switch (typename) {
  194. case 'module':
  195. module_source_map[data[0]] = data[2];
  196. return data[0];
  197. case 'storage':
  198. var data_type = data.shift();
  199. var root_key = data.shift();
  200. var location = data.shift();
  201. var size = data.shift();
  202. var view_metadata = data.shift();
  203. var storage = deserialized_objects[root_key];
  204. if (!storage) {
  205. storage = function_call(data_type, [ size ]);
  206. deserialized_objects[root_key] = storage;
  207. }
  208. if (view_metadata) {
  209. var view_key = view_metadata.shift();
  210. var view_offset = view_metadata.shift();
  211. var view_size = view_metadata.shift();
  212. var view = deserialized_objects[view_key];
  213. if (!view) {
  214. view = null; // storage.slice(view_offset, view_offset + view_size);
  215. deserialized_objects[view_key] = view;
  216. }
  217. return view;
  218. }
  219. return storage;
  220. }
  221. throw new pickle.Error("Unknown persistent load type '" + typename + "'.");
  222. };
  223. var root = unpickler.load(function_call, persistent_load);
  224. var deserialized_storage_keys = unpickler.load();
  225. deserialized_storage_keys.forEach((key) => {
  226. if (deserialized_objects[key]) {
  227. var storage = deserialized_objects[key];
  228. storage.data = unpickler.read(storage.dataTypeSize * storage.size);
  229. }
  230. });
  231. if ((Array.isArray(root) && root.__setitem__ && root.every((item) => item.value.__type__ == 'torch.FloatTensor')) ||
  232. (root != null && root.state_dict && Array.isArray(root.state_dict))) {
  233. callback(new pytorch.Error("File does not contain a model graph. Use 'torch.save()' to save both the graph and tensor data."), null);
  234. return;
  235. }
  236. if (!root._modules) {
  237. callback(new pytorch.Error('Root object does not contain modules.'), null);
  238. return;
  239. }
  240. var model = new pytorch.Model(sysInfo, root);
  241. callback(null, model);
  242. }
  243. catch (error) {
  244. host.exception(error, false);
  245. callback(new pytorch.Error(error.message), null);
  246. return;
  247. }
  248. }
  249. };
  250. pytorch.Model = class {
  251. constructor(sysInfo, root) {
  252. this._graphs = [ new pytorch.Graph(sysInfo, root) ];
  253. }
  254. get format() {
  255. return 'PyTorch';
  256. }
  257. get graphs() {
  258. return this._graphs;
  259. }
  260. };
  261. pytorch.Graph = class {
  262. constructor(sysInfo, root) {
  263. this._type = root.__type__;
  264. this._nodes = [];
  265. this._inputs = [];
  266. this._outputs = [];
  267. this._groups = true;
  268. this._littleEndian = sysInfo.little_endian;
  269. var input = 'data';
  270. this._inputs.push(new pytorch.Argument(input, true, [ new pytorch.Connection(input, null, null) ]));
  271. var outputs = this._loadModule(root, [], [ input ]);
  272. outputs.forEach((output) => {
  273. this._outputs.push(new pytorch.Argument(output, true, [ new pytorch.Connection(output, null, null) ]));
  274. });
  275. }
  276. _loadModule(parent, groups, inputs) {
  277. if (parent.__type__ &&
  278. !parent.__type__.startsWith('torch.nn.modules.container.') &&
  279. (!parent._modules || parent._modules.length == 0)) {
  280. var node = new pytorch.Node(parent, groups, inputs, this._littleEndian);
  281. this._nodes.push(node);
  282. return [];
  283. }
  284. if (!parent._modules) {
  285. throw new pytorch.Error('Module does not contain modules.');
  286. }
  287. parent._modules.forEach((module) => {
  288. switch (module.value.__type__) {
  289. case 'torch.nn.modules.container.Sequential':
  290. groups.push(module.key);
  291. inputs = this._loadModule(module.value, groups, inputs);
  292. groups.pop(module.key);
  293. break;
  294. case 'torchvision.models.densenet._Transition':
  295. case 'torchvision.models.resnet.Bottleneck':
  296. case 'torchvision.models.densenet._DenseBlock':
  297. case 'torchvision.models.densenet._DenseLayer':
  298. case 'torchvision.models.inception.BasicConv2d':
  299. case 'torchvision.models.inception.InceptionAux':
  300. case 'torchvision.models.inception.InceptionA':
  301. case 'torchvision.models.inception.InceptionB':
  302. case 'torchvision.models.inception.InceptionC':
  303. case 'torchvision.models.inception.InceptionD':
  304. case 'torchvision.models.inception.InceptionE':
  305. groups.push(module.key);
  306. inputs = this._loadSource(module, groups, inputs);
  307. groups.pop(module.key);
  308. break;
  309. default:
  310. var node = new pytorch.Node(module, groups, inputs, this._littleEndian);
  311. this._nodes.push(node);
  312. inputs = [ node.name ];
  313. break;
  314. }
  315. });
  316. return inputs;
  317. }
  318. _loadSource(parent, groups, inputs) {
  319. var node = new pytorch.Node(parent, groups, inputs);
  320. this._nodes.push(node);
  321. inputs = [ node.name ];
  322. return inputs;
  323. }
  324. get type() {
  325. return this._type;
  326. }
  327. get groups() {
  328. return this._groups;
  329. }
  330. get inputs() {
  331. return this._inputs;
  332. }
  333. get outputs() {
  334. return this._outputs;
  335. }
  336. get nodes() {
  337. return this._nodes;
  338. }
  339. };
  340. pytorch.Argument = class {
  341. constructor(name, visible, connections) {
  342. this._name = name;
  343. this._visible = visible;
  344. this._connections = connections;
  345. }
  346. get name() {
  347. return this._name;
  348. }
  349. get visible() {
  350. return this._visible;
  351. }
  352. get connections() {
  353. return this._connections;
  354. }
  355. };
  356. pytorch.Connection = class {
  357. constructor(id, type, initializer) {
  358. this._id = id;
  359. this._type = type;
  360. this._initializer = initializer;
  361. }
  362. get id() {
  363. return this._id;
  364. }
  365. get type() {
  366. if (this._initializer) {
  367. return this._initializer.type;
  368. }
  369. return this._type;
  370. }
  371. get initializer() {
  372. return this._initializer;
  373. }
  374. };
  375. pytorch.Node = class {
  376. constructor(module, groups, connections, littleEndian) {
  377. this._group = groups.join('/');
  378. groups.push(module.key);
  379. this._name = groups.join('/');
  380. groups.pop();
  381. var obj = module.value;
  382. this._operator = obj.__type__.split('.').pop();
  383. this._inputs = [];
  384. this._inputs.push(new pytorch.Argument('input', true, connections.map((connection) => {
  385. return new pytorch.Connection(connection, null, null);
  386. })));
  387. var initializers = [];
  388. if (obj._parameters) {
  389. obj._parameters.forEach((parameter) => {
  390. initializers.push(parameter);
  391. });
  392. }
  393. if (obj._buffers) {
  394. obj._buffers.forEach((buffer) => {
  395. initializers.push(buffer);
  396. });
  397. }
  398. initializers.forEach((parameter) => {
  399. if (parameter && parameter.value && (parameter.value.data || parameter.value.storage)) {
  400. var initializer = null;
  401. if (parameter.value.data) {
  402. initializer = new pytorch.Tensor(parameter.value.data, littleEndian);
  403. }
  404. else if (parameter.value.storage) {
  405. initializer = new pytorch.Tensor(parameter.value, littleEndian);
  406. }
  407. var visible = (this._operator != 'LSTM' || initializer == null);
  408. this._inputs.push(new pytorch.Argument(parameter.key, visible, [ new pytorch.Connection(null, null, initializer) ]));
  409. }
  410. });
  411. this._outputs = [];
  412. this._outputs.push(new pytorch.Argument('output', true, [ new pytorch.Connection(this._name, null, null) ]));
  413. this._attributes = [];
  414. Object.keys(obj).forEach((key) => {
  415. if (!key.startsWith('_')) {
  416. this._attributes.push(new pytorch.Attribute(this, key, obj[key]));
  417. }
  418. });
  419. }
  420. get name() {
  421. return this._name;
  422. }
  423. get group() {
  424. return this._group;
  425. }
  426. get operator() {
  427. return this._operator;
  428. }
  429. get category() {
  430. var schema = pytorch.OperatorMetadata.operatorMetadata.getSchema(this._operator);
  431. return (schema && schema.category) ? schema.category : null;
  432. }
  433. get attributes() {
  434. return this._attributes;
  435. }
  436. get inputs() {
  437. return this._inputs;
  438. }
  439. get outputs() {
  440. return this._outputs;
  441. }
  442. };
  443. pytorch.Attribute = class {
  444. constructor(node, name, value) {
  445. this._node = node;
  446. this._name = name;
  447. this._value = value;
  448. var schema = pytorch.OperatorMetadata.operatorMetadata.getAttributeSchema(this._node.operator, this._name);
  449. if (schema) {
  450. if (schema.hasOwnProperty('visible') && !schema.visible) {
  451. this._visible = false;
  452. }
  453. else if (schema.hasOwnProperty('default')) {
  454. if (JSON.stringify(schema.default) == JSON.stringify(value)) {
  455. this._visible = false;
  456. }
  457. }
  458. }
  459. }
  460. get name() {
  461. return this._name;
  462. }
  463. get value() {
  464. return this._value;
  465. }
  466. get visible() {
  467. return this._visible == false ? false : true;
  468. }
  469. };
  470. pytorch.Tensor = class {
  471. constructor(tensor, littleEndian) {
  472. this._tensor = tensor;
  473. this._type = new pytorch.TensorType(tensor.storage.dataType, new pytorch.TensorShape(tensor.size));
  474. this._littleEndian = littleEndian;
  475. }
  476. get kind() {
  477. return 'Tensor';
  478. }
  479. get type() {
  480. return this._type;
  481. }
  482. get state() {
  483. return this._context().state;
  484. }
  485. get value() {
  486. var context = this._context();
  487. if (context.state) {
  488. return null;
  489. }
  490. context.limit = Number.MAX_SAFE_INTEGER;
  491. return this._decode(context, 0);
  492. }
  493. toString() {
  494. var context = this._context();
  495. if (context.state) {
  496. return '';
  497. }
  498. context.limit = 10000;
  499. var value = this._decode(context, 0);
  500. switch (this.dataType) {
  501. case 'int64':
  502. return pytorch.Tensor._stringify(value, '', ' ');
  503. }
  504. return JSON.stringify(value, null, 4);
  505. }
  506. _context() {
  507. var context = {};
  508. context.state = null;
  509. context.index = 0;
  510. context.count = 0;
  511. if (!this._type.dataType) {
  512. context.state = 'Tensor has no data type.';
  513. return context;
  514. }
  515. if (!this._type.shape) {
  516. context.state = 'Tensor has no dimensions.';
  517. return context;
  518. }
  519. if (!this._tensor.storage || !this._tensor.storage.data) {
  520. context.state = 'Tensor data is empty.';
  521. return context;
  522. }
  523. context.data = this._tensor.storage.data;
  524. context.dataType = this._type.dataType;
  525. context.dimensions = this._type.shape.dimensions;
  526. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  527. return context;
  528. }
  529. _decode(context, dimension) {
  530. var results = [];
  531. var size = context.dimensions[dimension];
  532. if (dimension == context.dimensions.length - 1) {
  533. for (var i = 0; i < size; i++) {
  534. if (context.count > context.limit) {
  535. results.push('...');
  536. return results;
  537. }
  538. switch (context.dataType)
  539. {
  540. case 'uint8':
  541. results.push(context.dataView.getUint8(context.index, this._littleEndian));
  542. context.index += 1;
  543. context.count++;
  544. break;
  545. case 'float16':
  546. results.push(context.dataView.getFloat16(context.index, this._littleEndian));
  547. context.index += 2;
  548. context.count++;
  549. break;
  550. case 'float32':
  551. results.push(context.dataView.getFloat32(context.index, this._littleEndian));
  552. context.index += 4;
  553. context.count++;
  554. break;
  555. case 'float64':
  556. results.push(context.dataView.getFloat64(context.index, this._littleEndian));
  557. context.index += 8;
  558. context.count++;
  559. break;
  560. case 'int64':
  561. results.push(new base.Int64(context.data.subarray(context.index, context.index + 8)));
  562. context.index += 8;
  563. context.count++;
  564. break;
  565. }
  566. }
  567. }
  568. else {
  569. for (var j = 0; j < size; j++) {
  570. if (context.count > context.limit) {
  571. results.push('...');
  572. return results;
  573. }
  574. results.push(this._decode(context, dimension + 1));
  575. }
  576. }
  577. return results;
  578. }
  579. static _stringify(value, indentation, indent) {
  580. if (Array.isArray(value)) {
  581. var result = [];
  582. result.push('[');
  583. var items = value.map((item) => pytorch.Tensor._stringify(item, indentation + indent, indent));
  584. if (items.length > 0) {
  585. result.push(items.join(',\n'));
  586. }
  587. result.push(']');
  588. return result.join('\n');
  589. }
  590. return indentation + value.toString();
  591. }
  592. };
  593. pytorch.TensorType = class {
  594. constructor(dataType, shape) {
  595. this._dataType = dataType;
  596. this._shape = shape;
  597. }
  598. get dataType() {
  599. return this._dataType;
  600. }
  601. get shape() {
  602. return this._shape;
  603. }
  604. toString() {
  605. return this._dataType + this._shape.toString();
  606. }
  607. };
  608. pytorch.TensorShape = class {
  609. constructor(dimensions) {
  610. this._dimensions = dimensions;
  611. }
  612. get dimensions() {
  613. return this._dimensions;
  614. }
  615. toString() {
  616. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  617. }
  618. };
  619. pytorch.OperatorMetadata = class {
  620. static open(host, callback) {
  621. if (pytorch.OperatorMetadata.operatorMetadata) {
  622. callback(null, pytorch.OperatorMetadata.operatorMetadata);
  623. }
  624. else {
  625. host.request(null, 'pytorch-metadata.json', 'utf-8', (err, data) => {
  626. pytorch.OperatorMetadata.operatorMetadata = new pytorch.OperatorMetadata(data);
  627. callback(null, pytorch.OperatorMetadata.operatorMetadata);
  628. });
  629. }
  630. }
  631. constructor(data) {
  632. this._map = {};
  633. if (data) {
  634. var items = JSON.parse(data);
  635. if (items) {
  636. items.forEach((item) => {
  637. if (item.name && item.schema)
  638. {
  639. var name = item.name;
  640. var schema = item.schema;
  641. this._map[name] = schema;
  642. }
  643. });
  644. }
  645. }
  646. }
  647. getSchema(operator) {
  648. return this._map[operator] || null;
  649. }
  650. getAttributeSchema(operator, name) {
  651. var schema = this._map[operator];
  652. if (schema && schema.attributes && schema.attributes.length > 0) {
  653. if (!schema.attributesMap) {
  654. schema.attributesMap = {};
  655. schema.attributes.forEach((attribute) => {
  656. schema.attributesMap[attribute.name] = attribute;
  657. });
  658. }
  659. return schema.attributesMap[name] || null;
  660. }
  661. return null;
  662. }
  663. };
  664. pytorch.Error = class extends Error {
  665. constructor(message) {
  666. super(message);
  667. this.name = 'Error loading PyTorch model.';
  668. }
  669. };
  670. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  671. module.exports.ModelFactory = pytorch.ModelFactory;
  672. }