torch.js 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357
  1. var torch = torch || {};
  2. torch.ModelFactory = class {
  3. match(context) {
  4. return torch.T7Reader.open(context);
  5. }
  6. open(context, match) {
  7. return context.metadata('torch-metadata.json').then((metadata) => {
  8. const identifier = context.identifier;
  9. const reader = match;
  10. reader.callback = (name) => {
  11. if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  12. context.exception(new torch.Error("Unsupported type '" + name + "' in '" + identifier + "'."), false);
  13. }
  14. return null;
  15. };
  16. let root = reader.read();
  17. if (root && Array.isArray(root) && root.length == 2 && root[0].__class__ && !root[1].__class__) {
  18. root = root[0];
  19. }
  20. return new torch.Model(metadata, root);
  21. });
  22. }
  23. };
  24. torch.Model = class {
  25. constructor(metadata, root) {
  26. this._graphs = [];
  27. this._graphs.push(new torch.Graph(metadata, root));
  28. }
  29. get graphs() {
  30. return this._graphs;
  31. }
  32. get format() {
  33. return 'Torch v7';
  34. }
  35. };
  36. torch.Graph = class {
  37. constructor(metadata, root) {
  38. this._inputs = [];
  39. this._outputs = [];
  40. this._nodes = [];
  41. this._groups = 'false';
  42. if (Object.prototype.hasOwnProperty.call(root, 'model')) {
  43. root = root.model;
  44. }
  45. const inputs = [];
  46. const outputs = [];
  47. this._loadModule(metadata, root, [], '', inputs, outputs);
  48. this._inputs = this._inputs.concat(inputs.map((input, index) => {
  49. return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
  50. }));
  51. this._outputs = this._outputs.concat(outputs.map((output, index) => {
  52. return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
  53. }));
  54. }
  55. get inputs() {
  56. return this._inputs;
  57. }
  58. get outputs() {
  59. return this._outputs;
  60. }
  61. get nodes() {
  62. return this._nodes;
  63. }
  64. get groups() {
  65. return this._groups;
  66. }
  67. _loadModule(metadata, module, groups, key, inputs, outputs) {
  68. if (groups.length > 0) {
  69. this._groups = true;
  70. }
  71. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : '';
  72. switch (type) {
  73. case 'nn.Sequential': {
  74. groups.push(key);
  75. let subInputs = inputs;
  76. let subOutputs = [];
  77. const length = module.modules.length;
  78. let index = 0;
  79. for (const subModule of module.modules) {
  80. if (index == length - 1) {
  81. subOutputs = outputs;
  82. }
  83. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  84. subInputs = subOutputs;
  85. subOutputs = [];
  86. index++;
  87. }
  88. groups.pop();
  89. break;
  90. }
  91. case 'nn.Parallel':
  92. case 'nn.ParallelTable':
  93. case 'nn.JointTrain': {
  94. groups.push(key);
  95. let newInputs = [];
  96. let newOutputs = [];
  97. let index = 0;
  98. for (const subModule of module.modules) {
  99. const subInputs = [].concat(inputs);
  100. const subOutputs = [].concat(outputs);
  101. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  102. if (inputs.length == 0) {
  103. newInputs = newInputs.concat(subInputs);
  104. }
  105. if (outputs.length == 0) {
  106. newOutputs = newOutputs.concat(subOutputs);
  107. }
  108. index++;
  109. }
  110. inputs = inputs.concat(newInputs);
  111. for (const newOutput of newOutputs) {
  112. outputs.push(newOutput);
  113. }
  114. groups.pop();
  115. break;
  116. }
  117. case 'nn.Concat':
  118. case 'nn.ConcatTable': {
  119. const prefix = key;
  120. if (inputs.length == 0) {
  121. inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
  122. }
  123. let concatInputs = [];
  124. let index = 0;
  125. for (const subModule of module.modules) {
  126. const streamInputs = inputs.map((input) => input);
  127. const streamOutputs = [];
  128. this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
  129. concatInputs = concatInputs.concat(streamOutputs);
  130. index++;
  131. }
  132. delete module.modules;
  133. delete module.dimension;
  134. this._createNode(metadata, module, groups, key, concatInputs, outputs);
  135. break;
  136. }
  137. case 'nn.Inception': {
  138. delete module.modules; // TODO
  139. delete module.module; // TODO
  140. delete module.transfer; // TODO
  141. delete module.pool; // TODO
  142. this._createNode(metadata, module, groups, key, inputs, outputs);
  143. break;
  144. }
  145. case 'nn.gModule': {
  146. /*
  147. let index = 0;
  148. for (const subModule of module.modules) {
  149. subModule.modules = [];
  150. this._loadModule(metadata, subModule, groups, index.toString(), [], []);
  151. index++;
  152. }
  153. */
  154. this._createNode(metadata, module, groups, key, inputs, outputs);
  155. break;
  156. }
  157. default: {
  158. this._createNode(metadata, module, groups, key, inputs, outputs);
  159. break;
  160. }
  161. }
  162. }
  163. _createNode(metadata, module, group, subIndex, inputs, outputs) {
  164. const node = new torch.Node(metadata, module, group, subIndex, inputs, outputs);
  165. this._nodes.push(node);
  166. }
  167. };
  168. torch.Parameter = class {
  169. constructor(name, visible, args) {
  170. this._name = name;
  171. this._visible = visible;
  172. this._arguments = args;
  173. }
  174. get name() {
  175. return this._name;
  176. }
  177. get visible() {
  178. return this._visible;
  179. }
  180. get arguments() {
  181. return this._arguments;
  182. }
  183. };
  184. torch.Argument = class {
  185. constructor(name, type, initializer) {
  186. if (typeof name !== 'string') {
  187. throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  188. }
  189. this._name = name;
  190. this._type = type;
  191. this._initializer = initializer;
  192. }
  193. get name() {
  194. return this._name;
  195. }
  196. get type() {
  197. if (this._initializer) {
  198. return this._initializer.type;
  199. }
  200. return this._type;
  201. }
  202. get initializer() {
  203. return this._initializer;
  204. }
  205. };
  206. torch.Node = class {
  207. constructor(metadata, module, groups, name, inputs, outputs) {
  208. this._group = groups.join('/');
  209. if (module.name && typeof module.name === 'string') {
  210. this._name = module.name;
  211. delete module.name;
  212. }
  213. else {
  214. this._name = this._group ? (this._group + ':' + name) : name;
  215. }
  216. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : 'nn.Module';
  217. this._type = metadata.type(type);
  218. let initializers = [];
  219. for (const entry of Object.entries(module)) {
  220. const key = entry[0];
  221. const obj = entry[1];
  222. if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
  223. module[key] = obj.data();
  224. }
  225. }
  226. delete module.iSize;
  227. delete module.finput;
  228. delete module.fgradInput;
  229. delete module.output;
  230. delete module.gradInput;
  231. delete module.gradWeight;
  232. delete module.gradBias;
  233. delete module.grad_tmp;
  234. delete module.scaleT;
  235. delete module._input;
  236. delete module._output;
  237. delete module._gradInput;
  238. delete module._gradOutput;
  239. delete module.buffer;
  240. delete module.buffer2;
  241. delete module.tmp_in;
  242. delete module.tmp_out;
  243. delete module.accUpdateGradParameters;
  244. switch (this._type.name) {
  245. case 'nn.Linear':
  246. delete module.addBuffer;
  247. break;
  248. case 'nn.Normalize':
  249. case 'nn.Normalize2':
  250. delete module.addBuffer;
  251. delete module.normp;
  252. delete module.norm;
  253. break;
  254. case 'cudnn.SpatialConvolution':
  255. case 'cudnn.SpatialFullConvolution':
  256. case 'nn.SpatialConvolution':
  257. case 'nn.SpatialConvolutionMM':
  258. case 'nn.SpatialDilatedConvolution':
  259. case 'nn.SpatialFullConvolution':
  260. delete module.ones;
  261. delete module.input_slice;
  262. delete module.output_slice;
  263. delete module.convDescData;
  264. this._updateSize(module, 'adj');
  265. this._updateSize(module, 'd');
  266. this._updateSize(module, 'dilation');
  267. this._updateSize(module, 'k');
  268. this._updateSize(module, 'pad');
  269. break;
  270. case 'cudnn.BatchNormalization':
  271. case 'cudnn.SpatialBatchNormalization':
  272. case 'nn.BatchNormalization':
  273. case 'nn.SpatialBatchNormalization':
  274. case 'nn.InstanceNormalization':
  275. delete module.save_mean;
  276. delete module.save_std;
  277. delete module.gradWeight;
  278. delete module.normalized;
  279. delete module.centered;
  280. delete module.bn; // TODO InstanceNormalization
  281. break;
  282. case 'nn.SpatialCrossMapLRN':
  283. delete module.scale;
  284. break;
  285. case 'cudnn.SpatialMaxPooling':
  286. case 'cudnn.SpatialAveragePooling':
  287. case 'inn.SpatialMaxPooling':
  288. case 'nn.SpatialMaxPooling':
  289. case 'nn.SpatialAveragePooling':
  290. delete module.indices;
  291. this._updateSize(module, 'pad');
  292. this._updateSize(module, 'd');
  293. this._updateSize(module, 'k');
  294. break;
  295. case 'nn.SpatialZeroPadding':
  296. case 'nn.SpatialReflectionPadding':
  297. case 'nn.SpatialReplicationPadding':
  298. this._updateBox(module, 'pad');
  299. break;
  300. case 'nn.Dropout':
  301. delete module.noise;
  302. break;
  303. case 'nn.gModule':
  304. delete module.forwardnodes;
  305. delete module.backwardnodes;
  306. break;
  307. case 'nn.StereoJoin':
  308. delete module.output_L;
  309. break;
  310. default:
  311. break;
  312. }
  313. this._attributes = [];
  314. if (module.__class__) {
  315. for (const entry of Object.entries(module)) {
  316. const key = entry[0];
  317. const obj = entry[1];
  318. if (key == '_type') {
  319. continue;
  320. }
  321. if (Array.isArray(obj) && obj.every(((item) => item && item.__class__ && item.__class__.__module__ === 'nn'))) {
  322. continue;
  323. }
  324. if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
  325. initializers.push(new torch.Parameter(key, true, [
  326. new torch.Argument(key, null, new torch.Tensor(obj))
  327. ]));
  328. continue;
  329. }
  330. if (key == 'modules') {
  331. continue;
  332. }
  333. if (obj.__class__ && obj.__class__.__module__ !== '' && obj.__class__.__name__ != 'LuaFunction') {
  334. continue;
  335. }
  336. const attribute = new torch.Attribute(metadata, type, key, obj);
  337. this._attributes.push(attribute);
  338. }
  339. }
  340. this._inputs = [];
  341. if (inputs.length == 0 && this._name) {
  342. inputs.push(new torch.Argument(this._name + ':in', null, null));
  343. }
  344. this._inputs.push(new torch.Parameter('input', true, inputs));
  345. if (outputs.length == 0 && this._name) {
  346. outputs.push(new torch.Argument(this._name, null, null));
  347. }
  348. this._outputs = [];
  349. this._outputs.push(new torch.Parameter('output', true, outputs));
  350. initializers = initializers.filter((argument) => {
  351. if (argument.name == 'weight') {
  352. this._inputs.push(argument);
  353. return false;
  354. }
  355. return true;
  356. });
  357. initializers = initializers.filter((argument) => {
  358. if (argument.name == 'bias') {
  359. this._inputs.push(argument);
  360. return false;
  361. }
  362. return true;
  363. });
  364. this._inputs = this._inputs.concat(initializers);
  365. }
  366. get name() {
  367. return this._name;
  368. }
  369. get type() {
  370. return this._type;
  371. }
  372. get group() {
  373. return this._group;
  374. }
  375. get attributes() {
  376. return this._attributes;
  377. }
  378. get inputs() {
  379. return this._inputs;
  380. }
  381. get outputs() {
  382. return this._outputs;
  383. }
  384. _updateSize(module, name) {
  385. if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
  386. Object.prototype.hasOwnProperty.call(module, name + 'H')) {
  387. module[name] = [ module[name + 'W'], module[name + 'H'] ];
  388. delete module[name + 'W'];
  389. delete module[name + 'H'];
  390. }
  391. }
  392. _updateBox(module, name) {
  393. if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
  394. Object.prototype.hasOwnProperty.call(module, name + '_r') &&
  395. Object.prototype.hasOwnProperty.call(module, name + '_b') &&
  396. Object.prototype.hasOwnProperty.call(module, name + '_l')) {
  397. module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
  398. delete module[name + '_t'];
  399. delete module[name + '_r'];
  400. delete module[name + '_b'];
  401. delete module[name + '_l'];
  402. }
  403. }
  404. };
  405. torch.Attribute = class {
  406. constructor(metadata, type, name, value) {
  407. this._name = name;
  408. this._value = value;
  409. if (name == 'train') {
  410. this._visible = false;
  411. }
  412. const schema = metadata.attribute(type, name);
  413. if (schema) {
  414. if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
  415. this._visible = schema.visible;
  416. }
  417. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  418. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  419. this._visible = false;
  420. }
  421. }
  422. }
  423. }
  424. get name() {
  425. return this._name;
  426. }
  427. get value() {
  428. return this._value;
  429. }
  430. get visible() {
  431. return this._visible == false ? false : true;
  432. }
  433. };
  434. torch.Tensor = class {
  435. constructor(tensor) {
  436. this._type = new torch.TensorType(tensor);
  437. this._storage = tensor.storage;
  438. this._offset = tensor.storage_offset;
  439. }
  440. get type() {
  441. return this._type;
  442. }
  443. get state() {
  444. return this._context().state || null;
  445. }
  446. get value() {
  447. const context = this._context();
  448. if (context.state) {
  449. return null;
  450. }
  451. context.limit = Number.MAX_SAFE_INTEGER;
  452. return this._decode(context, 0);
  453. }
  454. toString() {
  455. const context = this._context();
  456. if (context.state) {
  457. return '';
  458. }
  459. context.limit = 1000;
  460. const value = this._decode(context, 0);
  461. return JSON.stringify(value, null, 4);
  462. }
  463. _context() {
  464. const context = {};
  465. context.state = null;
  466. context.index = 0;
  467. context.count = 0;
  468. if (!this._storage) {
  469. context.state = 'Tensor data is empty.';
  470. return context;
  471. }
  472. context.data = this._storage.data();
  473. context.index = this._offset;
  474. if (!context.data) {
  475. context.state = 'Tensor data is empty.';
  476. return context;
  477. }
  478. switch (this._type.dataType) {
  479. case 'uint8':
  480. case 'int8':
  481. case 'int16':
  482. case 'int32':
  483. case 'int64':
  484. case 'float32':
  485. case 'float64':
  486. break;
  487. default:
  488. context.state = 'Tensor data type is not implemented.';
  489. break;
  490. }
  491. context.dimensions = this._type.shape.dimensions;
  492. if (!context.dimensions && context.dimensions.length == 0) {
  493. context.state = 'Tensor has no dimensions.';
  494. return context;
  495. }
  496. return context;
  497. }
  498. _decode(context, dimension) {
  499. const results = [];
  500. const size = context.dimensions[dimension];
  501. if (dimension == context.dimensions.length - 1) {
  502. for (let i = 0; i < size; i++) {
  503. if (context.count > context.limit) {
  504. results.push('...');
  505. return results;
  506. }
  507. results.push(context.data[context.index]);
  508. context.index++;
  509. context.count++;
  510. }
  511. }
  512. else {
  513. for (let j = 0; j < size; j++) {
  514. if (context.count > context.limit) {
  515. results.push('...');
  516. return results;
  517. }
  518. results.push(this._decode(context, dimension + 1));
  519. }
  520. }
  521. return results;
  522. }
  523. };
  524. torch.TensorType = class {
  525. constructor(tensor) {
  526. this._dataType = tensor.dataType;
  527. this._shape = new torch.TensorShape(tensor.size);
  528. }
  529. get dataType() {
  530. return this._dataType;
  531. }
  532. get shape() {
  533. return this._shape;
  534. }
  535. toString() {
  536. return (this.dataType || '?') + this._shape.toString();
  537. }
  538. };
  539. torch.TensorShape = class {
  540. constructor(dimensions) {
  541. this._dimensions = dimensions;
  542. }
  543. get dimensions() {
  544. return this._dimensions;
  545. }
  546. toString() {
  547. if (this._dimensions) {
  548. if (this._dimensions.length == 0) {
  549. return '';
  550. }
  551. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  552. }
  553. return '';
  554. }
  555. };
  556. torch.Error = class extends Error {
  557. constructor(message) {
  558. super(message);
  559. this.name = 'Error loading Torch model.';
  560. }
  561. };
  562. torch.T7Reader = class {
  563. static open(context) {
  564. const stream = context.stream;
  565. if (stream && stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index == 0 && value <= 0x08))) {
  566. const reader = new torch.BinaryReader(stream);
  567. return new torch.T7Reader(reader);
  568. }
  569. if (stream && stream.length >= 2) {
  570. const buffer = stream.peek(2);
  571. const value = String.fromCharCode(stream.peek(1)[0]);
  572. if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
  573. const reader = new torch.TextReader(stream);
  574. return new torch.T7Reader(reader);
  575. }
  576. }
  577. return null;
  578. }
  579. constructor(reader) {
  580. this._reader = reader;
  581. this._memo = new Map();
  582. this._types = new Map();
  583. const Storage = class {
  584. constructor(dataType, itemSize) {
  585. this.dataType = dataType;
  586. this.itemSize = itemSize;
  587. }
  588. data() {
  589. if (this.reader) {
  590. const reader = this.reader;
  591. reader.reset();
  592. const dataType = this.dataType;
  593. const size = this.size;
  594. const array = new Array(size);
  595. for (let i = 0; i < size; i++) {
  596. switch (dataType) {
  597. case 'uint8':
  598. array[i] = reader.byte();
  599. break;
  600. case 'int8':
  601. array[i] = reader.int8();
  602. break;
  603. case 'int16':
  604. array[i] = reader.int16();
  605. break;
  606. case 'int32':
  607. array[i] = reader.int32();
  608. break;
  609. case 'int64':
  610. array[i] = reader.int64();
  611. break;
  612. case 'float32':
  613. array[i] = reader.float32();
  614. break;
  615. case 'float64':
  616. array[i] = reader.float64();
  617. break;
  618. default:
  619. throw new torch.Error("Unsupported data type '" + dataType + "'.");
  620. }
  621. }
  622. this._data = array;
  623. delete this.reader;
  624. }
  625. return this._data;
  626. }
  627. read(reader) {
  628. this.size = reader.int64();
  629. this.reader = reader.storage(this.size, this.itemSize, this.dataType);
  630. }
  631. };
  632. const Tensor = class {
  633. constructor(dataType) {
  634. this.dataType = dataType;
  635. }
  636. read(reader) {
  637. const dim = reader.int32();
  638. this.size = reader.int64s(dim);
  639. this.stride = reader.int64s(dim);
  640. this.storage_offset = reader.int64() - 1;
  641. this.storage = reader.read();
  642. }
  643. };
  644. this.register('bnn.Binary');
  645. this.register('bnn.SpatialConvolution');
  646. this.register('cudnn.BatchNormalization');
  647. this.register('cudnn.BatchBRNNReLU');
  648. this.register('cudnn.BLSTM');
  649. this.register('cudnn.ReLU');
  650. this.register('cudnn.RNN');
  651. this.register('cudnn.Sigmoid');
  652. this.register('cudnn.SoftMax');
  653. this.register('cudnn.LogSoftMax');
  654. this.register('cudnn.normal3DConv');
  655. this.register('cudnn.normal3DdeConv');
  656. this.register('cudnn.SpatialAveragePooling');
  657. this.register('cudnn.SpatialBatchNormalization');
  658. this.register('cudnn.SpatialConvolution');
  659. this.register('cudnn.SpatialFullConvolution');
  660. this.register('cudnn.SpatialMaxPooling');
  661. this.register('cudnn.SpatialSoftMax');
  662. this.register('cudnn.Tanh');
  663. this.register('cudnn.VolumetricAveragePooling');
  664. this.register('cudnn.VolumetricBatchNormalization');
  665. this.register('cudnn.VolumetricConvolution');
  666. this.register('cudnn.VolumetricMaxPooling');
  667. this.register('Dict');
  668. this.register('inn.ConstAffine');
  669. this.register('inn.SpatialMaxPooling');
  670. this.register('nn.Abs');
  671. this.register('nn.AddConstant');
  672. this.register('nn.BatchNormalization');
  673. this.register('nn.BilinearSamplerBHWD');
  674. this.register('nn.BinActiveZ'); // allenai/XNOR-Net
  675. this.register('nn.BCECriterion');
  676. this.register('nn.Bottle');
  677. this.register('nn.Clamp');
  678. this.register('nn.CMul');
  679. this.register('nn.CAddTable');
  680. this.register('nn.CDivTable');
  681. this.register('nn.CMulTable');
  682. this.register('nn.CSubTable');
  683. this.register('nn.Concat');
  684. this.register('nn.Copy');
  685. this.register('nn.ConcatTable');
  686. this.register('nn.Contiguous');
  687. this.register('nn.Constant');
  688. this.register('nn.CostVolMulti');
  689. this.register('nn.DataParallelTable');
  690. this.register('nn.DepthConcat');
  691. this.register('nn.Dropout');
  692. this.register('nn.Exp');
  693. this.register('nn.ExpOut');
  694. this.register('nn.FlattenTable');
  695. this.register('nn.GenNoise');
  696. this.register('nn.Identity');
  697. this.register('nn.Index');
  698. this.register('nn.Inception');
  699. this.register('nn.InstanceNormalization');
  700. this.register('nn.JoinTable');
  701. this.register('nn.JointTrain');
  702. this.register('nn.KeypointCoordinate');
  703. this.register('nn.LeakyReLU');
  704. this.register('nn.Linear');
  705. this.register('nn.LinearNoBias');
  706. this.register('nn.LogSoftMax');
  707. this.register('nn.LookupTable');
  708. this.register('nn.LSTM');
  709. this.register('nn.MaskZero');
  710. this.register('nn.MapTable');
  711. this.register('nn.Max');
  712. this.register('nn.Mean');
  713. this.register('nn.Min');
  714. this.register('nn.MulConstant');
  715. this.register('nn.MM');
  716. this.register('nn.MSECriterion');
  717. this.register('nn.Narrow');
  718. this.register('nn.NarrowTable');
  719. this.register('nn.Normalize');
  720. this.register('nn.Normalize2');
  721. this.register('nn.NoiseFill');
  722. this.register('nn.Padding');
  723. this.register('nn.Parallel');
  724. this.register('nn.ParallelCriterion');
  725. this.register('nn.ParallelTable');
  726. this.register('nn.PixelShuffle');
  727. this.register('nn.Power');
  728. this.register('nn.PReLU');
  729. this.register('nn.Recursor');
  730. this.register('nn.ReLU');
  731. this.register('nn.Replicate');
  732. this.register('nn.Reshape');
  733. this.register('nn.ShaveImage');
  734. this.register('nn.Select');
  735. this.register('nn.SelectTable');
  736. this.register('nn.Sequencer');
  737. this.register('nn.Sequential');
  738. this.register('nn.Sigmoid');
  739. this.register('nn.Sum');
  740. this.register('nn.SoftMax');
  741. this.register('nn.SpatialAveragePooling');
  742. this.register('nn.SpatialBatchNormalization');
  743. this.register('nn.SpatialConvolution');
  744. this.register('nn.SpatialConvolutionMM');
  745. this.register('nn.SpatialCrossMapLRN');
  746. this.register('nn.SpatialDilatedConvolution');
  747. this.register('nn.SpatialDropout');
  748. this.register('nn.SpatialFractionalMaxPooling');
  749. this.register('nn.SpatialFullConvolution');
  750. this.register('nn.SpatialLPPooling');
  751. this.register('nn.SpatialMaxPooling');
  752. this.register('nn.SpatialMaxUnpooling');
  753. this.register('nn.SpatialReflectionPadding');
  754. this.register('nn.SpatialReplicationPadding');
  755. this.register('nn.SpatialSoftMax');
  756. this.register('nn.SpatialSubtractiveNormalization');
  757. this.register('nn.SpatialUpSamplingBilinear');
  758. this.register('nn.SpatialUpSamplingNearest');
  759. this.register('nn.SpatialZeroPadding');
  760. this.register('nn.SplitTable');
  761. this.register('nn.Squeeze');
  762. this.register('nn.Square');
  763. this.register('nn.Sqrt');
  764. this.register('nn.StereoJoin');
  765. this.register('nn.Tanh');
  766. this.register('nn.Transpose');
  767. this.register('nn.TotalVariation');
  768. this.register('nn.Unpool');
  769. this.register('nn.View');
  770. this.register('nn.gModule');
  771. this.register('nngraph.Node');
  772. this.register('graph.Edge');
  773. this.register('graph.Graph');
  774. this.register('torch.ByteTensor', class extends Tensor {
  775. constructor() {
  776. super('uint8');
  777. }
  778. });
  779. this.register('torch.CharTensor', class extends Tensor {
  780. constructor() {
  781. super('int8');
  782. }
  783. });
  784. this.register('torch.ShortTensor', class extends Tensor {
  785. constructor() {
  786. super('int16');
  787. }
  788. });
  789. this.register('torch.IntTensor', class extends Tensor {
  790. constructor() {
  791. super('int32');
  792. }
  793. });
  794. this.register('torch.LongTensor', class extends Tensor {
  795. constructor() {
  796. super('int64');
  797. }
  798. });
  799. this.register('torch.FloatTensor', class extends Tensor {
  800. constructor() {
  801. super('float32');
  802. }
  803. });
  804. this.register('torch.DoubleTensor', class extends Tensor {
  805. constructor() {
  806. super('float64');
  807. }
  808. });
  809. this.register('torch.CudaByteTensor', class extends Tensor {
  810. constructor() {
  811. super('uint8');
  812. }
  813. });
  814. this.register('torch.CudaCharTensor', class extends Tensor {
  815. constructor() {
  816. super('int8');
  817. }
  818. });
  819. this.register('torch.CudaShortTensor', class extends Tensor {
  820. constructor() {
  821. super('int16');
  822. }
  823. });
  824. this.register('torch.CudaIntTensor', class extends Tensor {
  825. constructor() {
  826. super('int32');
  827. }
  828. });
  829. this.register('torch.CudaLongTensor', class extends Tensor {
  830. constructor() {
  831. super('int64');
  832. }
  833. });
  834. this.register('torch.CudaTensor', class extends Tensor {
  835. constructor() {
  836. super('float32');
  837. }
  838. });
  839. this.register('torch.CudaDoubleTensor', class extends Tensor {
  840. constructor() {
  841. super('float64');
  842. }
  843. });
  844. this.register('torch.ByteStorage', class extends Storage {
  845. constructor() {
  846. super('uint8', 1);
  847. }
  848. });
  849. this.register('torch.CharStorage', class extends Storage {
  850. constructor() {
  851. super('int8', 1);
  852. }
  853. });
  854. this.register('torch.ShortStorage', class extends Storage {
  855. constructor() {
  856. super('int16', 2);
  857. }
  858. });
  859. this.register('torch.IntStorage', class extends Storage {
  860. constructor() {
  861. super('int32', 4);
  862. }
  863. });
  864. this.register('torch.LongStorage', class extends Storage {
  865. constructor() {
  866. super('int64', 8);
  867. }
  868. });
  869. this.register('torch.FloatStorage', class extends Storage {
  870. constructor() {
  871. super('float32', 4);
  872. }
  873. });
  874. this.register('torch.DoubleStorage', class extends Storage {
  875. constructor() {
  876. super('float64', 8);
  877. }
  878. });
  879. this.register('torch.CudaByteStorage', class extends Storage {
  880. constructor() {
  881. super('uint8', 1);
  882. }
  883. });
  884. this.register('torch.CudaCharStorage', class extends Storage {
  885. constructor() {
  886. super('int8', 1);
  887. }
  888. });
  889. this.register('torch.CudaShortStorage', class extends Storage {
  890. constructor() {
  891. super('int16', 2);
  892. }
  893. });
  894. this.register('torch.CudaIntStorage', class extends Storage {
  895. constructor() {
  896. super('int32', 4);
  897. }
  898. });
  899. this.register('torch.CudaLongStorage', class extends Storage {
  900. constructor() {
  901. super('int64', 8);
  902. }
  903. });
  904. this.register('torch.CudaIntStorage', class extends Storage {
  905. constructor() {
  906. super('int32', 4);
  907. }
  908. });
  909. this.register('torch.CudaStorage', class extends Storage {
  910. constructor() {
  911. super('float32', 4);
  912. }
  913. });
  914. this.register('torch.CudaFloatStorage', class extends Storage {
  915. constructor() {
  916. super('float64', 8);
  917. }
  918. });
  919. this.register('w2nn.AuxiliaryLossTable');
  920. this.register('w2nn.InplaceClip01');
  921. this.register('w2nn.ScaleTable');
  922. this.register('LuaFunction', class {
  923. constructor(size, dumped, upvalues) {
  924. this.size = size;
  925. this.dumped = dumped;
  926. this.upvalues = upvalues;
  927. }
  928. });
  929. }
  930. register(name, type) {
  931. type = type || class {};
  932. const parts = name.split('.');
  933. type.__name__ = parts.pop();
  934. type.__module__ = parts.join('.');
  935. type.prototype.__class__ = type;
  936. this._types.set(name, type);
  937. }
  938. read() {
  939. const type = this.int32();
  940. switch (type) {
  941. case 0: return null;
  942. case 1: return this.float64();
  943. case 2: return this.string();
  944. case 3: return this.table();
  945. case 4: return this.object();
  946. case 5: return this.boolean();
  947. case 6: return this.function();
  948. case 7: return this.function();
  949. case 8: return this.function();
  950. default: throw new torch.Error("File format has invalid type '" + type + "'.");
  951. }
  952. }
  953. boolean() {
  954. return this._reader.boolean();
  955. }
  956. bytes(size) {
  957. return this._reader.bytes(size);
  958. }
  959. int32() {
  960. return this._reader.int32();
  961. }
  962. int64() {
  963. return this._reader.int64();
  964. }
  965. int64s(size) {
  966. return this._reader.int64s(size);
  967. }
  968. float64() {
  969. return this._reader.float64();
  970. }
  971. string() {
  972. return this._reader.string();
  973. }
  974. object() {
  975. const index = this.int32();
  976. if (this._memo.has(index)) {
  977. return this._memo.get(index);
  978. }
  979. let version = this.string();
  980. let name = null;
  981. if (version.startsWith('V ')) {
  982. name = this.string();
  983. version = Number(version.split(' ')[1]);
  984. }
  985. else {
  986. name = version;
  987. version = 0;
  988. }
  989. if (!this._types.has(name)) {
  990. this.callback(name);
  991. this.register(name);
  992. }
  993. const type = this._types.get(name);
  994. const obj = Reflect.construct(type, []);
  995. this._memo.set(index, obj);
  996. if (obj.read) {
  997. obj.read(this, version);
  998. }
  999. else {
  1000. const attributes = this.read();
  1001. if (attributes != null) {
  1002. for (const entry of Object.entries(attributes)) {
  1003. const key = entry[0];
  1004. obj[key] = entry[1];
  1005. }
  1006. }
  1007. }
  1008. return obj;
  1009. }
  1010. table() {
  1011. const index = this.int32();
  1012. if (this._memo.has(index)) {
  1013. return this._memo.get(index);
  1014. }
  1015. const table = {};
  1016. this._memo.set(index, table);
  1017. const size = this.int32();
  1018. let convert = true;
  1019. let sum = 0;
  1020. for (let i = 0; i < size; i++) {
  1021. const key = this.read();
  1022. const value = this.read();
  1023. table[key] = value;
  1024. if (Number.isInteger(key) && key >= 0) {
  1025. sum += key;
  1026. }
  1027. else {
  1028. convert = false;
  1029. }
  1030. }
  1031. const n = Object.keys(table).length;
  1032. if (convert && (n * (n + 1)) == (2 * sum)) {
  1033. const list = [];
  1034. for (let j = 0; j < n; j++) {
  1035. let item = table[j + 1];
  1036. if (item == table) {
  1037. item = list;
  1038. }
  1039. list.push(item);
  1040. }
  1041. this._memo.set(index, list);
  1042. return list;
  1043. }
  1044. return table;
  1045. }
  1046. function() {
  1047. const index = this.int32();
  1048. if (this._memo.has(index)) {
  1049. return this._memo.get(index);
  1050. }
  1051. const size = this.int32();
  1052. const dumped = this.bytes(size);
  1053. const upvalues = this.read();
  1054. const type = this._types.get('LuaFunction');
  1055. const obj = Reflect.construct(type, [ size, dumped, upvalues ]);
  1056. this._memo.set(index, obj);
  1057. return obj;
  1058. }
  1059. storage(size, itemSize, dataType) {
  1060. return this._reader.storage(size, itemSize, dataType);
  1061. }
  1062. };
  1063. torch.BinaryReader = class {
  1064. constructor(data) {
  1065. this._buffer = data instanceof Uint8Array ? data : data.peek();
  1066. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  1067. this._position = 0;
  1068. this._textDecoder = new TextDecoder('ascii');
  1069. }
  1070. reset() {
  1071. this._position = 0;
  1072. }
  1073. skip(offset) {
  1074. this._position += offset;
  1075. if (this._position > this._buffer.length) {
  1076. throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  1077. }
  1078. }
  1079. boolean() {
  1080. return this.int32() == 1;
  1081. }
  1082. bytes(length) {
  1083. const position = this._position;
  1084. this.skip(length);
  1085. return this._buffer.subarray(position, this._position);
  1086. }
  1087. int8() {
  1088. const position = this._position;
  1089. this.skip(1);
  1090. return this._dataView.getInt8(position, true);
  1091. }
  1092. int16() {
  1093. const position = this._position;
  1094. this.skip(2);
  1095. return this._dataView.getInt16(position, true);
  1096. }
  1097. int32() {
  1098. const position = this._position;
  1099. this.skip(4);
  1100. return this._dataView.getInt32(position, true);
  1101. }
  1102. int64() {
  1103. const position = this._position;
  1104. this.skip(8);
  1105. return this._dataView.getInt64(position, true).toNumber();
  1106. }
  1107. int64s(size) {
  1108. const array = [];
  1109. for (let i = 0; i < size; i++) {
  1110. array.push(this.int64());
  1111. }
  1112. return array;
  1113. }
  1114. float32() {
  1115. const position = this._position;
  1116. this.skip(4);
  1117. return this._dataView.getFloat32(position, true);
  1118. }
  1119. float64() {
  1120. const position = this._position;
  1121. this.skip(8);
  1122. return this._dataView.getFloat64(position, true);
  1123. }
  1124. string() {
  1125. return this._textDecoder.decode(this.bytes(this.int32()));
  1126. }
  1127. storage(size, itemSize) {
  1128. return new torch.BinaryReader(this.bytes(size * itemSize));
  1129. }
  1130. };
  1131. torch.TextReader = class {
  1132. constructor(data, separator) {
  1133. this._buffer = data instanceof Uint8Array ? data : data.peek();
  1134. this._position = 0;
  1135. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  1136. this._textDecoder = new TextDecoder('ascii');
  1137. this._separator = separator || 0x0a;
  1138. }
  1139. reset() {
  1140. this._position = 0;
  1141. }
  1142. line(size) {
  1143. const start = this._position;
  1144. while (this._position < this._buffer.length && size > -1) {
  1145. const c = this._buffer[this._position++];
  1146. if (c == this._separator) {
  1147. return this._buffer.slice(start, this._position - 1);
  1148. }
  1149. else if (this._position == this._buffer.length) {
  1150. return this._buffer.slice(start, this._position);
  1151. }
  1152. size--;
  1153. }
  1154. throw new torch.Error('Line exceeded maximum length.');
  1155. }
  1156. boolean() {
  1157. return this.int32() == 1;
  1158. }
  1159. bytes(size) {
  1160. return this.line(size);
  1161. }
  1162. int8() {
  1163. return this.int64();
  1164. }
  1165. int16() {
  1166. return this.int64();
  1167. }
  1168. int32() {
  1169. return this.int64();
  1170. }
  1171. int64() {
  1172. const token = this._textDecoder.decode(this.line(20));
  1173. const number = Number.parseInt(token, 10);
  1174. if (Number.isNaN(token - number)) {
  1175. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1176. }
  1177. return number;
  1178. }
  1179. int64s(size) {
  1180. const array = [];
  1181. if (size > 0) {
  1182. const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  1183. for (const token of content.split(' ')) {
  1184. const number = Number.parseInt(token, 10);
  1185. if (Number.isNaN(token - number)) {
  1186. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1187. }
  1188. array.push(number);
  1189. }
  1190. }
  1191. return array;
  1192. }
  1193. float32() {
  1194. return this.float64();
  1195. }
  1196. float64() {
  1197. const token = this._textDecoder.decode(this.line(24));
  1198. if (token.startsWith('-nan')) {
  1199. return -NaN;
  1200. }
  1201. if (token.startsWith('nan')) {
  1202. return NaN;
  1203. }
  1204. if (token.startsWith('inf')) {
  1205. return Infinity;
  1206. }
  1207. if (token.startsWith('-inf')) {
  1208. return -Infinity;
  1209. }
  1210. const number = Number.parseFloat(token);
  1211. if (Number.isNaN(token - number)) {
  1212. throw new torch.Error("Couldn't parse float '" + token + "'.");
  1213. }
  1214. return number;
  1215. }
  1216. string() {
  1217. const size = this.int32();
  1218. if (size == 0) {
  1219. return '';
  1220. }
  1221. const data = this.line(size);
  1222. const content = this._textDecoder.decode(data);
  1223. if (size != content.length) {
  1224. throw new torch.Error('Invalid string length.');
  1225. }
  1226. return content;
  1227. }
  1228. storage(size, itemSize, dataType) {
  1229. if (size <= 0) {
  1230. throw new torch.Error("Unsupported storage size '" + size + "'.");
  1231. }
  1232. if (dataType === 'uint8') {
  1233. const start = this._position;
  1234. this._position += size;
  1235. const bytes = this._buffer.slice(start, this._position);
  1236. this.line(0);
  1237. return new torch.BinaryReader(bytes);
  1238. }
  1239. const data = this.line(Number.MAX_SAFE_INTEGER);
  1240. return new torch.TextReader(data, 0x20);
  1241. }
  1242. };
  1243. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1244. module.exports.ModelFactory = torch.ModelFactory;
  1245. }