torch.js 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. var torch = torch || {};
  2. torch.ModelFactory = class {
  3. match(context) {
  4. return torch.T7Reader.open(context);
  5. }
  6. open(context, match) {
  7. return context.metadata('torch-metadata.json').then((metadata) => {
  8. const identifier = context.identifier;
  9. const reader = match;
  10. reader.callback = (name) => {
  11. if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  12. context.exception(new torch.Error("Unsupported type '" + name + "' in '" + identifier + "'."), false);
  13. }
  14. return null;
  15. };
  16. let root = reader.read();
  17. if (root && Array.isArray(root) && root.length == 2 && root[0].__class__ && !root[1].__class__) {
  18. root = root[0];
  19. }
  20. return new torch.Model(metadata, root);
  21. });
  22. }
  23. };
  24. torch.Model = class {
  25. constructor(metadata, root) {
  26. this._graphs = [];
  27. this._graphs.push(new torch.Graph(metadata, root));
  28. }
  29. get graphs() {
  30. return this._graphs;
  31. }
  32. get format() {
  33. return 'Torch v7';
  34. }
  35. };
  36. torch.Graph = class {
  37. constructor(metadata, root) {
  38. this._inputs = [];
  39. this._outputs = [];
  40. this._nodes = [];
  41. this._groups = 'false';
  42. if (Object.prototype.hasOwnProperty.call(root, 'model')) {
  43. root = root.model;
  44. }
  45. const inputs = [];
  46. const outputs = [];
  47. this._loadModule(metadata, root, [], '', inputs, outputs);
  48. this._inputs = this._inputs.concat(inputs.map((input, index) => {
  49. return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
  50. }));
  51. this._outputs = this._outputs.concat(outputs.map((output, index) => {
  52. return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
  53. }));
  54. }
  55. get inputs() {
  56. return this._inputs;
  57. }
  58. get outputs() {
  59. return this._outputs;
  60. }
  61. get nodes() {
  62. return this._nodes;
  63. }
  64. get groups() {
  65. return this._groups;
  66. }
  67. _loadModule(metadata, module, groups, key, inputs, outputs) {
  68. if (groups.length > 0) {
  69. this._groups = true;
  70. }
  71. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : '';
  72. switch (type) {
  73. case 'nn.Sequential': {
  74. groups.push(key);
  75. let subInputs = inputs;
  76. let subOutputs = [];
  77. const length = module.modules.length;
  78. let index = 0;
  79. for (const subModule of module.modules) {
  80. if (index == length - 1) {
  81. subOutputs = outputs;
  82. }
  83. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  84. subInputs = subOutputs;
  85. subOutputs = [];
  86. index++;
  87. }
  88. groups.pop();
  89. break;
  90. }
  91. case 'nn.Parallel':
  92. case 'nn.ParallelTable':
  93. case 'nn.JointTrain': {
  94. groups.push(key);
  95. let newInputs = [];
  96. let newOutputs = [];
  97. let index = 0;
  98. for (const subModule of module.modules) {
  99. const subInputs = [].concat(inputs);
  100. const subOutputs = [].concat(outputs);
  101. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  102. if (inputs.length == 0) {
  103. newInputs = newInputs.concat(subInputs);
  104. }
  105. if (outputs.length == 0) {
  106. newOutputs = newOutputs.concat(subOutputs);
  107. }
  108. index++;
  109. }
  110. inputs = inputs.concat(newInputs);
  111. for (const newOutput of newOutputs) {
  112. outputs.push(newOutput);
  113. }
  114. groups.pop();
  115. break;
  116. }
  117. case 'nn.Concat':
  118. case 'nn.ConcatTable': {
  119. const prefix = key;
  120. if (inputs.length == 0) {
  121. inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
  122. }
  123. let concatInputs = [];
  124. let index = 0;
  125. for (const subModule of module.modules) {
  126. const streamInputs = inputs.map((input) => input);
  127. const streamOutputs = [];
  128. this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
  129. concatInputs = concatInputs.concat(streamOutputs);
  130. index++;
  131. }
  132. delete module.modules;
  133. delete module.dimension;
  134. this._createNode(metadata, module, groups, key, concatInputs, outputs);
  135. break;
  136. }
  137. case 'nn.Inception': {
  138. delete module.modules; // TODO
  139. delete module.module; // TODO
  140. delete module.transfer; // TODO
  141. delete module.pool; // TODO
  142. this._createNode(metadata, module, groups, key, inputs, outputs);
  143. break;
  144. }
  145. case 'nn.gModule': {
  146. /*
  147. let index = 0;
  148. for (const subModule of module.modules) {
  149. subModule.modules = [];
  150. this._loadModule(metadata, subModule, groups, index.toString(), [], []);
  151. index++;
  152. }
  153. */
  154. this._createNode(metadata, module, groups, key, inputs, outputs);
  155. break;
  156. }
  157. default: {
  158. this._createNode(metadata, module, groups, key, inputs, outputs);
  159. break;
  160. }
  161. }
  162. }
  163. _createNode(metadata, module, group, subIndex, inputs, outputs) {
  164. const node = new torch.Node(metadata, module, group, subIndex, inputs, outputs);
  165. this._nodes.push(node);
  166. }
  167. };
  168. torch.Parameter = class {
  169. constructor(name, visible, args) {
  170. this._name = name;
  171. this._visible = visible;
  172. this._arguments = args;
  173. }
  174. get name() {
  175. return this._name;
  176. }
  177. get visible() {
  178. return this._visible;
  179. }
  180. get arguments() {
  181. return this._arguments;
  182. }
  183. };
  184. torch.Argument = class {
  185. constructor(name, type, initializer) {
  186. if (typeof name !== 'string') {
  187. throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  188. }
  189. this._name = name;
  190. this._type = type;
  191. this._initializer = initializer;
  192. }
  193. get name() {
  194. return this._name;
  195. }
  196. get type() {
  197. if (this._initializer) {
  198. return this._initializer.type;
  199. }
  200. return this._type;
  201. }
  202. get initializer() {
  203. return this._initializer;
  204. }
  205. };
  206. torch.Node = class {
  207. constructor(metadata, module, groups, name, inputs, outputs) {
  208. this._group = groups.join('/');
  209. if (module.name && typeof module.name === 'string') {
  210. this._name = module.name;
  211. delete module.name;
  212. }
  213. else {
  214. this._name = this._group ? (this._group + ':' + name) : name;
  215. }
  216. const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : 'nn.Module';
  217. this._type = metadata.type(type);
  218. let initializers = [];
  219. for (const entry of Object.entries(module)) {
  220. const key = entry[0];
  221. const obj = entry[1];
  222. if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
  223. module[key] = obj.data();
  224. }
  225. }
  226. delete module.iSize;
  227. delete module.finput;
  228. delete module.fgradInput;
  229. delete module.output;
  230. delete module.gradInput;
  231. delete module.gradWeight;
  232. delete module.gradBias;
  233. delete module.grad_tmp;
  234. delete module.scaleT;
  235. delete module._input;
  236. delete module._output;
  237. delete module._gradInput;
  238. delete module._gradOutput;
  239. delete module.buffer;
  240. delete module.buffer2;
  241. delete module.tmp_in;
  242. delete module.tmp_out;
  243. delete module.accUpdateGradParameters;
  244. switch (this._type.name) {
  245. case 'nn.Linear':
  246. delete module.addBuffer;
  247. break;
  248. case 'nn.Normalize':
  249. case 'nn.Normalize2':
  250. delete module.addBuffer;
  251. delete module.normp;
  252. delete module.norm;
  253. break;
  254. case 'cudnn.SpatialConvolution':
  255. case 'cudnn.SpatialFullConvolution':
  256. case 'nn.SpatialConvolution':
  257. case 'nn.SpatialConvolutionMM':
  258. case 'nn.SpatialDilatedConvolution':
  259. case 'nn.SpatialFullConvolution':
  260. delete module.ones;
  261. delete module.input_slice;
  262. delete module.output_slice;
  263. delete module.convDescData;
  264. this._updateSize(module, 'adj');
  265. this._updateSize(module, 'd');
  266. this._updateSize(module, 'dilation');
  267. this._updateSize(module, 'k');
  268. this._updateSize(module, 'pad');
  269. break;
  270. case 'cudnn.BatchNormalization':
  271. case 'cudnn.SpatialBatchNormalization':
  272. case 'nn.BatchNormalization':
  273. case 'nn.SpatialBatchNormalization':
  274. case 'nn.InstanceNormalization':
  275. delete module.save_mean;
  276. delete module.save_std;
  277. delete module.gradWeight;
  278. delete module.normalized;
  279. delete module.centered;
  280. delete module.bn; // TODO InstanceNormalization
  281. break;
  282. case 'nn.SpatialCrossMapLRN':
  283. delete module.scale;
  284. break;
  285. case 'cudnn.SpatialMaxPooling':
  286. case 'cudnn.SpatialAveragePooling':
  287. case 'inn.SpatialMaxPooling':
  288. case 'nn.SpatialMaxPooling':
  289. case 'nn.SpatialAveragePooling':
  290. delete module.indices;
  291. this._updateSize(module, 'pad');
  292. this._updateSize(module, 'd');
  293. this._updateSize(module, 'k');
  294. break;
  295. case 'nn.SpatialZeroPadding':
  296. case 'nn.SpatialReflectionPadding':
  297. case 'nn.SpatialReplicationPadding':
  298. this._updateBox(module, 'pad');
  299. break;
  300. case 'nn.Dropout':
  301. delete module.noise;
  302. break;
  303. case 'nn.gModule':
  304. delete module.forwardnodes;
  305. delete module.backwardnodes;
  306. break;
  307. case 'nn.StereoJoin':
  308. delete module.output_L;
  309. break;
  310. default:
  311. break;
  312. }
  313. this._attributes = [];
  314. if (module.__class__) {
  315. for (const entry of Object.entries(module)) {
  316. const key = entry[0];
  317. const obj = entry[1];
  318. if (key == '_type') {
  319. continue;
  320. }
  321. if (Array.isArray(obj) && obj.every(((item) => item && item.__class__ && item.__class__.__module__ === 'nn'))) {
  322. continue;
  323. }
  324. if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
  325. initializers.push(new torch.Parameter(key, true, [
  326. new torch.Argument(key, null, new torch.Tensor(obj))
  327. ]));
  328. continue;
  329. }
  330. if (key == 'modules') {
  331. continue;
  332. }
  333. if (obj.__class__ && obj.__class__.__module__ !== '' && obj.__class__.__name__ != 'LuaFunction') {
  334. continue;
  335. }
  336. const attribute = new torch.Attribute(metadata, type, key, obj);
  337. this._attributes.push(attribute);
  338. }
  339. }
  340. this._inputs = [];
  341. if (inputs.length == 0 && this._name) {
  342. inputs.push(new torch.Argument(this._name + ':in', null, null));
  343. }
  344. this._inputs.push(new torch.Parameter('input', true, inputs));
  345. if (outputs.length == 0 && this._name) {
  346. outputs.push(new torch.Argument(this._name, null, null));
  347. }
  348. this._outputs = [];
  349. this._outputs.push(new torch.Parameter('output', true, outputs));
  350. initializers = initializers.filter((argument) => {
  351. if (argument.name == 'weight') {
  352. this._inputs.push(argument);
  353. return false;
  354. }
  355. return true;
  356. });
  357. initializers = initializers.filter((argument) => {
  358. if (argument.name == 'bias') {
  359. this._inputs.push(argument);
  360. return false;
  361. }
  362. return true;
  363. });
  364. this._inputs = this._inputs.concat(initializers);
  365. }
  366. get name() {
  367. return this._name;
  368. }
  369. get type() {
  370. return this._type;
  371. }
  372. get group() {
  373. return this._group;
  374. }
  375. get attributes() {
  376. return this._attributes;
  377. }
  378. get inputs() {
  379. return this._inputs;
  380. }
  381. get outputs() {
  382. return this._outputs;
  383. }
  384. _updateSize(module, name) {
  385. if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
  386. Object.prototype.hasOwnProperty.call(module, name + 'H')) {
  387. module[name] = [ module[name + 'W'], module[name + 'H'] ];
  388. delete module[name + 'W'];
  389. delete module[name + 'H'];
  390. }
  391. }
  392. _updateBox(module, name) {
  393. if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
  394. Object.prototype.hasOwnProperty.call(module, name + '_r') &&
  395. Object.prototype.hasOwnProperty.call(module, name + '_b') &&
  396. Object.prototype.hasOwnProperty.call(module, name + '_l')) {
  397. module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
  398. delete module[name + '_t'];
  399. delete module[name + '_r'];
  400. delete module[name + '_b'];
  401. delete module[name + '_l'];
  402. }
  403. }
  404. };
  405. torch.Attribute = class {
  406. constructor(metadata, type, name, value) {
  407. this._name = name;
  408. this._value = value;
  409. if (name == 'train') {
  410. this._visible = false;
  411. }
  412. const schema = metadata.attribute(type, name);
  413. if (schema) {
  414. if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
  415. this._visible = schema.visible;
  416. }
  417. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  418. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  419. this._visible = false;
  420. }
  421. }
  422. }
  423. }
  424. get name() {
  425. return this._name;
  426. }
  427. get value() {
  428. return this._value;
  429. }
  430. get visible() {
  431. return this._visible == false ? false : true;
  432. }
  433. };
  434. torch.Tensor = class {
  435. constructor(tensor) {
  436. this._type = new torch.TensorType(tensor);
  437. this._storage = tensor.storage;
  438. this._offset = tensor.storage_offset;
  439. }
  440. get type() {
  441. return this._type;
  442. }
  443. get state() {
  444. return this._context().state || null;
  445. }
  446. get value() {
  447. const context = this._context();
  448. if (context.state) {
  449. return null;
  450. }
  451. context.limit = Number.MAX_SAFE_INTEGER;
  452. return this._decode(context, 0);
  453. }
  454. toString() {
  455. const context = this._context();
  456. if (context.state) {
  457. return '';
  458. }
  459. context.limit = 1000;
  460. const value = this._decode(context, 0);
  461. return JSON.stringify(value, null, 4);
  462. }
  463. _context() {
  464. const context = {};
  465. context.state = null;
  466. context.index = 0;
  467. context.count = 0;
  468. if (!this._storage) {
  469. context.state = 'Tensor data is empty.';
  470. return context;
  471. }
  472. context.data = this._storage.data();
  473. context.index = this._offset;
  474. if (!context.data) {
  475. context.state = 'Tensor data is empty.';
  476. return context;
  477. }
  478. switch (this._type.dataType) {
  479. case 'uint8':
  480. case 'int8':
  481. case 'int16':
  482. case 'int32':
  483. case 'int64':
  484. case 'float32':
  485. case 'float64':
  486. break;
  487. default:
  488. context.state = 'Tensor data type is not implemented.';
  489. break;
  490. }
  491. context.dimensions = this._type.shape.dimensions;
  492. if (!context.dimensions && context.dimensions.length == 0) {
  493. context.state = 'Tensor has no dimensions.';
  494. return context;
  495. }
  496. return context;
  497. }
  498. _decode(context, dimension) {
  499. const results = [];
  500. const size = context.dimensions[dimension];
  501. if (dimension == context.dimensions.length - 1) {
  502. for (let i = 0; i < size; i++) {
  503. if (context.count > context.limit) {
  504. results.push('...');
  505. return results;
  506. }
  507. results.push(context.data[context.index]);
  508. context.index++;
  509. context.count++;
  510. }
  511. }
  512. else {
  513. for (let j = 0; j < size; j++) {
  514. if (context.count > context.limit) {
  515. results.push('...');
  516. return results;
  517. }
  518. results.push(this._decode(context, dimension + 1));
  519. }
  520. }
  521. return results;
  522. }
  523. };
  524. torch.TensorType = class {
  525. constructor(tensor) {
  526. this._dataType = tensor.dataType;
  527. this._shape = new torch.TensorShape(tensor.size);
  528. }
  529. get dataType() {
  530. return this._dataType;
  531. }
  532. get shape() {
  533. return this._shape;
  534. }
  535. toString() {
  536. return (this.dataType || '?') + this._shape.toString();
  537. }
  538. };
  539. torch.TensorShape = class {
  540. constructor(dimensions) {
  541. this._dimensions = dimensions;
  542. }
  543. get dimensions() {
  544. return this._dimensions;
  545. }
  546. toString() {
  547. if (this._dimensions) {
  548. if (this._dimensions.length == 0) {
  549. return '';
  550. }
  551. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  552. }
  553. return '';
  554. }
  555. };
  556. torch.Error = class extends Error {
  557. constructor(message) {
  558. super(message);
  559. this.name = 'Error loading Torch model.';
  560. }
  561. };
  562. torch.T7Reader = class {
  563. static open(context) {
  564. const stream = context.stream;
  565. if (stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index == 0 && value <= 0x08))) {
  566. const reader = new torch.BinaryReader(stream);
  567. return new torch.T7Reader(reader);
  568. }
  569. if (stream.length >= 2) {
  570. const buffer = stream.peek(2);
  571. const value = String.fromCharCode(stream.peek(1)[0]);
  572. if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
  573. const reader = new torch.TextReader(stream);
  574. return new torch.T7Reader(reader);
  575. }
  576. }
  577. return null;
  578. }
  579. constructor(reader) {
  580. this._reader = reader;
  581. this._memo = new Map();
  582. this._types = new Map();
  583. const Storage = class {
  584. constructor(dataType, itemSize) {
  585. this.dataType = dataType;
  586. this.itemSize = itemSize;
  587. }
  588. data() {
  589. if (this.reader) {
  590. const reader = this.reader;
  591. reader.reset();
  592. const dataType = this.dataType;
  593. const size = this.size;
  594. const array = new Array(size);
  595. for (let i = 0; i < size; i++) {
  596. switch (dataType) {
  597. case 'uint8':
  598. array[i] = reader.byte();
  599. break;
  600. case 'int8':
  601. array[i] = reader.int8();
  602. break;
  603. case 'int16':
  604. array[i] = reader.int16();
  605. break;
  606. case 'int32':
  607. array[i] = reader.int32();
  608. break;
  609. case 'int64':
  610. array[i] = reader.int64();
  611. break;
  612. case 'float32':
  613. array[i] = reader.float32();
  614. break;
  615. case 'float64':
  616. array[i] = reader.float64();
  617. break;
  618. default:
  619. throw new torch.Error("Unsupported data type '" + dataType + "'.");
  620. }
  621. }
  622. this._data = array;
  623. delete this.reader;
  624. }
  625. return this._data;
  626. }
  627. read(reader) {
  628. this.size = reader.int64();
  629. this.reader = reader.storage(this.size, this.itemSize, this.dataType);
  630. }
  631. };
  632. const Tensor = class {
  633. constructor(dataType) {
  634. this.dataType = dataType;
  635. }
  636. read(reader) {
  637. const dim = reader.int32();
  638. this.size = reader.int64s(dim);
  639. this.stride = reader.int64s(dim);
  640. this.storage_offset = reader.int64() - 1;
  641. this.storage = reader.read();
  642. }
  643. };
  644. this.register('bnn.Binary');
  645. this.register('bnn.SpatialConvolution');
  646. this.register('cudnn.BatchNormalization');
  647. this.register('cudnn.BatchBRNNReLU');
  648. this.register('cudnn.BLSTM');
  649. this.register('cudnn.ReLU');
  650. this.register('cudnn.RNN');
  651. this.register('cudnn.Sigmoid');
  652. this.register('cudnn.SoftMax');
  653. this.register('cudnn.LogSoftMax');
  654. this.register('cudnn.normal3DConv');
  655. this.register('cudnn.normal3DdeConv');
  656. this.register('cudnn.SpatialAveragePooling');
  657. this.register('cudnn.SpatialBatchNormalization');
  658. this.register('cudnn.SpatialConvolution');
  659. this.register('cudnn.SpatialFullConvolution');
  660. this.register('cudnn.SpatialMaxPooling');
  661. this.register('cudnn.SpatialSoftMax');
  662. this.register('cudnn.Tanh');
  663. this.register('cudnn.VolumetricAveragePooling');
  664. this.register('cudnn.VolumetricBatchNormalization');
  665. this.register('cudnn.VolumetricConvolution');
  666. this.register('cudnn.VolumetricMaxPooling');
  667. this.register('Dict');
  668. this.register('inn.ConstAffine');
  669. this.register('inn.SpatialMaxPooling');
  670. this.register('nn.Abs');
  671. this.register('nn.AddConstant');
  672. this.register('nn.BatchNormalization');
  673. this.register('nn.BilinearSamplerBHWD');
  674. this.register('nn.BinActiveZ'); // allenai/XNOR-Net
  675. this.register('nn.BCECriterion');
  676. this.register('nn.Bottle');
  677. this.register('nn.Clamp');
  678. this.register('nn.CMul');
  679. this.register('nn.CAddTable');
  680. this.register('nn.CDivTable');
  681. this.register('nn.CMulTable');
  682. this.register('nn.CSubTable');
  683. this.register('nn.Concat');
  684. this.register('nn.Copy');
  685. this.register('nn.ConcatTable');
  686. this.register('nn.Contiguous');
  687. this.register('nn.Constant');
  688. this.register('nn.CostVolMulti');
  689. this.register('nn.DataParallelTable');
  690. this.register('nn.DepthConcat');
  691. this.register('nn.Dropout');
  692. this.register('nn.Exp');
  693. this.register('nn.ExpOut');
  694. this.register('nn.FlattenTable');
  695. this.register('nn.GenNoise');
  696. this.register('nn.Identity');
  697. this.register('nn.Index');
  698. this.register('nn.Inception');
  699. this.register('nn.InstanceNormalization');
  700. this.register('nn.JoinTable');
  701. this.register('nn.JointTrain');
  702. this.register('nn.KeypointCoordinate');
  703. this.register('nn.LeakyReLU');
  704. this.register('nn.Linear');
  705. this.register('nn.LinearNoBias');
  706. this.register('nn.LogSoftMax');
  707. this.register('nn.LookupTable');
  708. this.register('nn.LSTM');
  709. this.register('nn.MaskZero');
  710. this.register('nn.MapTable');
  711. this.register('nn.Max');
  712. this.register('nn.Mean');
  713. this.register('nn.Min');
  714. this.register('nn.MulConstant');
  715. this.register('nn.MM');
  716. this.register('nn.MSECriterion');
  717. this.register('nn.Narrow');
  718. this.register('nn.NarrowTable');
  719. this.register('nn.Normalize');
  720. this.register('nn.Normalize2');
  721. this.register('nn.NoiseFill');
  722. this.register('nn.Padding');
  723. this.register('nn.Parallel');
  724. this.register('nn.ParallelCriterion');
  725. this.register('nn.ParallelTable');
  726. this.register('nn.PixelShuffle');
  727. this.register('nn.Power');
  728. this.register('nn.PReLU');
  729. this.register('nn.Recursor');
  730. this.register('nn.ReLU');
  731. this.register('nn.Replicate');
  732. this.register('nn.Reshape');
  733. this.register('nn.ShaveImage');
  734. this.register('nn.Select');
  735. this.register('nn.SelectTable');
  736. this.register('nn.Sequencer');
  737. this.register('nn.Sequential');
  738. this.register('nn.Sigmoid');
  739. this.register('nn.Sum');
  740. this.register('nn.SoftMax');
  741. this.register('nn.SpatialAveragePooling');
  742. this.register('nn.SpatialBatchNormalization');
  743. this.register('nn.SpatialConvolution');
  744. this.register('nn.SpatialConvolutionMM');
  745. this.register('nn.SpatialCrossMapLRN');
  746. this.register('nn.SpatialDilatedConvolution');
  747. this.register('nn.SpatialDropout');
  748. this.register('nn.SpatialFractionalMaxPooling');
  749. this.register('nn.SpatialFullConvolution');
  750. this.register('nn.SpatialLPPooling');
  751. this.register('nn.SpatialMaxPooling');
  752. this.register('nn.SpatialMaxUnpooling');
  753. this.register('nn.SpatialReflectionPadding');
  754. this.register('nn.SpatialReplicationPadding');
  755. this.register('nn.SpatialSoftMax');
  756. this.register('nn.SpatialSubtractiveNormalization');
  757. this.register('nn.SpatialUpSamplingBilinear');
  758. this.register('nn.SpatialUpSamplingNearest');
  759. this.register('nn.SpatialZeroPadding');
  760. this.register('nn.SplitTable');
  761. this.register('nn.Squeeze');
  762. this.register('nn.Square');
  763. this.register('nn.Sqrt');
  764. this.register('nn.StereoJoin');
  765. this.register('nn.Tanh');
  766. this.register('nn.Transpose');
  767. this.register('nn.TotalVariation');
  768. this.register('nn.Unpool');
  769. this.register('nn.View');
  770. this.register('nn.gModule');
  771. this.register('nngraph.Node');
  772. this.register('graph.Edge');
  773. this.register('graph.Graph');
  774. this.register('torch.ByteTensor', class extends Tensor { constructor() { super('uint8'); } });
  775. this.register('torch.CharTensor', class extends Tensor { constructor() { super('int8'); } });
  776. this.register('torch.ShortTensor', class extends Tensor { constructor() { super('int16'); } });
  777. this.register('torch.IntTensor', class extends Tensor { constructor() { super('int32'); } });
  778. this.register('torch.LongTensor', class extends Tensor { constructor() { super('int64'); } });
  779. this.register('torch.FloatTensor', class extends Tensor { constructor() { super('float32'); } });
  780. this.register('torch.DoubleTensor', class extends Tensor { constructor() { super('float64'); } });
  781. this.register('torch.CudaByteTensor', class extends Tensor { constructor() { super('uint8'); } });
  782. this.register('torch.CudaCharTensor', class extends Tensor { constructor() { super('int8'); } });
  783. this.register('torch.CudaShortTensor', class extends Tensor { constructor() { super('int16'); } });
  784. this.register('torch.CudaIntTensor', class extends Tensor { constructor() { super('int32'); } });
  785. this.register('torch.CudaLongTensor', class extends Tensor { constructor() { super('int64'); } });
  786. this.register('torch.CudaTensor', class extends Tensor { constructor() { super('float32'); } });
  787. this.register('torch.CudaDoubleTensor', class extends Tensor { constructor() { super('float64'); } });
  788. this.register('torch.ByteStorage', class extends Storage { constructor() { super('uint8', 1); } });
  789. this.register('torch.CharStorage', class extends Storage { constructor() { super('int8', 1); } });
  790. this.register('torch.ShortStorage', class extends Storage { constructor() { super('int16', 2); } });
  791. this.register('torch.IntStorage', class extends Storage { constructor() { super('int32', 4); } });
  792. this.register('torch.LongStorage', class extends Storage { constructor() { super('int64', 8); } });
  793. this.register('torch.FloatStorage', class extends Storage { constructor() { super('float32', 4); } });
  794. this.register('torch.DoubleStorage', class extends Storage { constructor() { super('float64', 8); } });
  795. this.register('torch.CudaByteStorage', class extends Storage { constructor() { super('uint8', 1); } });
  796. this.register('torch.CudaCharStorage', class extends Storage { constructor() { super('int8', 1); } });
  797. this.register('torch.CudaShortStorage', class extends Storage { constructor() { super('int16', 2); } });
  798. this.register('torch.CudaIntStorage', class extends Storage { constructor() { super('int32', 4); } });
  799. this.register('torch.CudaLongStorage', class extends Storage { constructor() { super('int64', 8); } });
  800. this.register('torch.CudaIntStorage', class extends Storage { constructor() { super('int32', 4); } });
  801. this.register('torch.CudaStorage', class extends Storage { constructor() { super('float32', 4); } });
  802. this.register('torch.CudaFloatStorage', class extends Storage { constructor() { super('float64', 8); } });
  803. this.register('w2nn.AuxiliaryLossTable');
  804. this.register('w2nn.InplaceClip01');
  805. this.register('w2nn.ScaleTable');
  806. this.register('LuaFunction', class {
  807. constructor(size, dumped, upvalues) {
  808. this.size = size;
  809. this.dumped = dumped;
  810. this.upvalues = upvalues;
  811. }
  812. });
  813. }
  814. register(name, type) {
  815. type = type || class {};
  816. const parts = name.split('.');
  817. type.__name__ = parts.pop();
  818. type.__module__ = parts.join('.');
  819. type.prototype.__class__ = type;
  820. this._types.set(name, type);
  821. }
  822. read() {
  823. const type = this.int32();
  824. switch (type) {
  825. case 0: return null;
  826. case 1: return this.float64();
  827. case 2: return this.string();
  828. case 3: return this.table();
  829. case 4: return this.object();
  830. case 5: return this.boolean();
  831. case 6: return this.function();
  832. case 7: return this.function();
  833. case 8: return this.function();
  834. default: throw new torch.Error("File format has invalid type '" + type + "'.");
  835. }
  836. }
  837. boolean() {
  838. return this._reader.boolean();
  839. }
  840. bytes(size) {
  841. return this._reader.bytes(size);
  842. }
  843. int32() {
  844. return this._reader.int32();
  845. }
  846. int64() {
  847. return this._reader.int64();
  848. }
  849. int64s(size) {
  850. return this._reader.int64s(size);
  851. }
  852. float64() {
  853. return this._reader.float64();
  854. }
  855. string() {
  856. return this._reader.string();
  857. }
  858. object() {
  859. const index = this.int32();
  860. if (this._memo.has(index)) {
  861. return this._memo.get(index);
  862. }
  863. let version = this.string();
  864. let name = null;
  865. if (version.startsWith('V ')) {
  866. name = this.string();
  867. version = Number(version.split(' ')[1]);
  868. }
  869. else {
  870. name = version;
  871. version = 0;
  872. }
  873. if (!this._types.has(name)) {
  874. this.callback(name);
  875. this.register(name);
  876. }
  877. const type = this._types.get(name);
  878. const obj = Reflect.construct(type, []);
  879. this._memo.set(index, obj);
  880. if (obj.read) {
  881. obj.read(this, version);
  882. }
  883. else {
  884. const attributes = this.read();
  885. if (attributes != null) {
  886. for (const entry of Object.entries(attributes)) {
  887. const key = entry[0];
  888. obj[key] = entry[1];
  889. }
  890. }
  891. }
  892. return obj;
  893. }
  894. table() {
  895. const index = this.int32();
  896. if (this._memo.has(index)) {
  897. return this._memo.get(index);
  898. }
  899. const table = {};
  900. this._memo.set(index, table);
  901. const size = this.int32();
  902. let convert = true;
  903. let sum = 0;
  904. for (let i = 0; i < size; i++) {
  905. const key = this.read();
  906. const value = this.read();
  907. table[key] = value;
  908. if (Number.isInteger(key) && key >= 0) {
  909. sum += key;
  910. }
  911. else {
  912. convert = false;
  913. }
  914. }
  915. const n = Object.keys(table).length;
  916. if (convert && (n * (n + 1)) == (2 * sum)) {
  917. const list = [];
  918. for (let j = 0; j < n; j++) {
  919. let item = table[j + 1];
  920. if (item == table) {
  921. item = list;
  922. }
  923. list.push(item);
  924. }
  925. this._memo.set(index, list);
  926. return list;
  927. }
  928. return table;
  929. }
  930. function() {
  931. const index = this.int32();
  932. if (this._memo.has(index)) {
  933. return this._memo.get(index);
  934. }
  935. const size = this.int32();
  936. const dumped = this.bytes(size);
  937. const upvalues = this.read();
  938. const type = this._types.get('LuaFunction');
  939. const obj = Reflect.construct(type, [ size, dumped, upvalues ]);
  940. this._memo.set(index, obj);
  941. return obj;
  942. }
  943. storage(size, itemSize, dataType) {
  944. return this._reader.storage(size, itemSize, dataType);
  945. }
  946. };
  947. torch.BinaryReader = class {
  948. constructor(data) {
  949. this._buffer = data instanceof Uint8Array ? data : data.peek();
  950. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  951. this._position = 0;
  952. this._textDecoder = new TextDecoder('ascii');
  953. }
  954. reset() {
  955. this._position = 0;
  956. }
  957. skip(offset) {
  958. this._position += offset;
  959. if (this._position > this._buffer.length) {
  960. throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  961. }
  962. }
  963. boolean() {
  964. return this.int32() == 1;
  965. }
  966. bytes(length) {
  967. const position = this._position;
  968. this.skip(length);
  969. return this._buffer.subarray(position, this._position);
  970. }
  971. int8() {
  972. const position = this._position;
  973. this.skip(1);
  974. return this._dataView.getInt8(position, true);
  975. }
  976. int16() {
  977. const position = this._position;
  978. this.skip(2);
  979. return this._dataView.getInt16(position, true);
  980. }
  981. int32() {
  982. const position = this._position;
  983. this.skip(4);
  984. return this._dataView.getInt32(position, true);
  985. }
  986. int64() {
  987. const position = this._position;
  988. this.skip(8);
  989. return this._dataView.getInt64(position, true).toNumber();
  990. }
  991. int64s(size) {
  992. const array = [];
  993. for (let i = 0; i < size; i++) {
  994. array.push(this.int64());
  995. }
  996. return array;
  997. }
  998. float32() {
  999. const position = this._position;
  1000. this.skip(4);
  1001. return this._dataView.getFloat32(position, true);
  1002. }
  1003. float64() {
  1004. const position = this._position;
  1005. this.skip(8);
  1006. return this._dataView.getFloat64(position, true);
  1007. }
  1008. string() {
  1009. return this._textDecoder.decode(this.bytes(this.int32()));
  1010. }
  1011. storage(size, itemSize) {
  1012. return new torch.BinaryReader(this.bytes(size * itemSize));
  1013. }
  1014. };
  1015. torch.TextReader = class {
  1016. constructor(data, separator) {
  1017. this._buffer = data instanceof Uint8Array ? data : data.peek();
  1018. this._position = 0;
  1019. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  1020. this._textDecoder = new TextDecoder('ascii');
  1021. this._separator = separator || 0x0a;
  1022. }
  1023. reset() {
  1024. this._position = 0;
  1025. }
  1026. line(size) {
  1027. const start = this._position;
  1028. while (this._position < this._buffer.length && size > -1) {
  1029. const c = this._buffer[this._position++];
  1030. if (c == this._separator) {
  1031. return this._buffer.slice(start, this._position - 1);
  1032. }
  1033. else if (this._position == this._buffer.length) {
  1034. return this._buffer.slice(start, this._position);
  1035. }
  1036. size--;
  1037. }
  1038. throw new torch.Error('Line exceeded maximum length.');
  1039. }
  1040. boolean() {
  1041. return this.int32() == 1;
  1042. }
  1043. bytes(size) {
  1044. return this.line(size);
  1045. }
  1046. int8() {
  1047. return this.int64();
  1048. }
  1049. int16() {
  1050. return this.int64();
  1051. }
  1052. int32() {
  1053. return this.int64();
  1054. }
  1055. int64() {
  1056. const token = this._textDecoder.decode(this.line(20));
  1057. const number = Number.parseInt(token, 10);
  1058. if (Number.isNaN(token - number)) {
  1059. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1060. }
  1061. return number;
  1062. }
  1063. int64s(size) {
  1064. const array = [];
  1065. if (size > 0) {
  1066. const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  1067. for (const token of content.split(' ')) {
  1068. const number = Number.parseInt(token, 10);
  1069. if (Number.isNaN(token - number)) {
  1070. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1071. }
  1072. array.push(number);
  1073. }
  1074. }
  1075. return array;
  1076. }
  1077. float32() {
  1078. return this.float64();
  1079. }
  1080. float64() {
  1081. const token = this._textDecoder.decode(this.line(24));
  1082. if (token.startsWith('-nan')) {
  1083. return -NaN;
  1084. }
  1085. if (token.startsWith('nan')) {
  1086. return NaN;
  1087. }
  1088. if (token.startsWith('inf')) {
  1089. return Infinity;
  1090. }
  1091. if (token.startsWith('-inf')) {
  1092. return -Infinity;
  1093. }
  1094. const number = Number.parseFloat(token);
  1095. if (Number.isNaN(token - number)) {
  1096. throw new torch.Error("Couldn't parse float '" + token + "'.");
  1097. }
  1098. return number;
  1099. }
  1100. string() {
  1101. const size = this.int32();
  1102. if (size == 0) {
  1103. return '';
  1104. }
  1105. const data = this.line(size);
  1106. const content = this._textDecoder.decode(data);
  1107. if (size != content.length) {
  1108. throw new torch.Error('Invalid string length.');
  1109. }
  1110. return content;
  1111. }
  1112. storage(size, itemSize, dataType) {
  1113. if (size <= 0) {
  1114. throw new torch.Error("Unsupported storage size '" + size + "'.");
  1115. }
  1116. if (dataType === 'uint8') {
  1117. const start = this._position;
  1118. this._position += size;
  1119. const bytes = this._buffer.slice(start, this._position);
  1120. this.line(0);
  1121. return new torch.BinaryReader(bytes);
  1122. }
  1123. const data = this.line(Number.MAX_SAFE_INTEGER);
  1124. return new torch.TextReader(data, 0x20);
  1125. }
  1126. };
  1127. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1128. module.exports.ModelFactory = torch.ModelFactory;
  1129. }