torch.js 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. import * as base from './base.js';
  2. const torch = {};
  3. torch.ModelFactory = class {
  4. async match(context) {
  5. const reader = torch.T7Reader.open(context);
  6. if (reader) {
  7. return context.set('torch', reader);
  8. }
  9. return null;
  10. }
  11. async open(context) {
  12. const metadata = await context.metadata('torch-metadata.json');
  13. const reader = context.value;
  14. reader.callback = (name) => {
  15. if (name && name !== 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
  16. context.error(new torch.Error(`Unsupported type '${name}'.`));
  17. }
  18. return null;
  19. };
  20. const obj = reader.read();
  21. let graphs = [];
  22. if (obj && Array.isArray(obj) && obj.length >= 2 &&
  23. obj.slice(0, obj.length - 1).every((item) => item.__class__) &&
  24. !obj[obj.length - 1].__class__) {
  25. graphs = obj.slice(0, obj.length - 1);
  26. } else {
  27. graphs = [obj];
  28. }
  29. return new torch.Model(metadata, graphs);
  30. }
  31. };
  32. torch.Model = class {
  33. constructor(metadata, graphs) {
  34. this.format = 'Torch v7';
  35. this.modules = graphs.map((graph, index) => new torch.Graph(metadata, index.toString(), graph));
  36. }
  37. };
  38. torch.Graph = class {
  39. constructor(metadata, name, module) {
  40. this.name = name;
  41. this.inputs = [];
  42. this.outputs = [];
  43. this.nodes = [];
  44. this.groups = 'false';
  45. const values = new Map();
  46. values.map = (name, type, tensor) => {
  47. if (name.length === 0 && tensor) {
  48. return new torch.Value(name, type || null, tensor || null);
  49. }
  50. if (!values.has(name)) {
  51. values.set(name, new torch.Value(name, type || null, tensor || null));
  52. } else if (type || tensor) {
  53. throw new torch.Error(`Duplicate value '${name}'.`);
  54. }
  55. return values.get(name);
  56. };
  57. const node = new torch.Node(metadata, module, '', values);
  58. this.nodes.push(node);
  59. }
  60. };
  61. torch.Argument = class {
  62. constructor(name, value, type = null, visible = true) {
  63. this.name = name;
  64. this.value = value;
  65. this.type = type;
  66. this.visible = visible;
  67. }
  68. };
  69. torch.Value = class {
  70. constructor(name, type, initializer) {
  71. if (typeof name !== 'string') {
  72. throw new torch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  73. }
  74. this.name = name;
  75. this.type = initializer ? initializer.type : type;
  76. this.initializer = initializer;
  77. }
  78. };
  79. torch.Node = class {
  80. constructor(metadata, module, name, values) {
  81. this.name = name;
  82. this.inputs = [];
  83. this.outputs = [];
  84. const type = module.__class__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : 'nn.Module';
  85. this.type = metadata.type(type);
  86. for (const [key, obj] of Object.entries(module)) {
  87. if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
  88. module[key] = obj.data();
  89. }
  90. }
  91. delete module.iSize;
  92. delete module.finput;
  93. delete module.fgradInput;
  94. delete module.output;
  95. delete module.gradInput;
  96. delete module.gradWeight;
  97. delete module.gradBias;
  98. delete module.grad_tmp;
  99. delete module.scaleT;
  100. delete module._input;
  101. delete module._output;
  102. delete module._gradInput;
  103. delete module._gradOutput;
  104. delete module.buffer;
  105. delete module.buffer2;
  106. delete module.tmp_in;
  107. delete module.tmp_out;
  108. delete module.accUpdateGradParameters;
  109. this.attributes = [];
  110. for (const [name, obj] of Object.entries(module)) {
  111. if (name === '_type') {
  112. continue;
  113. }
  114. if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
  115. const argument = new torch.Argument(name, [values.map('', null, new torch.Tensor(obj))]);
  116. this.inputs.push(argument);
  117. } else if (Array.isArray(obj) && obj.every((item) => item && item.__class__)) {
  118. const nodes = obj.map((module) => new torch.Node(metadata, module, '', values));
  119. const argument = new torch.Argument(name, nodes, 'object[]');
  120. this.inputs.push(argument);
  121. } else if ((Array.isArray(obj) && obj.every((obj) => typeof obj === 'number' || typeof obj === 'string' || typeof obj === 'boolean')) ||
  122. typeof obj === 'number' || typeof obj === 'string' || typeof obj === 'boolean') {
  123. let visible = name === 'train' ? false : true;
  124. const schema = metadata.attribute(type, name);
  125. if (schema) {
  126. if (schema.visible === false) {
  127. visible = false;
  128. } else if (schema.default !== undefined && Object.prototype.hasOwnProperty.call(schema, 'default')) {
  129. visible = false;
  130. }
  131. }
  132. const attribute = new torch.Argument(name, obj, 'attribute', visible);
  133. this.inputs.push(attribute);
  134. } else if (obj) {
  135. const node = new torch.Node(metadata, obj, '', values);
  136. const argument = new torch.Argument(name, node, 'object');
  137. this.inputs.push(argument);
  138. } else {
  139. throw new torch.Error(`Invalid input value '${name}'.`);
  140. }
  141. }
  142. }
  143. _updateSize(module, name) {
  144. if (Object.prototype.hasOwnProperty.call(module, `${name}W`) &&
  145. Object.prototype.hasOwnProperty.call(module, `${name}H`)) {
  146. module[name] = [module[`${name}W`], module[`${name}H`]];
  147. delete module[`${name}W`];
  148. delete module[`${name}H`];
  149. }
  150. }
  151. _updateBox(module, name) {
  152. if (Object.prototype.hasOwnProperty.call(module, `${name}_t`) &&
  153. Object.prototype.hasOwnProperty.call(module, `${name}_r`) &&
  154. Object.prototype.hasOwnProperty.call(module, `${name}_b`) &&
  155. Object.prototype.hasOwnProperty.call(module, `${name}_l`)) {
  156. module[name] = [module[`${name}_t`], module[`${name}_r`], module[`${name}_b`], module[`${name}_l`]];
  157. delete module[`${name}_t`];
  158. delete module[`${name}_r`];
  159. delete module[`${name}_b`];
  160. delete module[`${name}_l`];
  161. }
  162. }
  163. };
  164. torch.Tensor = class {
  165. constructor(tensor) {
  166. this.type = new torch.TensorType(tensor);
  167. this.encoding = '|';
  168. this._storage = tensor.storage;
  169. this._offset = tensor.storage_offset;
  170. }
  171. get values() {
  172. if (this.type.shape.dimensions.length === 0) {
  173. return [];
  174. }
  175. if (this._storage) {
  176. const data = this._storage.data();
  177. if (data) {
  178. const size = this.type.shape.dimensions.reduce((a, b) => a * Number(b), 1);
  179. return data.slice(this._offset, this._offset + size);
  180. }
  181. }
  182. return null;
  183. }
  184. };
  185. torch.TensorType = class {
  186. constructor(tensor) {
  187. this.dataType = tensor.dataType;
  188. this.shape = new torch.TensorShape(tensor.size);
  189. }
  190. toString() {
  191. return (this.dataType || '?') + this.shape.toString();
  192. }
  193. };
  194. torch.TensorShape = class {
  195. constructor(dimensions) {
  196. this.dimensions = dimensions;
  197. }
  198. toString() {
  199. if (this.dimensions) {
  200. if (this.dimensions.length === 0) {
  201. return '';
  202. }
  203. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  204. }
  205. return '';
  206. }
  207. };
  208. torch.T7Reader = class {
  209. static open(context) {
  210. const stream = context.stream;
  211. if (stream && stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index === 0 && value <= 0x08))) {
  212. const reader = new torch.BinaryReader(stream);
  213. return new torch.T7Reader(reader);
  214. }
  215. if (stream && stream.length >= 2) {
  216. const buffer = stream.peek(2);
  217. const value = String.fromCharCode(stream.peek(1)[0]);
  218. if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
  219. const reader = new torch.TextReader(stream);
  220. return new torch.T7Reader(reader);
  221. }
  222. }
  223. return null;
  224. }
  225. constructor(reader) {
  226. // https://github.com/torch/torch7
  227. // https://github.com/torch/nngraph
  228. this._reader = reader;
  229. this._memo = new Map();
  230. this._types = new Map();
  231. const Storage = class {
  232. constructor(dataType, itemSize) {
  233. this.dataType = dataType;
  234. this.itemSize = itemSize;
  235. }
  236. read(reader) {
  237. this.size = reader.int64();
  238. this.reader = reader.storage(this.size, this.itemSize, this.dataType);
  239. }
  240. data() {
  241. if (this.reader) {
  242. const reader = this.reader;
  243. reader.seek(0);
  244. const dataType = this.dataType;
  245. const size = this.size;
  246. const array = new Array(size);
  247. for (let i = 0; i < size; i++) {
  248. switch (dataType) {
  249. case 'uint8':
  250. array[i] = reader.byte();
  251. break;
  252. case 'int8':
  253. array[i] = reader.int8();
  254. break;
  255. case 'int16':
  256. array[i] = reader.int16();
  257. break;
  258. case 'int32':
  259. array[i] = reader.int32();
  260. break;
  261. case 'int64':
  262. array[i] = reader.int64();
  263. break;
  264. case 'float32':
  265. array[i] = reader.float32();
  266. break;
  267. case 'float64':
  268. array[i] = reader.float64();
  269. break;
  270. default:
  271. throw new torch.Error(`Unsupported data type '${dataType}'.`);
  272. }
  273. }
  274. this._data = array;
  275. delete this.reader;
  276. }
  277. return this._data;
  278. }
  279. };
  280. const Tensor = class {
  281. constructor(dataType) {
  282. this.dataType = dataType;
  283. }
  284. read(reader) {
  285. const dim = reader.int32();
  286. this.size = reader.int64s(dim);
  287. this.stride = reader.int64s(dim);
  288. this.storage_offset = reader.int64() - 1;
  289. this.storage = reader.read();
  290. }
  291. };
  292. this.register('bnn.Binary');
  293. this.register('bnn.SpatialConvolution');
  294. this.register('cudnn.BatchNormalization');
  295. this.register('cudnn.BatchBRNNReLU');
  296. this.register('cudnn.BLSTM');
  297. this.register('cudnn.ReLU');
  298. this.register('cudnn.RNN');
  299. this.register('cudnn.Sigmoid');
  300. this.register('cudnn.SoftMax');
  301. this.register('cudnn.LogSoftMax');
  302. this.register('cudnn.normal3DConv');
  303. this.register('cudnn.normal3DdeConv');
  304. this.register('cudnn.SpatialAveragePooling');
  305. this.register('cudnn.SpatialBatchNormalization');
  306. this.register('cudnn.SpatialConvolution');
  307. this.register('cudnn.SpatialFullConvolution');
  308. this.register('cudnn.SpatialMaxPooling');
  309. this.register('cudnn.SpatialSoftMax');
  310. this.register('cudnn.Tanh');
  311. this.register('cudnn.VolumetricAveragePooling');
  312. this.register('cudnn.VolumetricBatchNormalization');
  313. this.register('cudnn.VolumetricConvolution');
  314. this.register('cudnn.VolumetricMaxPooling');
  315. this.register('Dict');
  316. this.register('inn.ConstAffine');
  317. this.register('inn.SpatialMaxPooling');
  318. this.register('nn.Abs');
  319. this.register('nn.AddConstant');
  320. this.register('nn.BatchNormalization');
  321. this.register('nn.BilinearSamplerBHWD');
  322. this.register('nn.BinActiveZ'); // allenai/XNOR-Net
  323. this.register('nn.BCECriterion');
  324. this.register('nn.Bottle');
  325. this.register('nn.Clamp');
  326. this.register('nn.CMul');
  327. this.register('nn.CAddTable');
  328. this.register('nn.CDivTable');
  329. this.register('nn.CMulTable');
  330. this.register('nn.CSubTable');
  331. this.register('nn.Concat');
  332. this.register('nn.Copy');
  333. this.register('nn.ConcatTable');
  334. this.register('nn.Contiguous');
  335. this.register('nn.Constant');
  336. this.register('nn.CostVolMulti');
  337. this.register('nn.DataParallelTable');
  338. this.register('nn.DepthConcat');
  339. this.register('nn.Dropout');
  340. this.register('nn.Exp');
  341. this.register('nn.ExpOut');
  342. this.register('nn.FlattenTable');
  343. this.register('nn.GenNoise');
  344. this.register('nn.Identity');
  345. this.register('nn.Index');
  346. this.register('nn.Inception');
  347. this.register('nn.InstanceNormalization');
  348. this.register('nn.JoinTable');
  349. this.register('nn.JointTrain');
  350. this.register('nn.KeypointCoordinate');
  351. this.register('nn.LeakyReLU');
  352. this.register('nn.Linear');
  353. this.register('nn.LinearNoBias');
  354. this.register('nn.LogSoftMax');
  355. this.register('nn.LookupTable');
  356. this.register('nn.LSTM');
  357. this.register('nn.MaskZero');
  358. this.register('nn.MapTable');
  359. this.register('nn.Max');
  360. this.register('nn.Mean');
  361. this.register('nn.Min');
  362. this.register('nn.MulConstant');
  363. this.register('nn.MM');
  364. this.register('nn.MSECriterion');
  365. this.register('nn.Narrow');
  366. this.register('nn.NarrowTable');
  367. this.register('nn.Normalize');
  368. this.register('nn.Normalize2');
  369. this.register('nn.NoiseFill');
  370. this.register('nn.Padding');
  371. this.register('nn.Parallel');
  372. this.register('nn.ParallelCriterion');
  373. this.register('nn.ParallelTable');
  374. this.register('nn.PixelShuffle');
  375. this.register('nn.Power');
  376. this.register('nn.PReLU');
  377. this.register('nn.Recursor');
  378. this.register('nn.ReLU');
  379. this.register('nn.Replicate');
  380. this.register('nn.Reshape');
  381. this.register('nn.ShaveImage');
  382. this.register('nn.Select');
  383. this.register('nn.SelectTable');
  384. this.register('nn.Sequencer');
  385. this.register('nn.Sequential');
  386. this.register('nn.Sigmoid');
  387. this.register('nn.Sum');
  388. this.register('nn.SoftMax');
  389. this.register('nn.SpatialAveragePooling');
  390. this.register('nn.SpatialBatchNormalization');
  391. this.register('nn.SpatialConvolution');
  392. this.register('nn.SpatialConvolution1_fw');
  393. this.register('nn.SpatialConvolutionMM');
  394. this.register('nn.SpatialCrossMapLRN');
  395. this.register('nn.SpatialDilatedConvolution');
  396. this.register('nn.SpatialDropout');
  397. this.register('nn.SpatialFractionalMaxPooling');
  398. this.register('nn.SpatialFullConvolution');
  399. this.register('nn.SpatialLPPooling');
  400. this.register('nn.SpatialMaxPooling');
  401. this.register('nn.SpatialMaxUnpooling');
  402. this.register('nn.SpatialReflectionPadding');
  403. this.register('nn.SpatialReplicationPadding');
  404. this.register('nn.SpatialSoftMax');
  405. this.register('nn.SpatialSubtractiveNormalization');
  406. this.register('nn.SpatialUpSamplingBilinear');
  407. this.register('nn.SpatialUpSamplingNearest');
  408. this.register('nn.SpatialZeroPadding');
  409. this.register('nn.SplitTable');
  410. this.register('nn.Squeeze');
  411. this.register('nn.Square');
  412. this.register('nn.Sqrt');
  413. this.register('nn.StereoJoin');
  414. this.register('nn.Tanh');
  415. this.register('nn.Transpose');
  416. this.register('nn.TotalVariation');
  417. this.register('nn.Unpool');
  418. this.register('nn.View');
  419. this.register('nn.gModule');
  420. this.register('nngraph.Node');
  421. this.register('graph.Edge');
  422. this.register('graph.Graph');
  423. this.register('torch.ByteTensor', class extends Tensor {
  424. constructor() {
  425. super('uint8');
  426. }
  427. });
  428. this.register('torch.CharTensor', class extends Tensor {
  429. constructor() {
  430. super('int8');
  431. }
  432. });
  433. this.register('torch.ShortTensor', class extends Tensor {
  434. constructor() {
  435. super('int16');
  436. }
  437. });
  438. this.register('torch.IntTensor', class extends Tensor {
  439. constructor() {
  440. super('int32');
  441. }
  442. });
  443. this.register('torch.LongTensor', class extends Tensor {
  444. constructor() {
  445. super('int64');
  446. }
  447. });
  448. this.register('torch.FloatTensor', class extends Tensor {
  449. constructor() {
  450. super('float32');
  451. }
  452. });
  453. this.register('torch.DoubleTensor', class extends Tensor {
  454. constructor() {
  455. super('float64');
  456. }
  457. });
  458. this.register('torch.CudaByteTensor', class extends Tensor {
  459. constructor() {
  460. super('uint8');
  461. }
  462. });
  463. this.register('torch.CudaCharTensor', class extends Tensor {
  464. constructor() {
  465. super('int8');
  466. }
  467. });
  468. this.register('torch.CudaShortTensor', class extends Tensor {
  469. constructor() {
  470. super('int16');
  471. }
  472. });
  473. this.register('torch.CudaIntTensor', class extends Tensor {
  474. constructor() {
  475. super('int32');
  476. }
  477. });
  478. this.register('torch.CudaLongTensor', class extends Tensor {
  479. constructor() {
  480. super('int64');
  481. }
  482. });
  483. this.register('torch.CudaTensor', class extends Tensor {
  484. constructor() {
  485. super('float32');
  486. }
  487. });
  488. this.register('torch.CudaDoubleTensor', class extends Tensor {
  489. constructor() {
  490. super('float64');
  491. }
  492. });
  493. this.register('torch.ByteStorage', class extends Storage {
  494. constructor() {
  495. super('uint8', 1);
  496. }
  497. });
  498. this.register('torch.CharStorage', class extends Storage {
  499. constructor() {
  500. super('int8', 1);
  501. }
  502. });
  503. this.register('torch.ShortStorage', class extends Storage {
  504. constructor() {
  505. super('int16', 2);
  506. }
  507. });
  508. this.register('torch.IntStorage', class extends Storage {
  509. constructor() {
  510. super('int32', 4);
  511. }
  512. });
  513. this.register('torch.LongStorage', class extends Storage {
  514. constructor() {
  515. super('int64', 8);
  516. }
  517. });
  518. this.register('torch.FloatStorage', class extends Storage {
  519. constructor() {
  520. super('float32', 4);
  521. }
  522. });
  523. this.register('torch.DoubleStorage', class extends Storage {
  524. constructor() {
  525. super('float64', 8);
  526. }
  527. });
  528. this.register('torch.CudaByteStorage', class extends Storage {
  529. constructor() {
  530. super('uint8', 1);
  531. }
  532. });
  533. this.register('torch.CudaCharStorage', class extends Storage {
  534. constructor() {
  535. super('int8', 1);
  536. }
  537. });
  538. this.register('torch.CudaShortStorage', class extends Storage {
  539. constructor() {
  540. super('int16', 2);
  541. }
  542. });
  543. this.register('torch.CudaIntStorage', class extends Storage {
  544. constructor() {
  545. super('int32', 4);
  546. }
  547. });
  548. this.register('torch.CudaLongStorage', class extends Storage {
  549. constructor() {
  550. super('int64', 8);
  551. }
  552. });
  553. this.register('torch.CudaIntStorage', class extends Storage {
  554. constructor() {
  555. super('int32', 4);
  556. }
  557. });
  558. this.register('torch.CudaStorage', class extends Storage {
  559. constructor() {
  560. super('float32', 4);
  561. }
  562. });
  563. this.register('torch.CudaFloatStorage', class extends Storage {
  564. constructor() {
  565. super('float64', 8);
  566. }
  567. });
  568. this.register('w2nn.AuxiliaryLossTable');
  569. this.register('w2nn.InplaceClip01');
  570. this.register('w2nn.ScaleTable');
  571. this.register('LuaFunction', class {
  572. constructor(size, dumped, upvalues) {
  573. this.size = size;
  574. this.dumped = dumped;
  575. this.upvalues = upvalues;
  576. }
  577. });
  578. }
  579. register(name, type) {
  580. type = type || class {};
  581. const parts = name.split('.');
  582. type.__name__ = parts.pop();
  583. type.__module__ = parts.join('.');
  584. type.prototype.__class__ = type;
  585. this._types.set(name, type);
  586. }
  587. read() {
  588. const type = this.int32();
  589. switch (type) {
  590. case 0: return null;
  591. case 1: return this.float64();
  592. case 2: return this.string();
  593. case 3: return this.table();
  594. case 4: return this.object();
  595. case 5: return this.boolean();
  596. case 6: return this.function();
  597. case 7: return this.function();
  598. case 8: return this.function();
  599. default: throw new torch.Error(`File format has invalid type '${type}'.`);
  600. }
  601. }
  602. boolean() {
  603. return this._reader.boolean();
  604. }
  605. int32() {
  606. return this._reader.int32();
  607. }
  608. int64() {
  609. return this._reader.int64();
  610. }
  611. int64s(size) {
  612. return this._reader.int64s(size);
  613. }
  614. float64() {
  615. return this._reader.float64();
  616. }
  617. string() {
  618. return this._reader.string();
  619. }
  620. object() {
  621. const index = this.int32();
  622. if (this._memo.has(index)) {
  623. return this._memo.get(index);
  624. }
  625. let version = this.string();
  626. let name = null;
  627. if (version.startsWith('V ')) {
  628. name = this.string();
  629. version = parseInt(version.split(' ')[1], 10);
  630. } else {
  631. name = version;
  632. version = 0;
  633. }
  634. if (!this._types.has(name)) {
  635. this.callback(name);
  636. this.register(name);
  637. }
  638. const type = this._types.get(name);
  639. const obj = Reflect.construct(type, []);
  640. this._memo.set(index, obj);
  641. if (obj.read) {
  642. obj.read(this, version);
  643. } else {
  644. const attributes = this.read();
  645. if (attributes !== null) {
  646. for (const [key, value] of Array.from(attributes)) {
  647. obj[key] = value;
  648. }
  649. }
  650. }
  651. return obj;
  652. }
  653. table() {
  654. const index = this.int32();
  655. if (this._memo.has(index)) {
  656. return this._memo.get(index);
  657. }
  658. const table = new Map();
  659. this._memo.set(index, table);
  660. const size = this.int32();
  661. let convert = true;
  662. let sum = 0;
  663. for (let i = 0; i < size; i++) {
  664. const key = this.read();
  665. const value = this.read();
  666. table.set(key, value);
  667. if (Number.isInteger(key) && key >= 0) {
  668. sum += key;
  669. } else {
  670. convert = false;
  671. }
  672. }
  673. const n = table.size;
  674. if (convert && (n * (n + 1)) === (2 * sum)) {
  675. const list = [];
  676. for (let i = 0; i < n; i++) {
  677. let item = table.get(i + 1);
  678. if (item === table) {
  679. item = list;
  680. }
  681. list.push(item);
  682. }
  683. this._memo.set(index, list);
  684. return list;
  685. }
  686. return table;
  687. }
  688. function() {
  689. const index = this.int32();
  690. if (this._memo.has(index)) {
  691. return this._memo.get(index);
  692. }
  693. const size = this.int32();
  694. const dumped = this._reader.read(size);
  695. const upvalues = this.read();
  696. const type = this._types.get('LuaFunction');
  697. const obj = Reflect.construct(type, [size, dumped, upvalues]);
  698. this._memo.set(index, obj);
  699. return obj;
  700. }
  701. storage(size, itemSize, dataType) {
  702. return this._reader.storage(size, itemSize, dataType);
  703. }
  704. };
  705. torch.BinaryReader = class {
  706. constructor(data) {
  707. this._reader = base.BinaryReader.open(data);
  708. this._textDecoder = new TextDecoder('ascii');
  709. }
  710. seek(position) {
  711. this._reader.seek(position);
  712. }
  713. skip(offset) {
  714. this._reader.skip(offset);
  715. }
  716. read(length) {
  717. return this._reader.read(length);
  718. }
  719. boolean() {
  720. return this.int32() === 1;
  721. }
  722. int32() {
  723. return this._reader.int32();
  724. }
  725. int64() {
  726. return this._reader.int64().toNumber();
  727. }
  728. int64s(size) {
  729. const array = [];
  730. for (let i = 0; i < size; i++) {
  731. array.push(this.int64());
  732. }
  733. return array;
  734. }
  735. float32() {
  736. return this._reader.float32();
  737. }
  738. float64() {
  739. return this._reader.float64();
  740. }
  741. string() {
  742. const size = this.int32();
  743. const buffer = this.read(size);
  744. return this._textDecoder.decode(buffer);
  745. }
  746. storage(size, itemSize) {
  747. const buffer = this.read(size * itemSize);
  748. return new torch.BinaryReader(buffer);
  749. }
  750. };
  751. torch.TextReader = class {
  752. constructor(data, separator) {
  753. this._buffer = data instanceof Uint8Array ? data : data.peek();
  754. this._position = 0;
  755. this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
  756. this._textDecoder = new TextDecoder('ascii');
  757. this._separator = separator || 0x0a;
  758. }
  759. seek(position) {
  760. this._position = position;
  761. }
  762. line(size) {
  763. const start = this._position;
  764. while (this._position < this._buffer.length && size > -1) {
  765. const c = this._buffer[this._position++];
  766. if (c === this._separator) {
  767. return this._buffer.slice(start, this._position - 1);
  768. } else if (this._position === this._buffer.length) {
  769. return this._buffer.slice(start, this._position);
  770. }
  771. size--;
  772. }
  773. throw new torch.Error('Line exceeded maximum length.');
  774. }
  775. boolean() {
  776. return this.int32() === 1;
  777. }
  778. read(size) {
  779. return this.line(size);
  780. }
  781. int8() {
  782. return this.int64();
  783. }
  784. int16() {
  785. return this.int64();
  786. }
  787. int32() {
  788. return this.int64();
  789. }
  790. int64() {
  791. const token = this._textDecoder.decode(this.line(20));
  792. const number = Number.parseInt(token, 10);
  793. if (Number.isNaN(token - number)) {
  794. throw new torch.Error(`Couldn't parse int64 '${token}'.`);
  795. }
  796. return number;
  797. }
  798. int64s(size) {
  799. const array = [];
  800. if (size > 0) {
  801. const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
  802. for (const token of content.split(' ')) {
  803. const number = Number.parseInt(token, 10);
  804. if (Number.isNaN(token - number)) {
  805. throw new torch.Error(`Couldn't parse int64 '${token}'.`);
  806. }
  807. array.push(number);
  808. }
  809. }
  810. return array;
  811. }
  812. float32() {
  813. return this.float64();
  814. }
  815. float64() {
  816. const token = this._textDecoder.decode(this.line(24));
  817. if (token.startsWith('-nan')) {
  818. return -NaN;
  819. }
  820. if (token.startsWith('nan')) {
  821. return NaN;
  822. }
  823. if (token.startsWith('inf')) {
  824. return Infinity;
  825. }
  826. if (token.startsWith('-inf')) {
  827. return -Infinity;
  828. }
  829. const number = Number.parseFloat(token);
  830. if (Number.isNaN(token - number)) {
  831. throw new torch.Error(`Couldn't parse float '${token}'.`);
  832. }
  833. return number;
  834. }
  835. string() {
  836. const size = this.int32();
  837. if (size === 0) {
  838. return '';
  839. }
  840. const data = this.line(size);
  841. const content = this._textDecoder.decode(data);
  842. if (size !== content.length) {
  843. throw new torch.Error('Invalid string length.');
  844. }
  845. return content;
  846. }
  847. storage(size, itemSize, dataType) {
  848. if (size <= 0) {
  849. throw new torch.Error(`Unsupported storage size '${size}'.`);
  850. }
  851. if (dataType === 'uint8') {
  852. const start = this._position;
  853. this._position += size;
  854. const bytes = this._buffer.slice(start, this._position);
  855. this.line(0);
  856. return new torch.BinaryReader(bytes);
  857. }
  858. const data = this.line(Number.MAX_SAFE_INTEGER);
  859. return new torch.TextReader(data, 0x20);
  860. }
  861. };
  862. torch.Error = class extends Error {
  863. constructor(message) {
  864. super(message);
  865. this.name = 'Error loading Torch model.';
  866. }
  867. };
  868. export const ModelFactory = torch.ModelFactory;