torch.js 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244
  1. /* jshint esversion: 6 */
  2. var torch = torch || {};
  3. torch.ModelFactory = class {
  4. match(context) {
  5. const extension = context.identifier.split('.').pop().toLowerCase();
  6. if (extension == 't7') {
  7. const stream = context.stream;
  8. if (stream.length >= 1 && stream.peek(1)[0] <= 58) {
  9. return 'torch';
  10. }
  11. }
  12. return undefined;
  13. }
  14. open(context) {
  15. return torch.Metadata.open(context).then((metadata) => {
  16. const identifier = context.identifier;
  17. const buffer = context.stream.peek();
  18. const reader = new torch.T7Reader(buffer, (name) => {
  19. if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  20. context.exception(new torch.Error("Unknown type '" + name + "' in '" + identifier + "'."), false);
  21. }
  22. return null;
  23. });
  24. let root = reader.read();
  25. if (root && Array.isArray(root) && root.length == 2 && root[0].__type__ && !root[1].__type__) {
  26. root = root[0];
  27. }
  28. return new torch.Model(metadata, root);
  29. });
  30. }
  31. };
  32. torch.Model = class {
  33. constructor(metadata, root) {
  34. this._graphs = [];
  35. this._graphs.push(new torch.Graph(metadata, root));
  36. }
  37. get graphs() {
  38. return this._graphs;
  39. }
  40. get format() {
  41. return 'Torch v7';
  42. }
  43. };
  44. torch.Graph = class {
  45. constructor(metadata, root) {
  46. this._inputs = [];
  47. this._outputs = [];
  48. this._nodes = [];
  49. this._groups = 'false';
  50. if (Object.prototype.hasOwnProperty.call(root, 'model')) {
  51. root = root.model;
  52. }
  53. const inputs = [];
  54. const outputs = [];
  55. this._loadModule(metadata, root, [], '', inputs, outputs);
  56. this._inputs = this._inputs.concat(inputs.map((input, index) => {
  57. return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
  58. }));
  59. this._outputs = this._outputs.concat(outputs.map((output, index) => {
  60. return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
  61. }));
  62. }
  63. get inputs() {
  64. return this._inputs;
  65. }
  66. get outputs() {
  67. return this._outputs;
  68. }
  69. get nodes() {
  70. return this._nodes;
  71. }
  72. get groups() {
  73. return this._groups;
  74. }
  75. _loadModule(metadata, module, groups, key, inputs, outputs) {
  76. if (groups.length > 0) {
  77. this._groups = true;
  78. }
  79. switch (module.__type__) {
  80. case 'nn.Sequential': {
  81. groups.push(key);
  82. let subInputs = inputs;
  83. let subOutputs = [];
  84. const length = module.modules.length;
  85. let index = 0;
  86. for (const subModule of module.modules) {
  87. if (index == length - 1) {
  88. subOutputs = outputs;
  89. }
  90. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  91. subInputs = subOutputs;
  92. subOutputs = [];
  93. index++;
  94. }
  95. groups.pop();
  96. break;
  97. }
  98. case 'nn.Parallel':
  99. case 'nn.ParallelTable':
  100. case 'nn.JointTrain': {
  101. groups.push(key);
  102. let newInputs = [];
  103. let newOutputs = [];
  104. let index = 0;
  105. for (const subModule of module.modules) {
  106. const subInputs = [].concat(inputs);
  107. const subOutputs = [].concat(outputs);
  108. this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  109. if (inputs.length == 0) {
  110. newInputs = newInputs.concat(subInputs);
  111. }
  112. if (outputs.length == 0) {
  113. newOutputs = newOutputs.concat(subOutputs);
  114. }
  115. index++;
  116. }
  117. inputs = inputs.concat(newInputs);
  118. for (const newOutput of newOutputs) {
  119. outputs.push(newOutput);
  120. }
  121. groups.pop();
  122. break;
  123. }
  124. case 'nn.Concat':
  125. case 'nn.ConcatTable': {
  126. const prefix = key;
  127. if (inputs.length == 0) {
  128. inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
  129. }
  130. let concatInputs = [];
  131. let index = 0;
  132. for (const subModule of module.modules) {
  133. const streamInputs = inputs.map((input) => input);
  134. const streamOutputs = [];
  135. this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
  136. concatInputs = concatInputs.concat(streamOutputs);
  137. index++;
  138. }
  139. delete module.modules;
  140. delete module.dimension;
  141. this._createNode(metadata, module, groups, key, concatInputs, outputs);
  142. break;
  143. }
  144. case 'nn.Inception': {
  145. delete module.modules; // TODO
  146. delete module.module; // TODO
  147. delete module.transfer; // TODO
  148. delete module.pool; // TODO
  149. this._createNode(metadata, module, groups, key, inputs, outputs);
  150. break;
  151. }
  152. default: {
  153. this._createNode(metadata, module, groups, key, inputs, outputs);
  154. break;
  155. }
  156. }
  157. }
  158. _createNode(metadata, module, group, subIndex, inputs, outputs) {
  159. this._nodes.push(new torch.Node(metadata, module, group, subIndex, inputs, outputs));
  160. }
  161. };
  162. torch.Parameter = class {
  163. constructor(name, visible, args) {
  164. this._name = name;
  165. this._visible = visible;
  166. this._arguments = args;
  167. }
  168. get name() {
  169. return this._name;
  170. }
  171. get visible() {
  172. return this._visible;
  173. }
  174. get arguments() {
  175. return this._arguments;
  176. }
  177. };
  178. torch.Argument = class {
  179. constructor(name, type, initializer) {
  180. if (typeof name !== 'string') {
  181. throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  182. }
  183. this._name = name;
  184. this._type = type;
  185. this._initializer = initializer;
  186. }
  187. get name() {
  188. return this._name;
  189. }
  190. get type() {
  191. if (this._initializer) {
  192. return this._initializer.type;
  193. }
  194. return this._type;
  195. }
  196. get initializer() {
  197. return this._initializer;
  198. }
  199. };
  200. torch.Node = class {
  201. constructor(metadata, module, groups, name, inputs, outputs) {
  202. this._group = groups.join('/');
  203. if (module.name && typeof module.name === 'string') {
  204. this._name = module.name;
  205. delete module.name;
  206. }
  207. else {
  208. this._name = this._group ? (this._group + ':' + name) : name;
  209. }
  210. const type = module.__type__ || 'nn.Module';
  211. this._type = metadata.type(type);
  212. let initializers = [];
  213. for (const key of Object.keys(module)) {
  214. const obj = module[key];
  215. if (obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Storage')) {
  216. module[key] = obj.data();
  217. }
  218. }
  219. delete module.iSize;
  220. delete module.finput;
  221. delete module.fgradInput;
  222. delete module.output;
  223. delete module.gradInput;
  224. delete module.gradWeight;
  225. delete module.gradBias;
  226. delete module.grad_tmp;
  227. delete module.scaleT;
  228. delete module._input;
  229. delete module._output;
  230. delete module._gradInput;
  231. delete module._gradOutput;
  232. delete module.buffer;
  233. delete module.buffer2;
  234. delete module.tmp_in;
  235. delete module.tmp_out;
  236. delete module.accUpdateGradParameters;
  237. switch (this._type.name) {
  238. case 'nn.Linear':
  239. delete module.addBuffer;
  240. break;
  241. case 'nn.Normalize':
  242. case 'nn.Normalize2':
  243. delete module.addBuffer;
  244. delete module.normp;
  245. delete module.norm;
  246. break;
  247. case 'cudnn.SpatialConvolution':
  248. case 'cudnn.SpatialFullConvolution':
  249. case 'nn.SpatialConvolution':
  250. case 'nn.SpatialConvolutionMM':
  251. case 'nn.SpatialDilatedConvolution':
  252. case 'nn.SpatialFullConvolution':
  253. delete module.ones;
  254. delete module.input_slice;
  255. delete module.output_slice;
  256. delete module.convDescData;
  257. this._updateSize(module, 'adj');
  258. this._updateSize(module, 'd');
  259. this._updateSize(module, 'dilation');
  260. this._updateSize(module, 'k');
  261. this._updateSize(module, 'pad');
  262. break;
  263. case 'cudnn.BatchNormalization':
  264. case 'cudnn.SpatialBatchNormalization':
  265. case 'nn.BatchNormalization':
  266. case 'nn.SpatialBatchNormalization':
  267. case 'nn.InstanceNormalization':
  268. delete module.save_mean;
  269. delete module.save_std;
  270. delete module.gradWeight;
  271. delete module.normalized;
  272. delete module.centered;
  273. delete module.bn; // TODO InstanceNormalization
  274. break;
  275. case 'nn.SpatialCrossMapLRN':
  276. delete module.scale;
  277. break;
  278. case 'cudnn.SpatialMaxPooling':
  279. case 'cudnn.SpatialAveragePooling':
  280. case 'inn.SpatialMaxPooling':
  281. case 'nn.SpatialMaxPooling':
  282. case 'nn.SpatialAveragePooling':
  283. delete module.indices;
  284. this._updateSize(module, 'pad');
  285. this._updateSize(module, 'd');
  286. this._updateSize(module, 'k');
  287. break;
  288. case 'nn.SpatialZeroPadding':
  289. case 'nn.SpatialReflectionPadding':
  290. case 'nn.SpatialReplicationPadding':
  291. this._updateBox(module, 'pad');
  292. break;
  293. case 'nn.Dropout':
  294. delete module.noise;
  295. break;
  296. case 'nn.gModule':
  297. delete module.forwardnodes;
  298. delete module.backwardnodes;
  299. break;
  300. case 'nn.StereoJoin':
  301. delete module.output_L;
  302. break;
  303. }
  304. this._attributes = [];
  305. if (module.__type__) {
  306. for (const key of Object.keys(module)) {
  307. if (key == '__type__' || key == '_type') {
  308. continue;
  309. }
  310. const obj = module[key];
  311. if (Array.isArray(obj) && obj.every(((item) => item && item.__type__ && item.__type__.startsWith('nn.')))) {
  312. continue;
  313. }
  314. if (obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Tensor')) {
  315. initializers.push(new torch.Parameter(key, true, [
  316. new torch.Argument(key, null, new torch.Tensor(obj))
  317. ]));
  318. continue;
  319. }
  320. if (key == 'modules' || (obj.__type__ && obj.__type__ != 'Function')) {
  321. continue;
  322. }
  323. this._attributes.push(new torch.Attribute(metadata, type, key, obj));
  324. }
  325. }
  326. this._inputs = [];
  327. if (inputs.length == 0 && this._name) {
  328. inputs.push(new torch.Argument(this._name + ':in', null, null));
  329. }
  330. this._inputs.push(new torch.Parameter('input', true, inputs));
  331. if (outputs.length == 0 && this._name) {
  332. outputs.push(new torch.Argument(this._name, null, null));
  333. }
  334. this._outputs = [];
  335. this._outputs.push(new torch.Parameter('output', true, outputs));
  336. initializers = initializers.filter((argument) => {
  337. if (argument.name == 'weight') {
  338. this._inputs.push(argument);
  339. return false;
  340. }
  341. return true;
  342. });
  343. initializers = initializers.filter((argument) => {
  344. if (argument.name == 'bias') {
  345. this._inputs.push(argument);
  346. return false;
  347. }
  348. return true;
  349. });
  350. this._inputs = this._inputs.concat(initializers);
  351. }
  352. get name() {
  353. return this._name;
  354. }
  355. get type() {
  356. return this._type;
  357. }
  358. get group() {
  359. return this._group;
  360. }
  361. get attributes() {
  362. return this._attributes;
  363. }
  364. get inputs() {
  365. return this._inputs;
  366. }
  367. get outputs() {
  368. return this._outputs;
  369. }
  370. _updateSize(module, name) {
  371. if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
  372. Object.prototype.hasOwnProperty.call(module, name + 'H')) {
  373. module[name] = [ module[name + 'W'], module[name + 'H'] ];
  374. delete module[name + 'W'];
  375. delete module[name + 'H'];
  376. }
  377. }
  378. _updateBox(module, name) {
  379. if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
  380. Object.prototype.hasOwnProperty.call(module, name + '_r') &&
  381. Object.prototype.hasOwnProperty.call(module, name + '_b') &&
  382. Object.prototype.hasOwnProperty.call(module, name + '_l')) {
  383. module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
  384. delete module[name + '_t'];
  385. delete module[name + '_r'];
  386. delete module[name + '_b'];
  387. delete module[name + '_l'];
  388. }
  389. }
  390. };
  391. torch.Attribute = class {
  392. constructor(metadata, type, name, value) {
  393. this._name = name;
  394. this._value = value;
  395. if (name == 'train') {
  396. this._visible = false;
  397. }
  398. const schema = metadata.attribute(type, name);
  399. if (schema) {
  400. if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
  401. this._visible = schema.visible;
  402. }
  403. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  404. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  405. this._visible = false;
  406. }
  407. }
  408. }
  409. }
  410. get name() {
  411. return this._name;
  412. }
  413. get value() {
  414. return this._value;
  415. }
  416. get visible() {
  417. return this._visible == false ? false : true;
  418. }
  419. };
  420. torch.Tensor = class {
  421. constructor(tensor) {
  422. this._type = new torch.TensorType(tensor);
  423. this._storage = tensor.storage;
  424. this._offset = tensor.storage_offset;
  425. }
  426. get type() {
  427. return this._type;
  428. }
  429. get state() {
  430. return this._context().state || null;
  431. }
  432. get value() {
  433. const context = this._context();
  434. if (context.state) {
  435. return null;
  436. }
  437. context.limit = Number.MAX_SAFE_INTEGER;
  438. return this._decode(context, 0);
  439. }
  440. toString() {
  441. const context = this._context();
  442. if (context.state) {
  443. return '';
  444. }
  445. context.limit = 1000;
  446. const value = this._decode(context, 0);
  447. return JSON.stringify(value, null, 4);
  448. }
  449. _context() {
  450. const context = {};
  451. context.state = null;
  452. context.index = 0;
  453. context.count = 0;
  454. if (!this._storage) {
  455. context.state = 'Tensor data is empty.';
  456. return context;
  457. }
  458. context.data = this._storage.data();
  459. context.index = this._offset;
  460. if (!context.data) {
  461. context.state = 'Tensor data is empty.';
  462. return context;
  463. }
  464. switch (this._type.dataType) {
  465. case 'uint8':
  466. case 'int8':
  467. case 'int16':
  468. case 'int32':
  469. case 'int64':
  470. case 'float32':
  471. case 'float64':
  472. break;
  473. default:
  474. context.state = 'Tensor data type is not implemented.';
  475. break;
  476. }
  477. context.dimensions = this._type.shape.dimensions;
  478. if (!context.dimensions && context.dimensions.length == 0) {
  479. context.state = 'Tensor has no dimensions.';
  480. return context;
  481. }
  482. return context;
  483. }
  484. _decode(context, dimension) {
  485. const results = [];
  486. const size = context.dimensions[dimension];
  487. if (dimension == context.dimensions.length - 1) {
  488. for (let i = 0; i < size; i++) {
  489. if (context.count > context.limit) {
  490. results.push('...');
  491. return results;
  492. }
  493. results.push(context.data[context.index]);
  494. context.index++;
  495. context.count++;
  496. }
  497. }
  498. else {
  499. for (let j = 0; j < size; j++) {
  500. if (context.count > context.limit) {
  501. results.push('...');
  502. return results;
  503. }
  504. results.push(this._decode(context, dimension + 1));
  505. }
  506. }
  507. return results;
  508. }
  509. };
  510. torch.TensorType = class {
  511. constructor(tensor) {
  512. this._dataType = tensor.dataType;
  513. this._shape = new torch.TensorShape(tensor.size);
  514. }
  515. get dataType() {
  516. return this._dataType;
  517. }
  518. get shape() {
  519. return this._shape;
  520. }
  521. toString() {
  522. return (this.dataType || '?') + this._shape.toString();
  523. }
  524. };
  525. torch.TensorShape = class {
  526. constructor(dimensions) {
  527. this._dimensions = dimensions;
  528. }
  529. get dimensions() {
  530. return this._dimensions;
  531. }
  532. toString() {
  533. if (this._dimensions) {
  534. if (this._dimensions.length == 0) {
  535. return '';
  536. }
  537. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  538. }
  539. return '';
  540. }
  541. };
  542. torch.Metadata = class {
  543. static open(context) {
  544. if (torch.Metadata._metadata) {
  545. return Promise.resolve(torch.Metadata._metadata);
  546. }
  547. return context.request('torch-metadata.json', 'utf-8', null).then((data) => {
  548. torch.Metadata._metadata = new torch.Metadata(data);
  549. return torch.Metadata._metadata;
  550. }).catch(() => {
  551. torch.Metadata._metadata = new torch.Metadata(null);
  552. return torch.Metadata._metadata;
  553. });
  554. }
  555. constructor(data) {
  556. this._types = new Map();
  557. this._attributes = new Map();
  558. if (data) {
  559. const items = JSON.parse(data);
  560. for (const item of items) {
  561. this._types.set(item.name, item);
  562. }
  563. }
  564. }
  565. type(name) {
  566. if (!this._types.has(name)) {
  567. this._types.set(name, { name: name });
  568. }
  569. return this._types.get(name);
  570. }
  571. attribute(type, name) {
  572. const key = type + ':' + name;
  573. if (!this._attributes.has(key)) {
  574. this._attributes.set(key, null);
  575. const metadata = this.type(type);
  576. if (metadata && Array.isArray(metadata.attributes)) {
  577. for (const attribute of metadata.attributes) {
  578. this._attributes.set(type + ':' + attribute.name, attribute);
  579. }
  580. }
  581. }
  582. return this._attributes.get(key);
  583. }
  584. };
  585. torch.Error = class extends Error {
  586. constructor(message) {
  587. super(message);
  588. this.name = 'Error loading Torch model.';
  589. }
  590. };
  591. torch.T7Reader = class {
  592. constructor(buffer, callback) {
  593. this._callback = callback;
  594. this._memo = new Map();
  595. this._registry = {};
  596. this._registry['bnn.Binary'] = function(reader) { reader.nn(this); };
  597. this._registry['bnn.SpatialConvolution'] = function(reader) { reader.nn(this); };
  598. this._registry['cudnn.BatchNormalization'] = function(reader) { reader.nn(this); };
  599. this._registry['cudnn.BatchBRNNReLU'] = function(reader) { reader.nn(this); };
  600. this._registry['cudnn.BLSTM'] = function(reader) { reader.nn(this); };
  601. this._registry['cudnn.ReLU'] = function(reader) { reader.nn(this); };
  602. this._registry['cudnn.RNN'] = function(reader) { reader.nn(this); };
  603. this._registry['cudnn.Sigmoid'] = function(reader) { reader.nn(this); };
  604. this._registry['cudnn.SoftMax'] = function(reader) { reader.nn(this); };
  605. this._registry['cudnn.LogSoftMax'] = function(reader) { reader.nn(this); };
  606. this._registry['cudnn.SpatialAveragePooling'] = function(reader) { reader.nn(this); };
  607. this._registry['cudnn.SpatialBatchNormalization'] = function(reader) { reader.nn(this); };
  608. this._registry['cudnn.SpatialConvolution'] = function(reader) { reader.nn(this); };
  609. this._registry['cudnn.SpatialFullConvolution'] = function(reader) { reader.nn(this); };
  610. this._registry['cudnn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
  611. this._registry['cudnn.SpatialSoftMax'] = function(reader) { reader.nn(this); };
  612. this._registry['cudnn.Tanh'] = function(reader) { reader.nn(this); };
  613. this._registry['cudnn.VolumetricAveragePooling'] = function(reader) { reader.nn(this); };
  614. this._registry['cudnn.VolumetricBatchNormalization'] = function(reader) { reader.nn(this); };
  615. this._registry['cudnn.VolumetricConvolution'] = function(reader) { reader.nn(this); };
  616. this._registry['cudnn.VolumetricMaxPooling'] = function(reader) { reader.nn(this); };
  617. this._registry['Dict'] = function(reader) { reader.nn(this); };
  618. this._registry['inn.ConstAffine'] = function(reader) { reader.nn(this); };
  619. this._registry['inn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
  620. this._registry['nn.Abs'] = function(reader) { reader.nn(this); };
  621. this._registry['nn.AddConstant'] = function(reader) { reader.nn(this); };
  622. this._registry['nn.BatchNormalization'] = function(reader) { reader.nn(this); };
  623. this._registry['nn.BilinearSamplerBHWD'] = function(reader) { reader.nn(this); };
  624. this._registry['nn.BinActiveZ'] = function(reader) { reader.nn(this); }; // allenai/XNOR-Net
  625. this._registry['nn.BCECriterion'] = function(reader) { reader.nn(this); };
  626. this._registry['nn.Bottle'] = function(reader) { reader.nn(this); };
  627. this._registry['nn.Clamp'] = function(reader) { reader.nn(this); };
  628. this._registry['nn.CMul'] = function(reader) { reader.nn(this); };
  629. this._registry['nn.CAddTable'] = function(reader) { reader.nn(this); };
  630. this._registry['nn.CDivTable'] = function(reader) { reader.nn(this); };
  631. this._registry['nn.CMulTable'] = function(reader) { reader.nn(this); };
  632. this._registry['nn.CSubTable'] = function(reader) { reader.nn(this); };
  633. this._registry['nn.Concat'] = function(reader) { reader.nn(this); };
  634. this._registry['nn.Copy'] = function(reader) { reader.nn(this); };
  635. this._registry['nn.ConcatTable'] = function(reader) { reader.nn(this); };
  636. this._registry['nn.Contiguous'] = function(reader) { reader.nn(this); };
  637. this._registry['nn.Constant'] = function(reader) { reader.nn(this); };
  638. this._registry['nn.CostVolMulti'] = function(reader) { reader.nn(this); };
  639. this._registry['nn.DataParallelTable'] = function(reader) { reader.nn(this); };
  640. this._registry['nn.DepthConcat'] = function(reader) { reader.nn(this); };
  641. this._registry['nn.Dropout'] = function(reader) { reader.nn(this); };
  642. this._registry['nn.Exp'] = function(reader) { reader.nn(this); };
  643. this._registry['nn.ExpOut'] = function(reader) { reader.nn(this); };
  644. this._registry['nn.FlattenTable'] = function(reader) { reader.nn(this); };
  645. this._registry['nn.GenNoise'] = function(reader) { reader.nn(this); };
  646. this._registry['nn.Identity'] = function(reader) { reader.nn(this); };
  647. this._registry['nn.Index'] = function(reader) { reader.nn(this); };
  648. this._registry['nn.Inception'] = function(reader) { reader.nn(this); };
  649. this._registry['nn.InstanceNormalization'] = function(reader) { reader.nn(this); };
  650. this._registry['nn.JoinTable'] = function(reader) { reader.nn(this); };
  651. this._registry['nn.JointTrain'] = function(reader) { reader.nn(this); };
  652. this._registry['nn.KeypointCoordinate'] = function(reader) { reader.nn(this); };
  653. this._registry['nn.LeakyReLU'] = function(reader) { reader.nn(this); };
  654. this._registry['nn.Linear'] = function(reader) { reader.nn(this); };
  655. this._registry['nn.LinearNoBias'] = function(reader) { reader.nn(this); };
  656. this._registry['nn.LogSoftMax'] = function(reader) { reader.nn(this); };
  657. this._registry['nn.LookupTable'] = function(reader) { reader.nn(this); };
  658. this._registry['nn.LSTM'] = function(reader) { reader.nn(this); };
  659. this._registry['nn.MaskZero'] = function(reader) { reader.nn(this); };
  660. this._registry['nn.MapTable'] = function(reader) { reader.nn(this); };
  661. this._registry['nn.Max'] = function(reader) { reader.nn(this); };
  662. this._registry['nn.Mean'] = function(reader) { reader.nn(this); };
  663. this._registry['nn.Min'] = function(reader) { reader.nn(this); };
  664. this._registry['nn.MulConstant'] = function(reader) { reader.nn(this); };
  665. this._registry['nn.MM'] = function(reader) { reader.nn(this); };
  666. this._registry['nn.MSECriterion'] = function(reader) { reader.nn(this); };
  667. this._registry['nn.Narrow'] = function(reader) { reader.nn(this); };
  668. this._registry['nn.NarrowTable'] = function(reader) { reader.nn(this); };
  669. this._registry['nn.Normalize'] = function(reader) { reader.nn(this); };
  670. this._registry['nn.Normalize2'] = function(reader) { reader.nn(this); };
  671. this._registry['nn.NoiseFill'] = function(reader) { reader.nn(this); };
  672. this._registry['nn.Padding'] = function(reader) { reader.nn(this); };
  673. this._registry['nn.Parallel'] = function(reader) { reader.nn(this); };
  674. this._registry['nn.ParallelCriterion'] = function(reader) { reader.nn(this); };
  675. this._registry['nn.ParallelTable'] = function(reader) { reader.nn(this); };
  676. this._registry['nn.PixelShuffle'] = function(reader) { reader.nn(this); };
  677. this._registry['nn.Power'] = function(reader) { reader.nn(this); };
  678. this._registry['nn.PReLU'] = function(reader) { reader.nn(this); };
  679. this._registry['nn.Recursor'] = function(reader) { reader.nn(this); };
  680. this._registry['nn.ReLU'] = function(reader) { reader.nn(this); };
  681. this._registry['nn.Replicate'] = function(reader) { reader.nn(this); };
  682. this._registry['nn.Reshape'] = function(reader) { reader.nn(this); };
  683. this._registry['nn.ShaveImage'] = function(reader) { reader.nn(this); };
  684. this._registry['nn.Select'] = function(reader) { reader.nn(this); };
  685. this._registry['nn.SelectTable'] = function(reader) { reader.nn(this); };
  686. this._registry['nn.Sequencer'] = function(reader) { reader.nn(this); };
  687. this._registry['nn.Sequential'] = function(reader) { reader.nn(this); };
  688. this._registry['nn.Sigmoid'] = function(reader) { reader.nn(this); };
  689. this._registry['nn.Sum'] = function(reader) { reader.nn(this); };
  690. this._registry['nn.SoftMax'] = function(reader) { reader.nn(this); };
  691. this._registry['nn.SpatialAveragePooling'] = function(reader) { reader.nn(this); };
  692. this._registry['nn.SpatialBatchNormalization'] = function(reader) { reader.nn(this); };
  693. this._registry['nn.SpatialConvolution'] = function(reader) { reader.nn(this); };
  694. this._registry['nn.SpatialConvolutionMM'] = function(reader) { reader.nn(this); };
  695. this._registry['nn.SpatialCrossMapLRN'] = function(reader) { reader.nn(this); };
  696. this._registry['nn.SpatialDilatedConvolution'] = function(reader) { reader.nn(this); };
  697. this._registry['nn.SpatialDropout'] = function(reader) { reader.nn(this); };
  698. this._registry['nn.SpatialFractionalMaxPooling'] = function(reader) { reader.nn(this); };
  699. this._registry['nn.SpatialFullConvolution'] = function(reader) { reader.nn(this); };
  700. this._registry['nn.SpatialLPPooling'] = function(reader) { reader.nn(this); };
  701. this._registry['nn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
  702. this._registry['nn.SpatialMaxUnpooling'] = function(reader) { reader.nn(this); };
  703. this._registry['nn.SpatialReflectionPadding'] = function(reader) { reader.nn(this); };
  704. this._registry['nn.SpatialReplicationPadding'] = function(reader) { reader.nn(this); };
  705. this._registry['nn.SpatialSoftMax'] = function(reader) { reader.nn(this); };
  706. this._registry['nn.SpatialSubtractiveNormalization'] = function(reader) { reader.nn(this); };
  707. this._registry['nn.SpatialUpSamplingBilinear'] = function(reader) { reader.nn(this); };
  708. this._registry['nn.SpatialUpSamplingNearest'] = function(reader) { reader.nn(this); };
  709. this._registry['nn.SpatialZeroPadding'] = function(reader) { reader.nn(this); };
  710. this._registry['nn.SplitTable'] = function(reader) { reader.nn(this); };
  711. this._registry['nn.Squeeze'] = function(reader) { reader.nn(this); };
  712. this._registry['nn.Square'] = function(reader) { reader.nn(this); };
  713. this._registry['nn.Sqrt'] = function(reader) { reader.nn(this); };
  714. this._registry['nn.StereoJoin'] = function(reader) { reader.nn(this); };
  715. this._registry['nn.Tanh'] = function(reader) { reader.nn(this); };
  716. this._registry['nn.Transpose'] = function(reader) { reader.nn(this); };
  717. this._registry['nn.TotalVariation'] = function(reader) { reader.nn(this); };
  718. this._registry['nn.Unpool'] = function(reader) { reader.nn(this); };
  719. this._registry['nn.View'] = function(reader) { reader.nn(this); };
  720. this._registry['nn.gModule'] = function(reader) { reader.nn(this); };
  721. this._registry['nngraph.Node'] = function(reader) { reader.nn(this); };
  722. this._registry['graph.Edge'] = function(reader) { reader.nn(this); };
  723. this._registry['graph.Graph'] = function(reader) { reader.nn(this); };
  724. this._registry['torch.ByteTensor'] = function(reader) { reader.tensor(this, 'uint8'); };
  725. this._registry['torch.CharTensor'] = function(reader) { reader.tensor(this, 'int8'); };
  726. this._registry['torch.ShortTensor'] = function(reader) { reader.tensor(this, 'int16'); };
  727. this._registry['torch.IntTensor'] = function(reader) { reader.tensor(this, 'int32'); };
  728. this._registry['torch.LongTensor'] = function(reader) { reader.tensor(this, 'int64'); };
  729. this._registry['torch.FloatTensor'] = function(reader) { reader.tensor(this, 'float32'); };
  730. this._registry['torch.DoubleTensor'] = function(reader) { reader.tensor(this, 'float64'); };
  731. this._registry['torch.CudaByteTensor'] = function(reader) { reader.tensor(this, 'uint8'); };
  732. this._registry['torch.CudaCharTensor'] = function(reader) { reader.tensor(this, 'int8'); };
  733. this._registry['torch.CudaShortTensor'] = function(reader) { reader.tensor(this, 'int16'); };
  734. this._registry['torch.CudaIntTensor'] = function(reader) { reader.tensor(this, 'int32'); };
  735. this._registry['torch.CudaLongTensor'] = function(reader) { reader.tensor(this, 'int64'); };
  736. this._registry['torch.CudaTensor'] = function(reader) { reader.tensor(this, 'float32'); };
  737. this._registry['torch.CudaDoubleTensor'] = function(reader) { reader.tensor(this, 'float64'); };
  738. this._registry['torch.ByteStorage'] = function(reader) { reader.storage(this, 'uint8', 1); };
  739. this._registry['torch.CharStorage'] = function(reader) { reader.storage(this, 'int8', 1); };
  740. this._registry['torch.ShortStorage'] = function(reader) { reader.storage(this, 'int16', 2); };
  741. this._registry['torch.IntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
  742. this._registry['torch.LongStorage'] = function(reader) { reader.storage(this, 'int64', 8); };
  743. this._registry['torch.FloatStorage'] = function(reader) { reader.storage(this, 'float32', 4); };
  744. this._registry['torch.DoubleStorage'] = function(reader) { reader.storage(this, 'float64', 8); };
  745. this._registry['torch.CudaByteStorage'] = function(reader) { reader.storage(this, 'uint8', 1); };
  746. this._registry['torch.CudaCharStorage'] = function(reader) { reader.storage(this, 'int8', 1); };
  747. this._registry['torch.CudaShortStorage'] = function(reader) { reader.storage(this, 'int16', 2); };
  748. this._registry['torch.CudaIntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
  749. this._registry['torch.CudaLongStorage'] = function(reader) { reader.storage(this, 'int64', 8); };
  750. this._registry['torch.CudaIntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
  751. this._registry['torch.CudaStorage'] = function(reader) { reader.storage(this, 'float32', 4); };
  752. this._registry['torch.CudaFloatStorage'] = function(reader) { reader.storage(this, 'float64', 8); };
  753. this._registry['w2nn.AuxiliaryLossTable'] = function(reader) { reader.nn(this); };
  754. this._registry['w2nn.InplaceClip01'] = function(reader) { reader.nn(this); };
  755. this._registry['w2nn.ScaleTable'] = function(reader) { reader.nn(this); };
  756. if (buffer.length == 0) {
  757. throw new torch.Error('File is empty.');
  758. }
  759. if (buffer[0] <= 8) {
  760. this._reader = new torch.BinaryReader(buffer);
  761. }
  762. else {
  763. this._reader = new torch.TextReader(buffer);
  764. this._reader.int32();
  765. this._reader.reset();
  766. }
  767. }
  768. read() {
  769. const type = this.int32();
  770. switch (type) {
  771. case 0: return null;
  772. case 1: return this.float64();
  773. case 2: return this.string();
  774. case 3: return this.table();
  775. case 4: return this.object();
  776. case 5: return this.boolean();
  777. case 6: return this.function();
  778. case 7: return this.function();
  779. case 8: return this.function();
  780. default: throw new torch.Error("File format has invalid type '" + type + "'.");
  781. }
  782. }
  783. boolean() {
  784. return this._reader.boolean();
  785. }
  786. bytes(size) {
  787. return this._reader.bytes(size);
  788. }
  789. int32() {
  790. return this._reader.int32();
  791. }
  792. int64() {
  793. return this._reader.int64();
  794. }
  795. int64s(size) {
  796. return this._reader.int64s(size);
  797. }
  798. float64() {
  799. return this._reader.float64();
  800. }
  801. string() {
  802. return this._reader.string();
  803. }
  804. object() {
  805. const index = this.int32();
  806. if (this._memo.has(index)) {
  807. return this._memo.get(index);
  808. }
  809. let version = this.string();
  810. let name = null;
  811. if (version.startsWith('V ')) {
  812. name = this.string();
  813. version = Number(version.split(' ')[1]);
  814. }
  815. else {
  816. name = version;
  817. version = 0;
  818. }
  819. const obj = { __type__: name };
  820. this._memo.set(index, obj);
  821. let constructor = this._registry[name];
  822. if (constructor) {
  823. constructor.apply(obj, [ this, version ]);
  824. }
  825. else {
  826. constructor = this._callback(name);
  827. if (constructor) {
  828. constructor.apply(obj, [ this, version ]);
  829. }
  830. this.nn(obj);
  831. }
  832. return obj;
  833. }
  834. table() {
  835. const index = this.int32();
  836. if (this._memo.has(index)) {
  837. return this._memo.get(index);
  838. }
  839. const table = {};
  840. this._memo.set(index, table);
  841. const size = this.int32();
  842. let convert = true;
  843. let sum = 0;
  844. for (let i = 0; i < size; i++) {
  845. const key = this.read();
  846. const value = this.read();
  847. table[key] = value;
  848. if (Number.isInteger(key) && key >= 0) {
  849. sum += key;
  850. }
  851. else {
  852. convert = false;
  853. }
  854. }
  855. const n = Object.keys(table).length;
  856. if (convert && (n * (n + 1)) == (2 * sum)) {
  857. const list = [];
  858. for (let j = 0; j < n; j++) {
  859. let item = table[j + 1];
  860. if (item == table) {
  861. item = list;
  862. }
  863. list.push(item);
  864. }
  865. this._memo.set(index, list);
  866. return list;
  867. }
  868. return table;
  869. }
  870. function() {
  871. const index = this.int32();
  872. if (this._memo.has(index)) {
  873. return this._memo.get(index);
  874. }
  875. const size = this.int32();
  876. const dumped = this.bytes(size);
  877. const upvalues = this.read();
  878. const func = { __type__: 'Function', size: size, dumped: dumped, upvalues: upvalues };
  879. this._memo.set(index, func);
  880. return func;
  881. }
  882. nn(obj) {
  883. const attributes = this.read();
  884. if (attributes != null) {
  885. for (const key of Object.keys(attributes)) {
  886. obj[key] = attributes[key];
  887. }
  888. }
  889. }
  890. tensor(obj, dataType) {
  891. const dim = this.int32();
  892. obj.dataType = dataType;
  893. obj.size = this.int64s(dim);
  894. obj.stride = this.int64s(dim);
  895. obj.storage_offset = this.int64() - 1;
  896. obj.storage = this.read();
  897. }
  898. storage(obj, dataType, itemSize) {
  899. obj.dataType = dataType;
  900. obj.itemSize = itemSize;
  901. obj.size = this.int64();
  902. obj.reader = this._reader.storage(obj.size, obj.itemSize, dataType);
  903. obj.data = function() {
  904. if (this.reader) {
  905. const reader = this.reader;
  906. reader.reset();
  907. const size = obj.size;
  908. const array = new Array(size);
  909. for (let i = 0; i < size; i++) {
  910. switch (dataType) {
  911. case 'uint8':
  912. array[i] = this.reader.byte();
  913. break;
  914. case 'int8':
  915. array[i] = this.reader.int8();
  916. break;
  917. case 'int16':
  918. array[i] = this.reader.int16();
  919. break;
  920. case 'int32':
  921. array[i] = this.reader.int32();
  922. break;
  923. case 'int64':
  924. array[i] = this.reader.int64();
  925. break;
  926. case 'float32':
  927. array[i] = this.reader.float32();
  928. break;
  929. case 'float64':
  930. array[i] = this.reader.float64();
  931. break;
  932. }
  933. }
  934. obj._data = array;
  935. delete obj.reader;
  936. }
  937. return obj._data;
  938. };
  939. }
  940. };
  941. torch.BinaryReader = class {
  942. constructor(buffer) {
  943. this._buffer = buffer;
  944. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  945. this._position = 0;
  946. this._textDecoder = new TextDecoder('ascii');
  947. }
  948. reset() {
  949. this._position = 0;
  950. }
  951. skip(offset) {
  952. this._position += offset;
  953. if (this._position > this._buffer.length) {
  954. throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  955. }
  956. }
  957. boolean() {
  958. return this.int32() == 1;
  959. }
  960. bytes(length) {
  961. const position = this._position;
  962. this.skip(length);
  963. return this._buffer.subarray(position, this._position);
  964. }
  965. int8() {
  966. const position = this._position;
  967. this.skip(1);
  968. return this._dataView.getInt8(position, true);
  969. }
  970. int16() {
  971. const position = this._position;
  972. this.skip(2);
  973. return this._dataView.getInt16(position, true);
  974. }
  975. int32() {
  976. const position = this._position;
  977. this.skip(4);
  978. return this._dataView.getInt32(position, true);
  979. }
  980. int64() {
  981. const position = this._position;
  982. this.skip(8);
  983. return this._dataView.getInt64(position, true).toNumber();
  984. }
  985. int64s(size) {
  986. const array = [];
  987. for (let i = 0; i < size; i++) {
  988. array.push(this.int64());
  989. }
  990. return array;
  991. }
  992. float32() {
  993. const position = this._position;
  994. this.skip(4);
  995. return this._dataView.getFloat32(position, true);
  996. }
  997. float64() {
  998. const position = this._position;
  999. this.skip(8);
  1000. return this._dataView.getFloat64(position, true);
  1001. }
  1002. string() {
  1003. return this._textDecoder.decode(this.bytes(this.int32()));
  1004. }
  1005. storage(size, itemSize) {
  1006. return new torch.BinaryReader(this.bytes(size * itemSize));
  1007. }
  1008. };
  1009. torch.TextReader = class {
  1010. constructor(buffer, separator) {
  1011. this._buffer = buffer;
  1012. this._position = 0;
  1013. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  1014. this._textDecoder = new TextDecoder('ascii');
  1015. this._separator = separator || 0x0a;
  1016. }
  1017. reset() {
  1018. this._position = 0;
  1019. }
  1020. line(size) {
  1021. const start = this._position;
  1022. while (this._position < this._buffer.length && size > -1) {
  1023. const c = this._buffer[this._position++];
  1024. if (c == this._separator) {
  1025. return this._buffer.slice(start, this._position - 1);
  1026. }
  1027. else if (this._position == this._buffer.length) {
  1028. return this._buffer.slice(start, this._position);
  1029. }
  1030. size--;
  1031. }
  1032. throw new torch.Error('Line exceeded maximum length.');
  1033. }
  1034. boolean() {
  1035. return this.int32() == 1;
  1036. }
  1037. bytes(size) {
  1038. return this.line(size);
  1039. }
  1040. int8() {
  1041. return this.int64();
  1042. }
  1043. int16() {
  1044. return this.int64();
  1045. }
  1046. int32() {
  1047. return this.int64();
  1048. }
  1049. int64() {
  1050. const token = this._textDecoder.decode(this.line(20));
  1051. const number = Number.parseInt(token, 10);
  1052. if (Number.isNaN(token - number)) {
  1053. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1054. }
  1055. return number;
  1056. }
  1057. int64s(size) {
  1058. const array = [];
  1059. if (size > 0) {
  1060. const text = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  1061. for (const token of text.split(' ')) {
  1062. const number = Number.parseInt(token, 10);
  1063. if (Number.isNaN(token - number)) {
  1064. throw new torch.Error("Couldn't parse int64 '" + token + "'.");
  1065. }
  1066. array.push(number);
  1067. }
  1068. }
  1069. return array;
  1070. }
  1071. float32() {
  1072. return this.float64();
  1073. }
  1074. float64() {
  1075. const token = this._textDecoder.decode(this.line(24));
  1076. if (token.startsWith('-nan')) {
  1077. return -NaN;
  1078. }
  1079. if (token.startsWith('nan')) {
  1080. return NaN;
  1081. }
  1082. if (token.startsWith('inf')) {
  1083. return Infinity;
  1084. }
  1085. if (token.startsWith('-inf')) {
  1086. return -Infinity;
  1087. }
  1088. const number = Number.parseFloat(token);
  1089. if (Number.isNaN(token - number)) {
  1090. throw new torch.Error("Couldn't parse float '" + token + "'.");
  1091. }
  1092. return number;
  1093. }
  1094. string() {
  1095. const size = this.int32();
  1096. if (size == 0) {
  1097. return '';
  1098. }
  1099. const data = this.line(size);
  1100. const text = this._textDecoder.decode(data);
  1101. if (size != text.length) {
  1102. throw torch.Error('Invalid text length.');
  1103. }
  1104. return text;
  1105. }
  1106. storage(size, itemSize, dataType) {
  1107. if (size <= 0) {
  1108. throw new torch.Error("Unsupported storage size '" + size + "'.");
  1109. }
  1110. if (dataType === 'uint8') {
  1111. const start = this._position;
  1112. this._position += size;
  1113. const bytes = this._buffer.slice(start, this._position);
  1114. this.line(0);
  1115. return new torch.BinaryReader(bytes);
  1116. }
  1117. const data = this.line(Number.MAX_SAFE_INTEGER);
  1118. return new torch.TextReader(data, 0x20);
  1119. }
  1120. };
  1121. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1122. module.exports.ModelFactory = torch.ModelFactory;
  1123. }