torch.js 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138
  1. import * as base from './base.js';
  2. const torch = {};
  3. torch.ModelFactory = class {
  4. match(context) {
  5. const reader = torch.T7Reader.open(context);
  6. if (reader) {
  7. context.type = 'torch';
  8. context.target = reader;
  9. }
  10. }
  11. async open(context) {
  12. const metadata = await context.metadata('torch-metadata.json');
  13. const reader = context.target;
  14. reader.callback = (name) => {
  15. if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  16. context.exception(new torch.Error(`Unsupported type '${name}'.`));
  17. }
  18. return null;
  19. };
  20. const obj = reader.read();
  21. let graphs = [];
  22. if (obj && Array.isArray(obj) && obj.length >= 2 &&
  23. obj.slice(0, obj.length - 1).every((item) => item.__class__) &&
  24. !obj[obj.length - 1].__class__) {
  25. graphs = obj.slice(0, obj.length - 1);
  26. } else {
  27. graphs = [ obj ];
  28. }
  29. return new torch.Model(metadata, graphs);
  30. }
  31. };
  32. torch.Model = class {
  33. constructor(metadata, graphs) {
  34. this.format = 'Torch v7';
  35. this.graphs = graphs.map((graph, index) => new torch.Graph(metadata, index.toString(), graph));
  36. }
  37. };
  38. torch.Graph = class {
  39. constructor(metadata, name, root) {
  40. this.name = name;
  41. this.inputs = [];
  42. this.outputs = [];
  43. this.nodes = [];
  44. this.groups = 'false';
  45. const values = new Map();
  46. values.map = (name, type, tensor) => {
  47. if (name.length === 0 && tensor) {
  48. return new torch.Value(name, type || null, tensor || null);
  49. }
  50. if (!values.has(name)) {
  51. values.set(name, new torch.Value(name, type || null, tensor || null));
  52. } else if (type || tensor) {
  53. throw new torch.Error(`Duplicate value '${name}'.`);
  54. }
  55. return values.get(name);
  56. };
  57. if (Object.prototype.hasOwnProperty.call(root, 'model')) {
  58. root = root.model;
  59. }
  60. const loadModule = (metadata, module, groups, key, inputs, outputs) => {
  61. if (groups.length > 0) {
  62. this.groups = true;
  63. }
  64. const type = module.__class__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : '';
  65. switch (type) {
  66. case 'nn.Sequential': {
  67. groups.push(key);
  68. let subInputs = inputs;
  69. let subOutputs = [];
  70. const length = module.modules.length;
  71. let index = 0;
  72. for (const subModule of module.modules) {
  73. if (index == length - 1) {
  74. subOutputs = outputs;
  75. }
  76. loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  77. subInputs = subOutputs;
  78. subOutputs = [];
  79. index++;
  80. }
  81. groups.pop();
  82. break;
  83. }
  84. case 'nn.Parallel':
  85. case 'nn.ParallelTable':
  86. case 'nn.JointTrain': {
  87. groups.push(key);
  88. let newInputs = [];
  89. let newOutputs = [];
  90. let index = 0;
  91. for (const subModule of module.modules) {
  92. const subInputs = [].concat(inputs);
  93. const subOutputs = [].concat(outputs);
  94. loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
  95. if (inputs.length == 0) {
  96. newInputs = newInputs.concat(subInputs);
  97. }
  98. if (outputs.length == 0) {
  99. newOutputs = newOutputs.concat(subOutputs);
  100. }
  101. index++;
  102. }
  103. // inputs = inputs.concat(newInputs);
  104. for (const newOutput of newOutputs) {
  105. outputs.push(newOutput);
  106. }
  107. groups.pop();
  108. break;
  109. }
  110. case 'nn.Concat':
  111. case 'nn.ConcatTable': {
  112. const prefix = key;
  113. if (inputs.length == 0) {
  114. inputs.push(values.map(`${groups.join('/')}:${key}:in`, null, null));
  115. }
  116. let concatInputs = [];
  117. let index = 0;
  118. for (const subModule of module.modules) {
  119. const streamInputs = inputs.map((input) => input);
  120. const streamOutputs = [];
  121. loadModule(metadata, subModule, groups, `${prefix}.${index}`, streamInputs, streamOutputs);
  122. concatInputs = concatInputs.concat(streamOutputs);
  123. index++;
  124. }
  125. delete module.modules;
  126. delete module.dimension;
  127. const node = new torch.Node(metadata, module, groups, key, inputs, outputs, values);
  128. this.nodes.push(node);
  129. break;
  130. }
  131. case 'nn.Inception': {
  132. delete module.modules; // TODO
  133. delete module.module; // TODO
  134. delete module.transfer; // TODO
  135. delete module.pool; // TODO
  136. const node = new torch.Node(metadata, module, groups, key, inputs, outputs, values);
  137. this.nodes.push(node);
  138. break;
  139. }
  140. case 'nn.gModule': {
  141. /*
  142. let index = 0;
  143. for (const subModule of module.modules) {
  144. subModule.modules = [];
  145. this.loadModule(metadata, subModule, groups, index.toString(), [], []);
  146. index++;
  147. }
  148. */
  149. const node = new torch.Node(metadata, module, groups, key, inputs, outputs, values);
  150. this.nodes.push(node);
  151. break;
  152. }
  153. default: {
  154. const node = new torch.Node(metadata, module, groups, key, inputs, outputs, values);
  155. this.nodes.push(node);
  156. break;
  157. }
  158. }
  159. };
  160. const inputs = [];
  161. const outputs = [];
  162. loadModule(metadata, root, [], '', inputs, outputs);
  163. this.inputs = this.inputs.concat(inputs.map((input, index) => {
  164. return new torch.Argument(`input${index != 0 ? (index + 1).toString() : ''}`, [ input ]);
  165. }));
  166. this.outputs = this.outputs.concat(outputs.map((output, index) => {
  167. return new torch.Argument(`output${index != 0 ? (index + 1).toString() : ''}`, [ output ]);
  168. }));
  169. }
  170. };
  171. torch.Argument = class {
  172. constructor(name, value) {
  173. this.name = name;
  174. this.value = value;
  175. }
  176. };
  177. torch.Value = class {
  178. constructor(name, type, initializer) {
  179. if (typeof name !== 'string') {
  180. throw new torch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  181. }
  182. this.name = name;
  183. this.type = initializer ? initializer.type : type;
  184. this.initializer = initializer;
  185. }
  186. };
  187. torch.Node = class {
  188. constructor(metadata, module, groups, name, inputs, outputs, values) {
  189. this.group = groups.join('/');
  190. if (module.name && typeof module.name === 'string') {
  191. this.name = module.name;
  192. delete module.name;
  193. } else {
  194. this.name = this.group ? (`${this.group}:${name}`) : name;
  195. }
  196. const type = module.__class__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : 'nn.Module';
  197. this.type = metadata.type(type);
  198. let initializers = [];
  199. for (const [key, obj] of Object.entries(module)) {
  200. if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
  201. module[key] = obj.data();
  202. }
  203. }
  204. delete module.iSize;
  205. delete module.finput;
  206. delete module.fgradInput;
  207. delete module.output;
  208. delete module.gradInput;
  209. delete module.gradWeight;
  210. delete module.gradBias;
  211. delete module.grad_tmp;
  212. delete module.scaleT;
  213. delete module._input;
  214. delete module._output;
  215. delete module._gradInput;
  216. delete module._gradOutput;
  217. delete module.buffer;
  218. delete module.buffer2;
  219. delete module.tmp_in;
  220. delete module.tmp_out;
  221. delete module.accUpdateGradParameters;
  222. switch (this.type.name) {
  223. case 'nn.Linear':
  224. delete module.addBuffer;
  225. break;
  226. case 'nn.Normalize':
  227. case 'nn.Normalize2':
  228. delete module.addBuffer;
  229. delete module.normp;
  230. delete module.norm;
  231. break;
  232. case 'cudnn.SpatialConvolution':
  233. case 'cudnn.SpatialFullConvolution':
  234. case 'nn.SpatialConvolution':
  235. case 'nn.SpatialConvolutionMM':
  236. case 'nn.SpatialConvolution1_fw':
  237. case 'nn.SpatialDilatedConvolution':
  238. case 'nn.SpatialFullConvolution':
  239. delete module.ones;
  240. delete module.input_slice;
  241. delete module.output_slice;
  242. delete module.convDescData;
  243. this._updateSize(module, 'adj');
  244. this._updateSize(module, 'd');
  245. this._updateSize(module, 'dilation');
  246. this._updateSize(module, 'k');
  247. this._updateSize(module, 'pad');
  248. break;
  249. case 'cudnn.BatchNormalization':
  250. case 'cudnn.SpatialBatchNormalization':
  251. case 'nn.BatchNormalization':
  252. case 'nn.SpatialBatchNormalization':
  253. case 'nn.InstanceNormalization':
  254. delete module.save_mean;
  255. delete module.save_std;
  256. delete module.gradWeight;
  257. delete module.normalized;
  258. delete module.centered;
  259. delete module.bn; // TODO InstanceNormalization
  260. break;
  261. case 'nn.SpatialCrossMapLRN':
  262. delete module.scale;
  263. break;
  264. case 'cudnn.SpatialMaxPooling':
  265. case 'cudnn.SpatialAveragePooling':
  266. case 'inn.SpatialMaxPooling':
  267. case 'nn.SpatialMaxPooling':
  268. case 'nn.SpatialAveragePooling':
  269. delete module.indices;
  270. this._updateSize(module, 'pad');
  271. this._updateSize(module, 'd');
  272. this._updateSize(module, 'k');
  273. break;
  274. case 'nn.SpatialZeroPadding':
  275. case 'nn.SpatialReflectionPadding':
  276. case 'nn.SpatialReplicationPadding':
  277. this._updateBox(module, 'pad');
  278. break;
  279. case 'nn.Dropout':
  280. delete module.noise;
  281. break;
  282. case 'nn.gModule':
  283. delete module.forwardnodes;
  284. delete module.backwardnodes;
  285. break;
  286. case 'nn.StereoJoin':
  287. delete module.output_L;
  288. break;
  289. default:
  290. break;
  291. }
  292. this.attributes = [];
  293. if (module.__class__) {
  294. for (const [key, obj] of Object.entries(module)) {
  295. if (key == '_type') {
  296. continue;
  297. }
  298. if (Array.isArray(obj) && obj.every(((item) => item && item.__class__ && item.__class__.__module__ === 'nn'))) {
  299. continue;
  300. }
  301. if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
  302. initializers.push(new torch.Argument(key, [ values.map('', null, new torch.Tensor(obj)) ]));
  303. continue;
  304. }
  305. if (key == 'modules') {
  306. continue;
  307. }
  308. if (obj.__class__ && obj.__class__.__module__ !== '' && obj.__class__.__name__ != 'LuaFunction') {
  309. continue;
  310. }
  311. const attribute = new torch.Attribute(metadata, type, key, obj);
  312. this.attributes.push(attribute);
  313. }
  314. }
  315. this.inputs = [];
  316. if (inputs.length == 0 && this.name) {
  317. inputs.push(values.map(`${this.name}:in`));
  318. }
  319. this.inputs.push(new torch.Argument('input', inputs));
  320. if (outputs.length == 0 && this.name) {
  321. outputs.push(values.map(this.name));
  322. }
  323. this.outputs = [];
  324. this.outputs.push(new torch.Argument('output', outputs));
  325. initializers = initializers.filter((argument) => {
  326. if (argument.name == 'weight') {
  327. this.inputs.push(argument);
  328. return false;
  329. }
  330. return true;
  331. });
  332. initializers = initializers.filter((argument) => {
  333. if (argument.name == 'bias') {
  334. this.inputs.push(argument);
  335. return false;
  336. }
  337. return true;
  338. });
  339. this.inputs = this.inputs.concat(initializers);
  340. }
  341. _updateSize(module, name) {
  342. if (Object.prototype.hasOwnProperty.call(module, `${name}W`) &&
  343. Object.prototype.hasOwnProperty.call(module, `${name}H`)) {
  344. module[name] = [ module[`${name}W`], module[`${name}H`] ];
  345. delete module[`${name}W`];
  346. delete module[`${name}H`];
  347. }
  348. }
  349. _updateBox(module, name) {
  350. if (Object.prototype.hasOwnProperty.call(module, `${name}_t`) &&
  351. Object.prototype.hasOwnProperty.call(module, `${name}_r`) &&
  352. Object.prototype.hasOwnProperty.call(module, `${name}_b`) &&
  353. Object.prototype.hasOwnProperty.call(module, `${name}_l`)) {
  354. module[name] = [ module[`${name}_t`], module[`${name}_r`], module[`${name}_b`], module[`${name}_l`] ];
  355. delete module[`${name}_t`];
  356. delete module[`${name}_r`];
  357. delete module[`${name}_b`];
  358. delete module[`${name}_l`];
  359. }
  360. }
  361. };
  362. torch.Attribute = class {
  363. constructor(metadata, type, name, value) {
  364. this.name = name;
  365. this.value = value;
  366. if (name == 'train') {
  367. this.visible = false;
  368. }
  369. metadata = metadata.attribute(type, name);
  370. if (metadata) {
  371. if (metadata.visible === false) {
  372. this.visible = false;
  373. } else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  374. if (JSON.stringify(metadata.default) == JSON.stringify(this.value)) {
  375. this.visible = false;
  376. }
  377. }
  378. }
  379. }
  380. };
  381. torch.Tensor = class {
  382. constructor(tensor) {
  383. this.type = new torch.TensorType(tensor);
  384. this.encoding = '|';
  385. this._storage = tensor.storage;
  386. this._offset = tensor.storage_offset;
  387. }
  388. get values() {
  389. if (this.type.shape.dimensions.length === 0) {
  390. return [];
  391. }
  392. if (this._storage) {
  393. const data = this._storage.data();
  394. if (data) {
  395. const size = this.type.shape.dimensions.reduce((a, b) => a * b, 1);
  396. return data.slice(this._offset, this._offset + size);
  397. }
  398. }
  399. return null;
  400. }
  401. };
  402. torch.TensorType = class {
  403. constructor(tensor) {
  404. this.dataType = tensor.dataType;
  405. this.shape = new torch.TensorShape(tensor.size);
  406. }
  407. toString() {
  408. return (this.dataType || '?') + this.shape.toString();
  409. }
  410. };
  411. torch.TensorShape = class {
  412. constructor(dimensions) {
  413. this.dimensions = dimensions;
  414. }
  415. toString() {
  416. if (this.dimensions) {
  417. if (this.dimensions.length == 0) {
  418. return '';
  419. }
  420. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  421. }
  422. return '';
  423. }
  424. };
  425. torch.Error = class extends Error {
  426. constructor(message) {
  427. super(message);
  428. this.name = 'Error loading Torch model.';
  429. }
  430. };
  431. torch.T7Reader = class {
  432. static open(context) {
  433. const stream = context.stream;
  434. if (stream && stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index == 0 && value <= 0x08))) {
  435. const reader = new torch.BinaryReader(stream);
  436. return new torch.T7Reader(reader);
  437. }
  438. if (stream && stream.length >= 2) {
  439. const buffer = stream.peek(2);
  440. const value = String.fromCharCode(stream.peek(1)[0]);
  441. if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
  442. const reader = new torch.TextReader(stream);
  443. return new torch.T7Reader(reader);
  444. }
  445. }
  446. return null;
  447. }
  448. constructor(reader) {
  449. this._reader = reader;
  450. this._memo = new Map();
  451. this._types = new Map();
  452. const Storage = class {
  453. constructor(dataType, itemSize) {
  454. this.dataType = dataType;
  455. this.itemSize = itemSize;
  456. }
  457. read(reader) {
  458. this.size = reader.int64();
  459. this.reader = reader.storage(this.size, this.itemSize, this.dataType);
  460. }
  461. data() {
  462. if (this.reader) {
  463. const reader = this.reader;
  464. reader.seek(0);
  465. const dataType = this.dataType;
  466. const size = this.size;
  467. const array = new Array(size);
  468. for (let i = 0; i < size; i++) {
  469. switch (dataType) {
  470. case 'uint8':
  471. array[i] = reader.byte();
  472. break;
  473. case 'int8':
  474. array[i] = reader.int8();
  475. break;
  476. case 'int16':
  477. array[i] = reader.int16();
  478. break;
  479. case 'int32':
  480. array[i] = reader.int32();
  481. break;
  482. case 'int64':
  483. array[i] = reader.int64();
  484. break;
  485. case 'float32':
  486. array[i] = reader.float32();
  487. break;
  488. case 'float64':
  489. array[i] = reader.float64();
  490. break;
  491. default:
  492. throw new torch.Error(`Unsupported data type '${dataType}'.`);
  493. }
  494. }
  495. this._data = array;
  496. delete this.reader;
  497. }
  498. return this._data;
  499. }
  500. };
  501. const Tensor = class {
  502. constructor(dataType) {
  503. this.dataType = dataType;
  504. }
  505. read(reader) {
  506. const dim = reader.int32();
  507. this.size = reader.int64s(dim);
  508. this.stride = reader.int64s(dim);
  509. this.storage_offset = reader.int64() - 1;
  510. this.storage = reader.read();
  511. }
  512. };
  513. this.register('bnn.Binary');
  514. this.register('bnn.SpatialConvolution');
  515. this.register('cudnn.BatchNormalization');
  516. this.register('cudnn.BatchBRNNReLU');
  517. this.register('cudnn.BLSTM');
  518. this.register('cudnn.ReLU');
  519. this.register('cudnn.RNN');
  520. this.register('cudnn.Sigmoid');
  521. this.register('cudnn.SoftMax');
  522. this.register('cudnn.LogSoftMax');
  523. this.register('cudnn.normal3DConv');
  524. this.register('cudnn.normal3DdeConv');
  525. this.register('cudnn.SpatialAveragePooling');
  526. this.register('cudnn.SpatialBatchNormalization');
  527. this.register('cudnn.SpatialConvolution');
  528. this.register('cudnn.SpatialFullConvolution');
  529. this.register('cudnn.SpatialMaxPooling');
  530. this.register('cudnn.SpatialSoftMax');
  531. this.register('cudnn.Tanh');
  532. this.register('cudnn.VolumetricAveragePooling');
  533. this.register('cudnn.VolumetricBatchNormalization');
  534. this.register('cudnn.VolumetricConvolution');
  535. this.register('cudnn.VolumetricMaxPooling');
  536. this.register('Dict');
  537. this.register('inn.ConstAffine');
  538. this.register('inn.SpatialMaxPooling');
  539. this.register('nn.Abs');
  540. this.register('nn.AddConstant');
  541. this.register('nn.BatchNormalization');
  542. this.register('nn.BilinearSamplerBHWD');
  543. this.register('nn.BinActiveZ'); // allenai/XNOR-Net
  544. this.register('nn.BCECriterion');
  545. this.register('nn.Bottle');
  546. this.register('nn.Clamp');
  547. this.register('nn.CMul');
  548. this.register('nn.CAddTable');
  549. this.register('nn.CDivTable');
  550. this.register('nn.CMulTable');
  551. this.register('nn.CSubTable');
  552. this.register('nn.Concat');
  553. this.register('nn.Copy');
  554. this.register('nn.ConcatTable');
  555. this.register('nn.Contiguous');
  556. this.register('nn.Constant');
  557. this.register('nn.CostVolMulti');
  558. this.register('nn.DataParallelTable');
  559. this.register('nn.DepthConcat');
  560. this.register('nn.Dropout');
  561. this.register('nn.Exp');
  562. this.register('nn.ExpOut');
  563. this.register('nn.FlattenTable');
  564. this.register('nn.GenNoise');
  565. this.register('nn.Identity');
  566. this.register('nn.Index');
  567. this.register('nn.Inception');
  568. this.register('nn.InstanceNormalization');
  569. this.register('nn.JoinTable');
  570. this.register('nn.JointTrain');
  571. this.register('nn.KeypointCoordinate');
  572. this.register('nn.LeakyReLU');
  573. this.register('nn.Linear');
  574. this.register('nn.LinearNoBias');
  575. this.register('nn.LogSoftMax');
  576. this.register('nn.LookupTable');
  577. this.register('nn.LSTM');
  578. this.register('nn.MaskZero');
  579. this.register('nn.MapTable');
  580. this.register('nn.Max');
  581. this.register('nn.Mean');
  582. this.register('nn.Min');
  583. this.register('nn.MulConstant');
  584. this.register('nn.MM');
  585. this.register('nn.MSECriterion');
  586. this.register('nn.Narrow');
  587. this.register('nn.NarrowTable');
  588. this.register('nn.Normalize');
  589. this.register('nn.Normalize2');
  590. this.register('nn.NoiseFill');
  591. this.register('nn.Padding');
  592. this.register('nn.Parallel');
  593. this.register('nn.ParallelCriterion');
  594. this.register('nn.ParallelTable');
  595. this.register('nn.PixelShuffle');
  596. this.register('nn.Power');
  597. this.register('nn.PReLU');
  598. this.register('nn.Recursor');
  599. this.register('nn.ReLU');
  600. this.register('nn.Replicate');
  601. this.register('nn.Reshape');
  602. this.register('nn.ShaveImage');
  603. this.register('nn.Select');
  604. this.register('nn.SelectTable');
  605. this.register('nn.Sequencer');
  606. this.register('nn.Sequential');
  607. this.register('nn.Sigmoid');
  608. this.register('nn.Sum');
  609. this.register('nn.SoftMax');
  610. this.register('nn.SpatialAveragePooling');
  611. this.register('nn.SpatialBatchNormalization');
  612. this.register('nn.SpatialConvolution');
  613. this.register('nn.SpatialConvolution1_fw');
  614. this.register('nn.SpatialConvolutionMM');
  615. this.register('nn.SpatialCrossMapLRN');
  616. this.register('nn.SpatialDilatedConvolution');
  617. this.register('nn.SpatialDropout');
  618. this.register('nn.SpatialFractionalMaxPooling');
  619. this.register('nn.SpatialFullConvolution');
  620. this.register('nn.SpatialLPPooling');
  621. this.register('nn.SpatialMaxPooling');
  622. this.register('nn.SpatialMaxUnpooling');
  623. this.register('nn.SpatialReflectionPadding');
  624. this.register('nn.SpatialReplicationPadding');
  625. this.register('nn.SpatialSoftMax');
  626. this.register('nn.SpatialSubtractiveNormalization');
  627. this.register('nn.SpatialUpSamplingBilinear');
  628. this.register('nn.SpatialUpSamplingNearest');
  629. this.register('nn.SpatialZeroPadding');
  630. this.register('nn.SplitTable');
  631. this.register('nn.Squeeze');
  632. this.register('nn.Square');
  633. this.register('nn.Sqrt');
  634. this.register('nn.StereoJoin');
  635. this.register('nn.Tanh');
  636. this.register('nn.Transpose');
  637. this.register('nn.TotalVariation');
  638. this.register('nn.Unpool');
  639. this.register('nn.View');
  640. this.register('nn.gModule');
  641. this.register('nngraph.Node');
  642. this.register('graph.Edge');
  643. this.register('graph.Graph');
  644. this.register('torch.ByteTensor', class extends Tensor {
  645. constructor() {
  646. super('uint8');
  647. }
  648. });
  649. this.register('torch.CharTensor', class extends Tensor {
  650. constructor() {
  651. super('int8');
  652. }
  653. });
  654. this.register('torch.ShortTensor', class extends Tensor {
  655. constructor() {
  656. super('int16');
  657. }
  658. });
  659. this.register('torch.IntTensor', class extends Tensor {
  660. constructor() {
  661. super('int32');
  662. }
  663. });
  664. this.register('torch.LongTensor', class extends Tensor {
  665. constructor() {
  666. super('int64');
  667. }
  668. });
  669. this.register('torch.FloatTensor', class extends Tensor {
  670. constructor() {
  671. super('float32');
  672. }
  673. });
  674. this.register('torch.DoubleTensor', class extends Tensor {
  675. constructor() {
  676. super('float64');
  677. }
  678. });
  679. this.register('torch.CudaByteTensor', class extends Tensor {
  680. constructor() {
  681. super('uint8');
  682. }
  683. });
  684. this.register('torch.CudaCharTensor', class extends Tensor {
  685. constructor() {
  686. super('int8');
  687. }
  688. });
  689. this.register('torch.CudaShortTensor', class extends Tensor {
  690. constructor() {
  691. super('int16');
  692. }
  693. });
  694. this.register('torch.CudaIntTensor', class extends Tensor {
  695. constructor() {
  696. super('int32');
  697. }
  698. });
  699. this.register('torch.CudaLongTensor', class extends Tensor {
  700. constructor() {
  701. super('int64');
  702. }
  703. });
  704. this.register('torch.CudaTensor', class extends Tensor {
  705. constructor() {
  706. super('float32');
  707. }
  708. });
  709. this.register('torch.CudaDoubleTensor', class extends Tensor {
  710. constructor() {
  711. super('float64');
  712. }
  713. });
  714. this.register('torch.ByteStorage', class extends Storage {
  715. constructor() {
  716. super('uint8', 1);
  717. }
  718. });
  719. this.register('torch.CharStorage', class extends Storage {
  720. constructor() {
  721. super('int8', 1);
  722. }
  723. });
  724. this.register('torch.ShortStorage', class extends Storage {
  725. constructor() {
  726. super('int16', 2);
  727. }
  728. });
  729. this.register('torch.IntStorage', class extends Storage {
  730. constructor() {
  731. super('int32', 4);
  732. }
  733. });
  734. this.register('torch.LongStorage', class extends Storage {
  735. constructor() {
  736. super('int64', 8);
  737. }
  738. });
  739. this.register('torch.FloatStorage', class extends Storage {
  740. constructor() {
  741. super('float32', 4);
  742. }
  743. });
  744. this.register('torch.DoubleStorage', class extends Storage {
  745. constructor() {
  746. super('float64', 8);
  747. }
  748. });
  749. this.register('torch.CudaByteStorage', class extends Storage {
  750. constructor() {
  751. super('uint8', 1);
  752. }
  753. });
  754. this.register('torch.CudaCharStorage', class extends Storage {
  755. constructor() {
  756. super('int8', 1);
  757. }
  758. });
  759. this.register('torch.CudaShortStorage', class extends Storage {
  760. constructor() {
  761. super('int16', 2);
  762. }
  763. });
  764. this.register('torch.CudaIntStorage', class extends Storage {
  765. constructor() {
  766. super('int32', 4);
  767. }
  768. });
  769. this.register('torch.CudaLongStorage', class extends Storage {
  770. constructor() {
  771. super('int64', 8);
  772. }
  773. });
  774. this.register('torch.CudaIntStorage', class extends Storage {
  775. constructor() {
  776. super('int32', 4);
  777. }
  778. });
  779. this.register('torch.CudaStorage', class extends Storage {
  780. constructor() {
  781. super('float32', 4);
  782. }
  783. });
  784. this.register('torch.CudaFloatStorage', class extends Storage {
  785. constructor() {
  786. super('float64', 8);
  787. }
  788. });
  789. this.register('w2nn.AuxiliaryLossTable');
  790. this.register('w2nn.InplaceClip01');
  791. this.register('w2nn.ScaleTable');
  792. this.register('LuaFunction', class {
  793. constructor(size, dumped, upvalues) {
  794. this.size = size;
  795. this.dumped = dumped;
  796. this.upvalues = upvalues;
  797. }
  798. });
  799. }
  800. register(name, type) {
  801. type = type || class {};
  802. const parts = name.split('.');
  803. type.__name__ = parts.pop();
  804. type.__module__ = parts.join('.');
  805. type.prototype.__class__ = type;
  806. this._types.set(name, type);
  807. }
  808. read() {
  809. const type = this.int32();
  810. switch (type) {
  811. case 0: return null;
  812. case 1: return this.float64();
  813. case 2: return this.string();
  814. case 3: return this.table();
  815. case 4: return this.object();
  816. case 5: return this.boolean();
  817. case 6: return this.function();
  818. case 7: return this.function();
  819. case 8: return this.function();
  820. default: throw new torch.Error(`File format has invalid type '${type}'.`);
  821. }
  822. }
  823. boolean() {
  824. return this._reader.boolean();
  825. }
  826. int32() {
  827. return this._reader.int32();
  828. }
  829. int64() {
  830. return this._reader.int64();
  831. }
  832. int64s(size) {
  833. return this._reader.int64s(size);
  834. }
  835. float64() {
  836. return this._reader.float64();
  837. }
  838. string() {
  839. return this._reader.string();
  840. }
  841. object() {
  842. const index = this.int32();
  843. if (this._memo.has(index)) {
  844. return this._memo.get(index);
  845. }
  846. let version = this.string();
  847. let name = null;
  848. if (version.startsWith('V ')) {
  849. name = this.string();
  850. version = Number(version.split(' ')[1]);
  851. } else {
  852. name = version;
  853. version = 0;
  854. }
  855. if (!this._types.has(name)) {
  856. this.callback(name);
  857. this.register(name);
  858. }
  859. const type = this._types.get(name);
  860. const obj = Reflect.construct(type, []);
  861. this._memo.set(index, obj);
  862. if (obj.read) {
  863. obj.read(this, version);
  864. } else {
  865. const attributes = this.read();
  866. if (attributes != null) {
  867. for (const [key, value] of Object.entries(attributes)) {
  868. obj[key] = value;
  869. }
  870. }
  871. }
  872. return obj;
  873. }
  874. table() {
  875. const index = this.int32();
  876. if (this._memo.has(index)) {
  877. return this._memo.get(index);
  878. }
  879. const table = {};
  880. this._memo.set(index, table);
  881. const size = this.int32();
  882. let convert = true;
  883. let sum = 0;
  884. for (let i = 0; i < size; i++) {
  885. const key = this.read();
  886. const value = this.read();
  887. table[key] = value;
  888. if (Number.isInteger(key) && key >= 0) {
  889. sum += key;
  890. } else {
  891. convert = false;
  892. }
  893. }
  894. const n = Object.keys(table).length;
  895. if (convert && (n * (n + 1)) == (2 * sum)) {
  896. const list = [];
  897. for (let j = 0; j < n; j++) {
  898. let item = table[j + 1];
  899. if (item == table) {
  900. item = list;
  901. }
  902. list.push(item);
  903. }
  904. this._memo.set(index, list);
  905. return list;
  906. }
  907. return table;
  908. }
  909. function() {
  910. const index = this.int32();
  911. if (this._memo.has(index)) {
  912. return this._memo.get(index);
  913. }
  914. const size = this.int32();
  915. const dumped = this._reader.read(size);
  916. const upvalues = this.read();
  917. const type = this._types.get('LuaFunction');
  918. const obj = Reflect.construct(type, [ size, dumped, upvalues ]);
  919. this._memo.set(index, obj);
  920. return obj;
  921. }
  922. storage(size, itemSize, dataType) {
  923. return this._reader.storage(size, itemSize, dataType);
  924. }
  925. };
  926. torch.BinaryReader = class extends base.BinaryReader {
  927. constructor(data) {
  928. super(data instanceof Uint8Array ? data : data.peek(), true);
  929. this._textDecoder = new TextDecoder('ascii');
  930. }
  931. boolean() {
  932. return this.int32() == 1;
  933. }
  934. read(length) {
  935. const position = this._position;
  936. this.skip(length);
  937. return this._buffer.subarray(position, this._position);
  938. }
  939. int64s(size) {
  940. const array = [];
  941. for (let i = 0; i < size; i++) {
  942. array.push(this.int64());
  943. }
  944. return array;
  945. }
  946. string() {
  947. const size = this.int32();
  948. const buffer = this.read(size);
  949. return this._textDecoder.decode(buffer);
  950. }
  951. storage(size, itemSize) {
  952. const buffer = this.read(size * itemSize);
  953. return new torch.BinaryReader(buffer);
  954. }
  955. };
  956. torch.TextReader = class {
  957. constructor(data, separator) {
  958. this._buffer = data instanceof Uint8Array ? data : data.peek();
  959. this._position = 0;
  960. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  961. this._textDecoder = new TextDecoder('ascii');
  962. this._separator = separator || 0x0a;
  963. }
  964. seek(position) {
  965. this._position = position;
  966. }
  967. line(size) {
  968. const start = this._position;
  969. while (this._position < this._buffer.length && size > -1) {
  970. const c = this._buffer[this._position++];
  971. if (c == this._separator) {
  972. return this._buffer.slice(start, this._position - 1);
  973. } else if (this._position == this._buffer.length) {
  974. return this._buffer.slice(start, this._position);
  975. }
  976. size--;
  977. }
  978. throw new torch.Error('Line exceeded maximum length.');
  979. }
  980. boolean() {
  981. return this.int32() == 1;
  982. }
  983. read(size) {
  984. return this.line(size);
  985. }
  986. int8() {
  987. return this.int64();
  988. }
  989. int16() {
  990. return this.int64();
  991. }
  992. int32() {
  993. return this.int64();
  994. }
  995. int64() {
  996. const token = this._textDecoder.decode(this.line(20));
  997. const number = Number.parseInt(token, 10);
  998. if (Number.isNaN(token - number)) {
  999. throw new torch.Error(`Couldn't parse int64 '${token}'.`);
  1000. }
  1001. return number;
  1002. }
  1003. int64s(size) {
  1004. const array = [];
  1005. if (size > 0) {
  1006. const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  1007. for (const token of content.split(' ')) {
  1008. const number = Number.parseInt(token, 10);
  1009. if (Number.isNaN(token - number)) {
  1010. throw new torch.Error(`Couldn't parse int64 '${token}'.`);
  1011. }
  1012. array.push(number);
  1013. }
  1014. }
  1015. return array;
  1016. }
  1017. float32() {
  1018. return this.float64();
  1019. }
  1020. float64() {
  1021. const token = this._textDecoder.decode(this.line(24));
  1022. if (token.startsWith('-nan')) {
  1023. return -NaN;
  1024. }
  1025. if (token.startsWith('nan')) {
  1026. return NaN;
  1027. }
  1028. if (token.startsWith('inf')) {
  1029. return Infinity;
  1030. }
  1031. if (token.startsWith('-inf')) {
  1032. return -Infinity;
  1033. }
  1034. const number = Number.parseFloat(token);
  1035. if (Number.isNaN(token - number)) {
  1036. throw new torch.Error(`Couldn't parse float '${token}'.`);
  1037. }
  1038. return number;
  1039. }
  1040. string() {
  1041. const size = this.int32();
  1042. if (size == 0) {
  1043. return '';
  1044. }
  1045. const data = this.line(size);
  1046. const content = this._textDecoder.decode(data);
  1047. if (size != content.length) {
  1048. throw new torch.Error('Invalid string length.');
  1049. }
  1050. return content;
  1051. }
  1052. storage(size, itemSize, dataType) {
  1053. if (size <= 0) {
  1054. throw new torch.Error(`Unsupported storage size '${size}'.`);
  1055. }
  1056. if (dataType === 'uint8') {
  1057. const start = this._position;
  1058. this._position += size;
  1059. const bytes = this._buffer.slice(start, this._position);
  1060. this.line(0);
  1061. return new torch.BinaryReader(bytes);
  1062. }
  1063. const data = this.line(Number.MAX_SAFE_INTEGER);
  1064. return new torch.TextReader(data, 0x20);
  1065. }
  1066. };
  1067. export const ModelFactory = torch.ModelFactory;