torch.js 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290
  1. var torch = torch || {};
  2. torch.ModelFactory = class {
  3. match(context) {
  4. return torch.T7Reader.open(context);
  5. }
  6. open(context, match) {
  7. return torch.Metadata.open(context).then((metadata) => {
  8. const identifier = context.identifier;
  9. const reader = match;
  10. reader.callback = (name) => {
  11. if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  12. context.exception(new torch.Error("Unsupported type '" + name + "' in '" + identifier + "'."), false);
  13. }
  14. return null;
  15. };
  16. let root = reader.read();
  17. if (root && Array.isArray(root) && root.length == 2 && root[0].__class__ && !root[1].__class__) {
  18. root = root[0];
  19. }
  20. return new torch.Model(metadata, root);
  21. });
  22. }
  23. };
  24. torch.Model = class {
  25. constructor(metadata, root) {
  26. this._graphs = [];
  27. this._graphs.push(new torch.Graph(metadata, root));
  28. }
  29. get graphs() {
  30. return this._graphs;
  31. }
  32. get format() {
  33. return 'Torch v7';
  34. }
  35. };
  36. torch.Graph = class {
  37. constructor(metadata, root) {
  38. this._inputs = [];
  39. this._outputs = [];
  40. this._nodes = [];
  41. this._groups = 'false';
  42. if (Object.prototype.hasOwnProperty.call(root, 'model')) {
  43. root = root.model;
  44. }
  45. const inputs = [];
  46. const outputs = [];
  47. this._loadModule(metadata, root, [], '', inputs, outputs);
  48. this._inputs = this._inputs.concat(inputs.map((input, index) => {
  49. return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
  50. }));
  51. this._outputs = this._outputs.concat(outputs.map((output, index) => {
  52. return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
  53. }));
  54. }
  55. get inputs() {
  56. return this._inputs;
  57. }
  58. get outputs() {
  59. return this._outputs;
  60. }
  61. get nodes() {
  62. return this._nodes;
  63. }
  64. get groups() {
  65. return this._groups;
  66. }
  67. _loadModule(metadata, module, groups, key, inputs, outputs) {
  68. if (groups.length > 0) {
  69. this._groups = true;
  70. }
  71. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : '';
  72. switch (type) {
  73. case 'nn.Sequential': {
  74. groups.push(key);
  75. let subInputs = inputs;
  76. let subOutputs = [];
  77. const length = module.modules.length;
  78. let index = 0;
  79. for (const subModule of module.modules) {
  80. if (index == length - 1) {
  81. subOutputs = outputs;
  82. }
  83. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  84. subInputs = subOutputs;
  85. subOutputs = [];
  86. index++;
  87. }
  88. groups.pop();
  89. break;
  90. }
  91. case 'nn.Parallel':
  92. case 'nn.ParallelTable':
  93. case 'nn.JointTrain': {
  94. groups.push(key);
  95. let newInputs = [];
  96. let newOutputs = [];
  97. let index = 0;
  98. for (const subModule of module.modules) {
  99. const subInputs = [].concat(inputs);
  100. const subOutputs = [].concat(outputs);
  101. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  102. if (inputs.length == 0) {
  103. newInputs = newInputs.concat(subInputs);
  104. }
  105. if (outputs.length == 0) {
  106. newOutputs = newOutputs.concat(subOutputs);
  107. }
  108. index++;
  109. }
  110. inputs = inputs.concat(newInputs);
  111. for (const newOutput of newOutputs) {
  112. outputs.push(newOutput);
  113. }
  114. groups.pop();
  115. break;
  116. }
  117. case 'nn.Concat':
  118. case 'nn.ConcatTable': {
  119. const prefix = key;
  120. if (inputs.length == 0) {
  121. inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
  122. }
  123. let concatInputs = [];
  124. let index = 0;
  125. for (const subModule of module.modules) {
  126. const streamInputs = inputs.map((input) => input);
  127. const streamOutputs = [];
  128. this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
  129. concatInputs = concatInputs.concat(streamOutputs);
  130. index++;
  131. }
  132. delete module.modules;
  133. delete module.dimension;
  134. this._createNode(metadata, module, groups, key, concatInputs, outputs);
  135. break;
  136. }
  137. case 'nn.Inception': {
  138. delete module.modules; // TODO
  139. delete module.module; // TODO
  140. delete module.transfer; // TODO
  141. delete module.pool; // TODO
  142. this._createNode(metadata, module, groups, key, inputs, outputs);
  143. break;
  144. }
  145. case 'nn.gModule': {
  146. /*
  147. let index = 0;
  148. for (const subModule of module.modules) {
  149. subModule.modules = [];
  150. this._loadModule(metadata, subModule, groups, index.toString(), [], []);
  151. index++;
  152. }
  153. */
  154. this._createNode(metadata, module, groups, key, inputs, outputs);
  155. break;
  156. }
  157. default: {
  158. this._createNode(metadata, module, groups, key, inputs, outputs);
  159. break;
  160. }
  161. }
  162. }
  163. _createNode(metadata, module, group, subIndex, inputs, outputs) {
  164. const node = new torch.Node(metadata, module, group, subIndex, inputs, outputs);
  165. this._nodes.push(node);
  166. }
  167. };
  168. torch.Parameter = class {
  169. constructor(name, visible, args) {
  170. this._name = name;
  171. this._visible = visible;
  172. this._arguments = args;
  173. }
  174. get name() {
  175. return this._name;
  176. }
  177. get visible() {
  178. return this._visible;
  179. }
  180. get arguments() {
  181. return this._arguments;
  182. }
  183. };
  184. torch.Argument = class {
  185. constructor(name, type, initializer) {
  186. if (typeof name !== 'string') {
  187. throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  188. }
  189. this._name = name;
  190. this._type = type;
  191. this._initializer = initializer;
  192. }
  193. get name() {
  194. return this._name;
  195. }
  196. get type() {
  197. if (this._initializer) {
  198. return this._initializer.type;
  199. }
  200. return this._type;
  201. }
  202. get initializer() {
  203. return this._initializer;
  204. }
  205. };
  206. torch.Node = class {
  207. constructor(metadata, module, groups, name, inputs, outputs) {
  208. this._group = groups.join('/');
  209. if (module.name && typeof module.name === 'string') {
  210. this._name = module.name;
  211. delete module.name;
  212. }
  213. else {
  214. this._name = this._group ? (this._group + ':' + name) : name;
  215. }
  216. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : 'nn.Module';
  217. this._type = metadata.type(type);
  218. let initializers = [];
  219. for (const entry of Object.entries(module)) {
  220. const key = entry[0];
  221. const obj = entry[1];
  222. if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
  223. module[key] = obj.data();
  224. }
  225. }
  226. delete module.iSize;
  227. delete module.finput;
  228. delete module.fgradInput;
  229. delete module.output;
  230. delete module.gradInput;
  231. delete module.gradWeight;
  232. delete module.gradBias;
  233. delete module.grad_tmp;
  234. delete module.scaleT;
  235. delete module._input;
  236. delete module._output;
  237. delete module._gradInput;
  238. delete module._gradOutput;
  239. delete module.buffer;
  240. delete module.buffer2;
  241. delete module.tmp_in;
  242. delete module.tmp_out;
  243. delete module.accUpdateGradParameters;
  244. switch (this._type.name) {
  245. case 'nn.Linear':
  246. delete module.addBuffer;
  247. break;
  248. case 'nn.Normalize':
  249. case 'nn.Normalize2':
  250. delete module.addBuffer;
  251. delete module.normp;
  252. delete module.norm;
  253. break;
  254. case 'cudnn.SpatialConvolution':
  255. case 'cudnn.SpatialFullConvolution':
  256. case 'nn.SpatialConvolution':
  257. case 'nn.SpatialConvolutionMM':
  258. case 'nn.SpatialDilatedConvolution':
  259. case 'nn.SpatialFullConvolution':
  260. delete module.ones;
  261. delete module.input_slice;
  262. delete module.output_slice;
  263. delete module.convDescData;
  264. this._updateSize(module, 'adj');
  265. this._updateSize(module, 'd');
  266. this._updateSize(module, 'dilation');
  267. this._updateSize(module, 'k');
  268. this._updateSize(module, 'pad');
  269. break;
  270. case 'cudnn.BatchNormalization':
  271. case 'cudnn.SpatialBatchNormalization':
  272. case 'nn.BatchNormalization':
  273. case 'nn.SpatialBatchNormalization':
  274. case 'nn.InstanceNormalization':
  275. delete module.save_mean;
  276. delete module.save_std;
  277. delete module.gradWeight;
  278. delete module.normalized;
  279. delete module.centered;
  280. delete module.bn; // TODO InstanceNormalization
  281. break;
  282. case 'nn.SpatialCrossMapLRN':
  283. delete module.scale;
  284. break;
  285. case 'cudnn.SpatialMaxPooling':
  286. case 'cudnn.SpatialAveragePooling':
  287. case 'inn.SpatialMaxPooling':
  288. case 'nn.SpatialMaxPooling':
  289. case 'nn.SpatialAveragePooling':
  290. delete module.indices;
  291. this._updateSize(module, 'pad');
  292. this._updateSize(module, 'd');
  293. this._updateSize(module, 'k');
  294. break;
  295. case 'nn.SpatialZeroPadding':
  296. case 'nn.SpatialReflectionPadding':
  297. case 'nn.SpatialReplicationPadding':
  298. this._updateBox(module, 'pad');
  299. break;
  300. case 'nn.Dropout':
  301. delete module.noise;
  302. break;
  303. case 'nn.gModule':
  304. delete module.forwardnodes;
  305. delete module.backwardnodes;
  306. break;
  307. case 'nn.StereoJoin':
  308. delete module.output_L;
  309. break;
  310. default:
  311. break;
  312. }
  313. this._attributes = [];
  314. if (module.__class__) {
  315. for (const entry of Object.entries(module)) {
  316. const key = entry[0];
  317. const obj = entry[1];
  318. if (key == '_type') {
  319. continue;
  320. }
  321. if (Array.isArray(obj) && obj.every(((item) => item && item.__class__ && item.__class__.__module__ === 'nn'))) {
  322. continue;
  323. }
  324. if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
  325. initializers.push(new torch.Parameter(key, true, [
  326. new torch.Argument(key, null, new torch.Tensor(obj))
  327. ]));
  328. continue;
  329. }
  330. if (key == 'modules') {
  331. continue;
  332. }
  333. if (obj.__class__ && obj.__class__.__module__ !== '' && obj.__class__.__name__ != 'LuaFunction') {
  334. continue;
  335. }
  336. const attribute = new torch.Attribute(metadata, type, key, obj);
  337. this._attributes.push(attribute);
  338. }
  339. }
  340. this._inputs = [];
  341. if (inputs.length == 0 && this._name) {
  342. inputs.push(new torch.Argument(this._name + ':in', null, null));
  343. }
  344. this._inputs.push(new torch.Parameter('input', true, inputs));
  345. if (outputs.length == 0 && this._name) {
  346. outputs.push(new torch.Argument(this._name, null, null));
  347. }
  348. this._outputs = [];
  349. this._outputs.push(new torch.Parameter('output', true, outputs));
  350. initializers = initializers.filter((argument) => {
  351. if (argument.name == 'weight') {
  352. this._inputs.push(argument);
  353. return false;
  354. }
  355. return true;
  356. });
  357. initializers = initializers.filter((argument) => {
  358. if (argument.name == 'bias') {
  359. this._inputs.push(argument);
  360. return false;
  361. }
  362. return true;
  363. });
  364. this._inputs = this._inputs.concat(initializers);
  365. }
  366. get name() {
  367. return this._name;
  368. }
  369. get type() {
  370. return this._type;
  371. }
  372. get group() {
  373. return this._group;
  374. }
  375. get attributes() {
  376. return this._attributes;
  377. }
  378. get inputs() {
  379. return this._inputs;
  380. }
  381. get outputs() {
  382. return this._outputs;
  383. }
  384. _updateSize(module, name) {
  385. if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
  386. Object.prototype.hasOwnProperty.call(module, name + 'H')) {
  387. module[name] = [ module[name + 'W'], module[name + 'H'] ];
  388. delete module[name + 'W'];
  389. delete module[name + 'H'];
  390. }
  391. }
  392. _updateBox(module, name) {
  393. if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
  394. Object.prototype.hasOwnProperty.call(module, name + '_r') &&
  395. Object.prototype.hasOwnProperty.call(module, name + '_b') &&
  396. Object.prototype.hasOwnProperty.call(module, name + '_l')) {
  397. module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
  398. delete module[name + '_t'];
  399. delete module[name + '_r'];
  400. delete module[name + '_b'];
  401. delete module[name + '_l'];
  402. }
  403. }
  404. };
  405. torch.Attribute = class {
  406. constructor(metadata, type, name, value) {
  407. this._name = name;
  408. this._value = value;
  409. if (name == 'train') {
  410. this._visible = false;
  411. }
  412. const schema = metadata.attribute(type, name);
  413. if (schema) {
  414. if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
  415. this._visible = schema.visible;
  416. }
  417. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  418. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  419. this._visible = false;
  420. }
  421. }
  422. }
  423. }
  424. get name() {
  425. return this._name;
  426. }
  427. get value() {
  428. return this._value;
  429. }
  430. get visible() {
  431. return this._visible == false ? false : true;
  432. }
  433. };
  434. torch.Tensor = class {
  435. constructor(tensor) {
  436. this._type = new torch.TensorType(tensor);
  437. this._storage = tensor.storage;
  438. this._offset = tensor.storage_offset;
  439. }
  440. get type() {
  441. return this._type;
  442. }
  443. get state() {
  444. return this._context().state || null;
  445. }
  446. get value() {
  447. const context = this._context();
  448. if (context.state) {
  449. return null;
  450. }
  451. context.limit = Number.MAX_SAFE_INTEGER;
  452. return this._decode(context, 0);
  453. }
  454. toString() {
  455. const context = this._context();
  456. if (context.state) {
  457. return '';
  458. }
  459. context.limit = 1000;
  460. const value = this._decode(context, 0);
  461. return JSON.stringify(value, null, 4);
  462. }
  463. _context() {
  464. const context = {};
  465. context.state = null;
  466. context.index = 0;
  467. context.count = 0;
  468. if (!this._storage) {
  469. context.state = 'Tensor data is empty.';
  470. return context;
  471. }
  472. context.data = this._storage.data();
  473. context.index = this._offset;
  474. if (!context.data) {
  475. context.state = 'Tensor data is empty.';
  476. return context;
  477. }
  478. switch (this._type.dataType) {
  479. case 'uint8':
  480. case 'int8':
  481. case 'int16':
  482. case 'int32':
  483. case 'int64':
  484. case 'float32':
  485. case 'float64':
  486. break;
  487. default:
  488. context.state = 'Tensor data type is not implemented.';
  489. break;
  490. }
  491. context.dimensions = this._type.shape.dimensions;
  492. if (!context.dimensions && context.dimensions.length == 0) {
  493. context.state = 'Tensor has no dimensions.';
  494. return context;
  495. }
  496. return context;
  497. }
  498. _decode(context, dimension) {
  499. const results = [];
  500. const size = context.dimensions[dimension];
  501. if (dimension == context.dimensions.length - 1) {
  502. for (let i = 0; i < size; i++) {
  503. if (context.count > context.limit) {
  504. results.push('...');
  505. return results;
  506. }
  507. results.push(context.data[context.index]);
  508. context.index++;
  509. context.count++;
  510. }
  511. }
  512. else {
  513. for (let j = 0; j < size; j++) {
  514. if (context.count > context.limit) {
  515. results.push('...');
  516. return results;
  517. }
  518. results.push(this._decode(context, dimension + 1));
  519. }
  520. }
  521. return results;
  522. }
  523. };
  524. torch.TensorType = class {
  525. constructor(tensor) {
  526. this._dataType = tensor.dataType;
  527. this._shape = new torch.TensorShape(tensor.size);
  528. }
  529. get dataType() {
  530. return this._dataType;
  531. }
  532. get shape() {
  533. return this._shape;
  534. }
  535. toString() {
  536. return (this.dataType || '?') + this._shape.toString();
  537. }
  538. };
  539. torch.TensorShape = class {
  540. constructor(dimensions) {
  541. this._dimensions = dimensions;
  542. }
  543. get dimensions() {
  544. return this._dimensions;
  545. }
  546. toString() {
  547. if (this._dimensions) {
  548. if (this._dimensions.length == 0) {
  549. return '';
  550. }
  551. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  552. }
  553. return '';
  554. }
  555. };
  556. torch.Metadata = class {
  557. static open(context) {
  558. if (torch.Metadata._metadata) {
  559. return Promise.resolve(torch.Metadata._metadata);
  560. }
  561. return context.request('torch-metadata.json', 'utf-8', null).then((data) => {
  562. torch.Metadata._metadata = new torch.Metadata(data);
  563. return torch.Metadata._metadata;
  564. }).catch(() => {
  565. torch.Metadata._metadata = new torch.Metadata(null);
  566. return torch.Metadata._metadata;
  567. });
  568. }
  569. constructor(data) {
  570. this._types = new Map();
  571. this._attributes = new Map();
  572. if (data) {
  573. const items = JSON.parse(data);
  574. for (const item of items) {
  575. this._types.set(item.name, item);
  576. }
  577. }
  578. }
  579. type(name) {
  580. if (!this._types.has(name)) {
  581. this._types.set(name, { name: name });
  582. }
  583. return this._types.get(name);
  584. }
  585. attribute(type, name) {
  586. const key = type + ':' + name;
  587. if (!this._attributes.has(key)) {
  588. this._attributes.set(key, null);
  589. const metadata = this.type(type);
  590. if (metadata && Array.isArray(metadata.attributes)) {
  591. for (const attribute of metadata.attributes) {
  592. this._attributes.set(type + ':' + attribute.name, attribute);
  593. }
  594. }
  595. }
  596. return this._attributes.get(key);
  597. }
  598. };
  599. torch.Error = class extends Error {
  600. constructor(message) {
  601. super(message);
  602. this.name = 'Error loading Torch model.';
  603. }
  604. };
  605. torch.T7Reader = class {
  606. static open(context) {
  607. const stream = context.stream;
  608. if (stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index == 0 && value <= 0x08))) {
  609. const reader = new torch.BinaryReader(stream);
  610. return new torch.T7Reader(reader);
  611. }
  612. if (stream.length >= 2) {
  613. const buffer = stream.peek(2);
  614. const value = String.fromCharCode(stream.peek(1)[0]);
  615. if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
  616. const reader = new torch.TextReader(stream);
  617. return new torch.T7Reader(reader);
  618. }
  619. }
  620. return null;
  621. }
  622. constructor(reader) {
  623. this._reader = reader;
  624. this._memo = new Map();
  625. this._types = new Map();
  626. const Storage = class {
  627. constructor(dataType, itemSize) {
  628. this.dataType = dataType;
  629. this.itemSize = itemSize;
  630. }
  631. data() {
  632. if (this.reader) {
  633. const reader = this.reader;
  634. reader.reset();
  635. const dataType = this.dataType;
  636. const size = this.size;
  637. const array = new Array(size);
  638. for (let i = 0; i < size; i++) {
  639. switch (dataType) {
  640. case 'uint8':
  641. array[i] = reader.byte();
  642. break;
  643. case 'int8':
  644. array[i] = reader.int8();
  645. break;
  646. case 'int16':
  647. array[i] = reader.int16();
  648. break;
  649. case 'int32':
  650. array[i] = reader.int32();
  651. break;
  652. case 'int64':
  653. array[i] = reader.int64();
  654. break;
  655. case 'float32':
  656. array[i] = reader.float32();
  657. break;
  658. case 'float64':
  659. array[i] = reader.float64();
  660. break;
  661. default:
  662. throw new torch.Error("Unsupported data type '" + dataType + "'.");
  663. }
  664. }
  665. this._data = array;
  666. delete this.reader;
  667. }
  668. return this._data;
  669. }
  670. read(reader) {
  671. this.size = reader.int64();
  672. this.reader = reader.storage(this.size, this.itemSize, this.dataType);
  673. }
  674. };
  675. const Tensor = class {
  676. constructor(dataType) {
  677. this.dataType = dataType;
  678. }
  679. read(reader) {
  680. const dim = reader.int32();
  681. this.size = reader.int64s(dim);
  682. this.stride = reader.int64s(dim);
  683. this.storage_offset = reader.int64() - 1;
  684. this.storage = reader.read();
  685. }
  686. };
  687. this.register('bnn.Binary');
  688. this.register('bnn.SpatialConvolution');
  689. this.register('cudnn.BatchNormalization');
  690. this.register('cudnn.BatchBRNNReLU');
  691. this.register('cudnn.BLSTM');
  692. this.register('cudnn.ReLU');
  693. this.register('cudnn.RNN');
  694. this.register('cudnn.Sigmoid');
  695. this.register('cudnn.SoftMax');
  696. this.register('cudnn.LogSoftMax');
  697. this.register('cudnn.normal3DConv');
  698. this.register('cudnn.normal3DdeConv');
  699. this.register('cudnn.SpatialAveragePooling');
  700. this.register('cudnn.SpatialBatchNormalization');
  701. this.register('cudnn.SpatialConvolution');
  702. this.register('cudnn.SpatialFullConvolution');
  703. this.register('cudnn.SpatialMaxPooling');
  704. this.register('cudnn.SpatialSoftMax');
  705. this.register('cudnn.Tanh');
  706. this.register('cudnn.VolumetricAveragePooling');
  707. this.register('cudnn.VolumetricBatchNormalization');
  708. this.register('cudnn.VolumetricConvolution');
  709. this.register('cudnn.VolumetricMaxPooling');
  710. this.register('Dict');
  711. this.register('inn.ConstAffine');
  712. this.register('inn.SpatialMaxPooling');
  713. this.register('nn.Abs');
  714. this.register('nn.AddConstant');
  715. this.register('nn.BatchNormalization');
  716. this.register('nn.BilinearSamplerBHWD');
  717. this.register('nn.BinActiveZ'); // allenai/XNOR-Net
  718. this.register('nn.BCECriterion');
  719. this.register('nn.Bottle');
  720. this.register('nn.Clamp');
  721. this.register('nn.CMul');
  722. this.register('nn.CAddTable');
  723. this.register('nn.CDivTable');
  724. this.register('nn.CMulTable');
  725. this.register('nn.CSubTable');
  726. this.register('nn.Concat');
  727. this.register('nn.Copy');
  728. this.register('nn.ConcatTable');
  729. this.register('nn.Contiguous');
  730. this.register('nn.Constant');
  731. this.register('nn.CostVolMulti');
  732. this.register('nn.DataParallelTable');
  733. this.register('nn.DepthConcat');
  734. this.register('nn.Dropout');
  735. this.register('nn.Exp');
  736. this.register('nn.ExpOut');
  737. this.register('nn.FlattenTable');
  738. this.register('nn.GenNoise');
  739. this.register('nn.Identity');
  740. this.register('nn.Index');
  741. this.register('nn.Inception');
  742. this.register('nn.InstanceNormalization');
  743. this.register('nn.JoinTable');
  744. this.register('nn.JointTrain');
  745. this.register('nn.KeypointCoordinate');
  746. this.register('nn.LeakyReLU');
  747. this.register('nn.Linear');
  748. this.register('nn.LinearNoBias');
  749. this.register('nn.LogSoftMax');
  750. this.register('nn.LookupTable');
  751. this.register('nn.LSTM');
  752. this.register('nn.MaskZero');
  753. this.register('nn.MapTable');
  754. this.register('nn.Max');
  755. this.register('nn.Mean');
  756. this.register('nn.Min');
  757. this.register('nn.MulConstant');
  758. this.register('nn.MM');
  759. this.register('nn.MSECriterion');
  760. this.register('nn.Narrow');
  761. this.register('nn.NarrowTable');
  762. this.register('nn.Normalize');
  763. this.register('nn.Normalize2');
  764. this.register('nn.NoiseFill');
  765. this.register('nn.Padding');
  766. this.register('nn.Parallel');
  767. this.register('nn.ParallelCriterion');
  768. this.register('nn.ParallelTable');
  769. this.register('nn.PixelShuffle');
  770. this.register('nn.Power');
  771. this.register('nn.PReLU');
  772. this.register('nn.Recursor');
  773. this.register('nn.ReLU');
  774. this.register('nn.Replicate');
  775. this.register('nn.Reshape');
  776. this.register('nn.ShaveImage');
  777. this.register('nn.Select');
  778. this.register('nn.SelectTable');
  779. this.register('nn.Sequencer');
  780. this.register('nn.Sequential');
  781. this.register('nn.Sigmoid');
  782. this.register('nn.Sum');
  783. this.register('nn.SoftMax');
  784. this.register('nn.SpatialAveragePooling');
  785. this.register('nn.SpatialBatchNormalization');
  786. this.register('nn.SpatialConvolution');
  787. this.register('nn.SpatialConvolutionMM');
  788. this.register('nn.SpatialCrossMapLRN');
  789. this.register('nn.SpatialDilatedConvolution');
  790. this.register('nn.SpatialDropout');
  791. this.register('nn.SpatialFractionalMaxPooling');
  792. this.register('nn.SpatialFullConvolution');
  793. this.register('nn.SpatialLPPooling');
  794. this.register('nn.SpatialMaxPooling');
  795. this.register('nn.SpatialMaxUnpooling');
  796. this.register('nn.SpatialReflectionPadding');
  797. this.register('nn.SpatialReplicationPadding');
  798. this.register('nn.SpatialSoftMax');
  799. this.register('nn.SpatialSubtractiveNormalization');
  800. this.register('nn.SpatialUpSamplingBilinear');
  801. this.register('nn.SpatialUpSamplingNearest');
  802. this.register('nn.SpatialZeroPadding');
  803. this.register('nn.SplitTable');
  804. this.register('nn.Squeeze');
  805. this.register('nn.Square');
  806. this.register('nn.Sqrt');
  807. this.register('nn.StereoJoin');
  808. this.register('nn.Tanh');
  809. this.register('nn.Transpose');
  810. this.register('nn.TotalVariation');
  811. this.register('nn.Unpool');
  812. this.register('nn.View');
  813. this.register('nn.gModule');
  814. this.register('nngraph.Node');
  815. this.register('graph.Edge');
  816. this.register('graph.Graph');
  817. this.register('torch.ByteTensor', class extends Tensor { constructor() { super('uint8'); } });
  818. this.register('torch.CharTensor', class extends Tensor { constructor() { super('int8'); } });
  819. this.register('torch.ShortTensor', class extends Tensor { constructor() { super('int16'); } });
  820. this.register('torch.IntTensor', class extends Tensor { constructor() { super('int32'); } });
  821. this.register('torch.LongTensor', class extends Tensor { constructor() { super('int64'); } });
  822. this.register('torch.FloatTensor', class extends Tensor { constructor() { super('float32'); } });
  823. this.register('torch.DoubleTensor', class extends Tensor { constructor() { super('float64'); } });
  824. this.register('torch.CudaByteTensor', class extends Tensor { constructor() { super('uint8'); } });
  825. this.register('torch.CudaCharTensor', class extends Tensor { constructor() { super('int8'); } });
  826. this.register('torch.CudaShortTensor', class extends Tensor { constructor() { super('int16'); } });
  827. this.register('torch.CudaIntTensor', class extends Tensor { constructor() { super('int32'); } });
  828. this.register('torch.CudaLongTensor', class extends Tensor { constructor() { super('int64'); } });
  829. this.register('torch.CudaTensor', class extends Tensor { constructor() { super('float32'); } });
  830. this.register('torch.CudaDoubleTensor', class extends Tensor { constructor() { super('float64'); } });
  831. this.register('torch.ByteStorage', class extends Storage { constructor() { super('uint8', 1); } });
  832. this.register('torch.CharStorage', class extends Storage { constructor() { super('int8', 1); } });
  833. this.register('torch.ShortStorage', class extends Storage { constructor() { super('int16', 2); } });
  834. this.register('torch.IntStorage', class extends Storage { constructor() { super('int32', 4); } });
  835. this.register('torch.LongStorage', class extends Storage { constructor() { super('int64', 8); } });
  836. this.register('torch.FloatStorage', class extends Storage { constructor() { super('float32', 4); } });
  837. this.register('torch.DoubleStorage', class extends Storage { constructor() { super('float64', 8); } });
  838. this.register('torch.CudaByteStorage', class extends Storage { constructor() { super('uint8', 1); } });
  839. this.register('torch.CudaCharStorage', class extends Storage { constructor() { super('int8', 1); } });
  840. this.register('torch.CudaShortStorage', class extends Storage { constructor() { super('int16', 2); } });
  841. this.register('torch.CudaIntStorage', class extends Storage { constructor() { super('int32', 4); } });
  842. this.register('torch.CudaLongStorage', class extends Storage { constructor() { super('int64', 8); } });
  843. this.register('torch.CudaIntStorage', class extends Storage { constructor() { super('int32', 4); } });
  844. this.register('torch.CudaStorage', class extends Storage { constructor() { super('float32', 4); } });
  845. this.register('torch.CudaFloatStorage', class extends Storage { constructor() { super('float64', 8); } });
  846. this.register('w2nn.AuxiliaryLossTable');
  847. this.register('w2nn.InplaceClip01');
  848. this.register('w2nn.ScaleTable');
  849. this.register('LuaFunction', class {
  850. constructor(size, dumped, upvalues) {
  851. this.size = size;
  852. this.dumped = dumped;
  853. this.upvalues = upvalues;
  854. }
  855. });
  856. }
  857. register(name, type) {
  858. type = type || class {};
  859. const parts = name.split('.');
  860. type.__name__ = parts.pop();
  861. type.__module__ = parts.join('.');
  862. type.prototype.__class__ = type;
  863. this._types.set(name, type);
  864. }
  865. read() {
  866. const type = this.int32();
  867. switch (type) {
  868. case 0: return null;
  869. case 1: return this.float64();
  870. case 2: return this.string();
  871. case 3: return this.table();
  872. case 4: return this.object();
  873. case 5: return this.boolean();
  874. case 6: return this.function();
  875. case 7: return this.function();
  876. case 8: return this.function();
  877. default: throw new torch.Error("File format has invalid type '" + type + "'.");
  878. }
  879. }
  880. boolean() {
  881. return this._reader.boolean();
  882. }
  883. bytes(size) {
  884. return this._reader.bytes(size);
  885. }
  886. int32() {
  887. return this._reader.int32();
  888. }
  889. int64() {
  890. return this._reader.int64();
  891. }
  892. int64s(size) {
  893. return this._reader.int64s(size);
  894. }
  895. float64() {
  896. return this._reader.float64();
  897. }
  898. string() {
  899. return this._reader.string();
  900. }
  901. object() {
  902. const index = this.int32();
  903. if (this._memo.has(index)) {
  904. return this._memo.get(index);
  905. }
  906. let version = this.string();
  907. let name = null;
  908. if (version.startsWith('V ')) {
  909. name = this.string();
  910. version = Number(version.split(' ')[1]);
  911. }
  912. else {
  913. name = version;
  914. version = 0;
  915. }
  916. if (!this._types.has(name)) {
  917. this.callback(name);
  918. this.register(name);
  919. }
  920. const type = this._types.get(name);
  921. const obj = Reflect.construct(type, []);
  922. this._memo.set(index, obj);
  923. if (obj.read) {
  924. obj.read(this, version);
  925. }
  926. else {
  927. const attributes = this.read();
  928. if (attributes != null) {
  929. for (const entry of Object.entries(attributes)) {
  930. const key = entry[0];
  931. obj[key] = entry[1];
  932. }
  933. }
  934. }
  935. return obj;
  936. }
  937. table() {
  938. const index = this.int32();
  939. if (this._memo.has(index)) {
  940. return this._memo.get(index);
  941. }
  942. const table = {};
  943. this._memo.set(index, table);
  944. const size = this.int32();
  945. let convert = true;
  946. let sum = 0;
  947. for (let i = 0; i < size; i++) {
  948. const key = this.read();
  949. const value = this.read();
  950. table[key] = value;
  951. if (Number.isInteger(key) && key >= 0) {
  952. sum += key;
  953. }
  954. else {
  955. convert = false;
  956. }
  957. }
  958. const n = Object.keys(table).length;
  959. if (convert && (n * (n + 1)) == (2 * sum)) {
  960. const list = [];
  961. for (let j = 0; j < n; j++) {
  962. let item = table[j + 1];
  963. if (item == table) {
  964. item = list;
  965. }
  966. list.push(item);
  967. }
  968. this._memo.set(index, list);
  969. return list;
  970. }
  971. return table;
  972. }
  973. function() {
  974. const index = this.int32();
  975. if (this._memo.has(index)) {
  976. return this._memo.get(index);
  977. }
  978. const size = this.int32();
  979. const dumped = this.bytes(size);
  980. const upvalues = this.read();
  981. const type = this._types.get('LuaFunction');
  982. const obj = Reflect.construct(type, [ size, dumped, upvalues ]);
  983. this._memo.set(index, obj);
  984. return obj;
  985. }
  986. storage(size, itemSize, dataType) {
  987. return this._reader.storage(size, itemSize, dataType);
  988. }
  989. };
  990. torch.BinaryReader = class {
  991. constructor(data) {
  992. this._buffer = data instanceof Uint8Array ? data : data.peek();
  993. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  994. this._position = 0;
  995. this._textDecoder = new TextDecoder('ascii');
  996. }
  997. reset() {
  998. this._position = 0;
  999. }
  1000. skip(offset) {
  1001. this._position += offset;
  1002. if (this._position > this._buffer.length) {
  1003. throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  1004. }
  1005. }
  1006. boolean() {
  1007. return this.int32() == 1;
  1008. }
  1009. bytes(length) {
  1010. const position = this._position;
  1011. this.skip(length);
  1012. return this._buffer.subarray(position, this._position);
  1013. }
  1014. int8() {
  1015. const position = this._position;
  1016. this.skip(1);
  1017. return this._dataView.getInt8(position, true);
  1018. }
  1019. int16() {
  1020. const position = this._position;
  1021. this.skip(2);
  1022. return this._dataView.getInt16(position, true);
  1023. }
  1024. int32() {
  1025. const position = this._position;
  1026. this.skip(4);
  1027. return this._dataView.getInt32(position, true);
  1028. }
  1029. int64() {
  1030. const position = this._position;
  1031. this.skip(8);
  1032. return this._dataView.getInt64(position, true).toNumber();
  1033. }
  1034. int64s(size) {
  1035. const array = [];
  1036. for (let i = 0; i < size; i++) {
  1037. array.push(this.int64());
  1038. }
  1039. return array;
  1040. }
  1041. float32() {
  1042. const position = this._position;
  1043. this.skip(4);
  1044. return this._dataView.getFloat32(position, true);
  1045. }
  1046. float64() {
  1047. const position = this._position;
  1048. this.skip(8);
  1049. return this._dataView.getFloat64(position, true);
  1050. }
  1051. string() {
  1052. return this._textDecoder.decode(this.bytes(this.int32()));
  1053. }
  1054. storage(size, itemSize) {
  1055. return new torch.BinaryReader(this.bytes(size * itemSize));
  1056. }
  1057. };
  1058. torch.TextReader = class {
  1059. constructor(data, separator) {
  1060. this._buffer = data instanceof Uint8Array ? data : data.peek();
  1061. this._position = 0;
  1062. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  1063. this._textDecoder = new TextDecoder('ascii');
  1064. this._separator = separator || 0x0a;
  1065. }
  1066. reset() {
  1067. this._position = 0;
  1068. }
  1069. line(size) {
  1070. const start = this._position;
  1071. while (this._position < this._buffer.length && size > -1) {
  1072. const c = this._buffer[this._position++];
  1073. if (c == this._separator) {
  1074. return this._buffer.slice(start, this._position - 1);
  1075. }
  1076. else if (this._position == this._buffer.length) {
  1077. return this._buffer.slice(start, this._position);
  1078. }
  1079. size--;
  1080. }
  1081. throw new torch.Error('Line exceeded maximum length.');
  1082. }
  1083. boolean() {
  1084. return this.int32() == 1;
  1085. }
  1086. bytes(size) {
  1087. return this.line(size);
  1088. }
  1089. int8() {
  1090. return this.int64();
  1091. }
  1092. int16() {
  1093. return this.int64();
  1094. }
  1095. int32() {
  1096. return this.int64();
  1097. }
  1098. int64() {
  1099. const token = this._textDecoder.decode(this.line(20));
  1100. const number = Number.parseInt(token, 10);
  1101. if (Number.isNaN(token - number)) {
  1102. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1103. }
  1104. return number;
  1105. }
  1106. int64s(size) {
  1107. const array = [];
  1108. if (size > 0) {
  1109. const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  1110. for (const token of content.split(' ')) {
  1111. const number = Number.parseInt(token, 10);
  1112. if (Number.isNaN(token - number)) {
  1113. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1114. }
  1115. array.push(number);
  1116. }
  1117. }
  1118. return array;
  1119. }
  1120. float32() {
  1121. return this.float64();
  1122. }
  1123. float64() {
  1124. const token = this._textDecoder.decode(this.line(24));
  1125. if (token.startsWith('-nan')) {
  1126. return -NaN;
  1127. }
  1128. if (token.startsWith('nan')) {
  1129. return NaN;
  1130. }
  1131. if (token.startsWith('inf')) {
  1132. return Infinity;
  1133. }
  1134. if (token.startsWith('-inf')) {
  1135. return -Infinity;
  1136. }
  1137. const number = Number.parseFloat(token);
  1138. if (Number.isNaN(token - number)) {
  1139. throw new torch.Error("Couldn't parse float '" + token + "'.");
  1140. }
  1141. return number;
  1142. }
  1143. string() {
  1144. const size = this.int32();
  1145. if (size == 0) {
  1146. return '';
  1147. }
  1148. const data = this.line(size);
  1149. const content = this._textDecoder.decode(data);
  1150. if (size != content.length) {
  1151. throw new torch.Error('Invalid string length.');
  1152. }
  1153. return content;
  1154. }
  1155. storage(size, itemSize, dataType) {
  1156. if (size <= 0) {
  1157. throw new torch.Error("Unsupported storage size '" + size + "'.");
  1158. }
  1159. if (dataType === 'uint8') {
  1160. const start = this._position;
  1161. this._position += size;
  1162. const bytes = this._buffer.slice(start, this._position);
  1163. this.line(0);
  1164. return new torch.BinaryReader(bytes);
  1165. }
  1166. const data = this.line(Number.MAX_SAFE_INTEGER);
  1167. return new torch.TextReader(data, 0x20);
  1168. }
  1169. };
  1170. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1171. module.exports.ModelFactory = torch.ModelFactory;
  1172. }