| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357 |
- var torch = torch || {};
- torch.ModelFactory = class {
- match(context) {
- return torch.T7Reader.open(context);
- }
- open(context, match) {
- return context.metadata('torch-metadata.json').then((metadata) => {
- const identifier = context.identifier;
- const reader = match;
- reader.callback = (name) => {
- if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
- context.exception(new torch.Error("Unsupported type '" + name + "' in '" + identifier + "'."), false);
- }
- return null;
- };
- let root = reader.read();
- if (root && Array.isArray(root) && root.length == 2 && root[0].__class__ && !root[1].__class__) {
- root = root[0];
- }
- return new torch.Model(metadata, root);
- });
- }
- };
- torch.Model = class {
- constructor(metadata, root) {
- this._graphs = [];
- this._graphs.push(new torch.Graph(metadata, root));
- }
- get graphs() {
- return this._graphs;
- }
- get format() {
- return 'Torch v7';
- }
- };
- torch.Graph = class {
- constructor(metadata, root) {
- this._inputs = [];
- this._outputs = [];
- this._nodes = [];
- this._groups = 'false';
- if (Object.prototype.hasOwnProperty.call(root, 'model')) {
- root = root.model;
- }
- const inputs = [];
- const outputs = [];
- this._loadModule(metadata, root, [], '', inputs, outputs);
- this._inputs = this._inputs.concat(inputs.map((input, index) => {
- return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
- }));
- this._outputs = this._outputs.concat(outputs.map((output, index) => {
- return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
- }));
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- get nodes() {
- return this._nodes;
- }
- get groups() {
- return this._groups;
- }
- _loadModule(metadata, module, groups, key, inputs, outputs) {
- if (groups.length > 0) {
- this._groups = true;
- }
- const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : '';
- switch (type) {
- case 'nn.Sequential': {
- groups.push(key);
- let subInputs = inputs;
- let subOutputs = [];
- const length = module.modules.length;
- let index = 0;
- for (const subModule of module.modules) {
- if (index == length - 1) {
- subOutputs = outputs;
- }
- this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
- subInputs = subOutputs;
- subOutputs = [];
- index++;
- }
- groups.pop();
- break;
- }
- case 'nn.Parallel':
- case 'nn.ParallelTable':
- case 'nn.JointTrain': {
- groups.push(key);
- let newInputs = [];
- let newOutputs = [];
- let index = 0;
- for (const subModule of module.modules) {
- const subInputs = [].concat(inputs);
- const subOutputs = [].concat(outputs);
- this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
- if (inputs.length == 0) {
- newInputs = newInputs.concat(subInputs);
- }
- if (outputs.length == 0) {
- newOutputs = newOutputs.concat(subOutputs);
- }
- index++;
- }
- inputs = inputs.concat(newInputs);
- for (const newOutput of newOutputs) {
- outputs.push(newOutput);
- }
- groups.pop();
- break;
- }
- case 'nn.Concat':
- case 'nn.ConcatTable': {
- const prefix = key;
- if (inputs.length == 0) {
- inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
- }
- let concatInputs = [];
- let index = 0;
- for (const subModule of module.modules) {
- const streamInputs = inputs.map((input) => input);
- const streamOutputs = [];
- this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
- concatInputs = concatInputs.concat(streamOutputs);
- index++;
- }
- delete module.modules;
- delete module.dimension;
- this._createNode(metadata, module, groups, key, concatInputs, outputs);
- break;
- }
- case 'nn.Inception': {
- delete module.modules; // TODO
- delete module.module; // TODO
- delete module.transfer; // TODO
- delete module.pool; // TODO
- this._createNode(metadata, module, groups, key, inputs, outputs);
- break;
- }
- case 'nn.gModule': {
- /*
- let index = 0;
- for (const subModule of module.modules) {
- subModule.modules = [];
- this._loadModule(metadata, subModule, groups, index.toString(), [], []);
- index++;
- }
- */
- this._createNode(metadata, module, groups, key, inputs, outputs);
- break;
- }
- default: {
- this._createNode(metadata, module, groups, key, inputs, outputs);
- break;
- }
- }
- }
- _createNode(metadata, module, group, subIndex, inputs, outputs) {
- const node = new torch.Node(metadata, module, group, subIndex, inputs, outputs);
- this._nodes.push(node);
- }
- };
- torch.Parameter = class {
- constructor(name, visible, args) {
- this._name = name;
- this._visible = visible;
- this._arguments = args;
- }
- get name() {
- return this._name;
- }
- get visible() {
- return this._visible;
- }
- get arguments() {
- return this._arguments;
- }
- };
- torch.Argument = class {
- constructor(name, type, initializer) {
- if (typeof name !== 'string') {
- throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
- }
- this._name = name;
- this._type = type;
- this._initializer = initializer;
- }
- get name() {
- return this._name;
- }
- get type() {
- if (this._initializer) {
- return this._initializer.type;
- }
- return this._type;
- }
- get initializer() {
- return this._initializer;
- }
- };
- torch.Node = class {
- constructor(metadata, module, groups, name, inputs, outputs) {
- this._group = groups.join('/');
- if (module.name && typeof module.name === 'string') {
- this._name = module.name;
- delete module.name;
- }
- else {
- this._name = this._group ? (this._group + ':' + name) : name;
- }
- const type = module.__class__ ? module.__class__.__module__ + '.' + module.__class__.__name__ : 'nn.Module';
- this._type = metadata.type(type);
- let initializers = [];
- for (const entry of Object.entries(module)) {
- const key = entry[0];
- const obj = entry[1];
- if (obj && obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Storage')) {
- module[key] = obj.data();
- }
- }
- delete module.iSize;
- delete module.finput;
- delete module.fgradInput;
- delete module.output;
- delete module.gradInput;
- delete module.gradWeight;
- delete module.gradBias;
- delete module.grad_tmp;
- delete module.scaleT;
- delete module._input;
- delete module._output;
- delete module._gradInput;
- delete module._gradOutput;
- delete module.buffer;
- delete module.buffer2;
- delete module.tmp_in;
- delete module.tmp_out;
- delete module.accUpdateGradParameters;
- switch (this._type.name) {
- case 'nn.Linear':
- delete module.addBuffer;
- break;
- case 'nn.Normalize':
- case 'nn.Normalize2':
- delete module.addBuffer;
- delete module.normp;
- delete module.norm;
- break;
- case 'cudnn.SpatialConvolution':
- case 'cudnn.SpatialFullConvolution':
- case 'nn.SpatialConvolution':
- case 'nn.SpatialConvolutionMM':
- case 'nn.SpatialDilatedConvolution':
- case 'nn.SpatialFullConvolution':
- delete module.ones;
- delete module.input_slice;
- delete module.output_slice;
- delete module.convDescData;
- this._updateSize(module, 'adj');
- this._updateSize(module, 'd');
- this._updateSize(module, 'dilation');
- this._updateSize(module, 'k');
- this._updateSize(module, 'pad');
- break;
- case 'cudnn.BatchNormalization':
- case 'cudnn.SpatialBatchNormalization':
- case 'nn.BatchNormalization':
- case 'nn.SpatialBatchNormalization':
- case 'nn.InstanceNormalization':
- delete module.save_mean;
- delete module.save_std;
- delete module.gradWeight;
- delete module.normalized;
- delete module.centered;
- delete module.bn; // TODO InstanceNormalization
- break;
- case 'nn.SpatialCrossMapLRN':
- delete module.scale;
- break;
- case 'cudnn.SpatialMaxPooling':
- case 'cudnn.SpatialAveragePooling':
- case 'inn.SpatialMaxPooling':
- case 'nn.SpatialMaxPooling':
- case 'nn.SpatialAveragePooling':
- delete module.indices;
- this._updateSize(module, 'pad');
- this._updateSize(module, 'd');
- this._updateSize(module, 'k');
- break;
- case 'nn.SpatialZeroPadding':
- case 'nn.SpatialReflectionPadding':
- case 'nn.SpatialReplicationPadding':
- this._updateBox(module, 'pad');
- break;
- case 'nn.Dropout':
- delete module.noise;
- break;
- case 'nn.gModule':
- delete module.forwardnodes;
- delete module.backwardnodes;
- break;
- case 'nn.StereoJoin':
- delete module.output_L;
- break;
- default:
- break;
- }
- this._attributes = [];
- if (module.__class__) {
- for (const entry of Object.entries(module)) {
- const key = entry[0];
- const obj = entry[1];
- if (key == '_type') {
- continue;
- }
- if (Array.isArray(obj) && obj.every(((item) => item && item.__class__ && item.__class__.__module__ === 'nn'))) {
- continue;
- }
- if (obj.__class__ && obj.__class__.__module__ === 'torch' && obj.__class__.__name__.endsWith('Tensor')) {
- initializers.push(new torch.Parameter(key, true, [
- new torch.Argument(key, null, new torch.Tensor(obj))
- ]));
- continue;
- }
- if (key == 'modules') {
- continue;
- }
- if (obj.__class__ && obj.__class__.__module__ !== '' && obj.__class__.__name__ != 'LuaFunction') {
- continue;
- }
- const attribute = new torch.Attribute(metadata, type, key, obj);
- this._attributes.push(attribute);
- }
- }
- this._inputs = [];
- if (inputs.length == 0 && this._name) {
- inputs.push(new torch.Argument(this._name + ':in', null, null));
- }
- this._inputs.push(new torch.Parameter('input', true, inputs));
- if (outputs.length == 0 && this._name) {
- outputs.push(new torch.Argument(this._name, null, null));
- }
- this._outputs = [];
- this._outputs.push(new torch.Parameter('output', true, outputs));
- initializers = initializers.filter((argument) => {
- if (argument.name == 'weight') {
- this._inputs.push(argument);
- return false;
- }
- return true;
- });
- initializers = initializers.filter((argument) => {
- if (argument.name == 'bias') {
- this._inputs.push(argument);
- return false;
- }
- return true;
- });
- this._inputs = this._inputs.concat(initializers);
- }
- get name() {
- return this._name;
- }
- get type() {
- return this._type;
- }
- get group() {
- return this._group;
- }
- get attributes() {
- return this._attributes;
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- _updateSize(module, name) {
- if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
- Object.prototype.hasOwnProperty.call(module, name + 'H')) {
- module[name] = [ module[name + 'W'], module[name + 'H'] ];
- delete module[name + 'W'];
- delete module[name + 'H'];
- }
- }
- _updateBox(module, name) {
- if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
- Object.prototype.hasOwnProperty.call(module, name + '_r') &&
- Object.prototype.hasOwnProperty.call(module, name + '_b') &&
- Object.prototype.hasOwnProperty.call(module, name + '_l')) {
- module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
- delete module[name + '_t'];
- delete module[name + '_r'];
- delete module[name + '_b'];
- delete module[name + '_l'];
- }
- }
- };
- torch.Attribute = class {
- constructor(metadata, type, name, value) {
- this._name = name;
- this._value = value;
- if (name == 'train') {
- this._visible = false;
- }
- const schema = metadata.attribute(type, name);
- if (schema) {
- if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
- this._visible = schema.visible;
- }
- else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
- if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
- this._visible = false;
- }
- }
- }
- }
- get name() {
- return this._name;
- }
- get value() {
- return this._value;
- }
- get visible() {
- return this._visible == false ? false : true;
- }
- };
- torch.Tensor = class {
- constructor(tensor) {
- this._type = new torch.TensorType(tensor);
- this._storage = tensor.storage;
- this._offset = tensor.storage_offset;
- }
- get type() {
- return this._type;
- }
- get state() {
- return this._context().state || null;
- }
- get value() {
- const context = this._context();
- if (context.state) {
- return null;
- }
- context.limit = Number.MAX_SAFE_INTEGER;
- return this._decode(context, 0);
- }
- toString() {
- const context = this._context();
- if (context.state) {
- return '';
- }
- context.limit = 1000;
- const value = this._decode(context, 0);
- return JSON.stringify(value, null, 4);
- }
- _context() {
- const context = {};
- context.state = null;
- context.index = 0;
- context.count = 0;
- if (!this._storage) {
- context.state = 'Tensor data is empty.';
- return context;
- }
- context.data = this._storage.data();
- context.index = this._offset;
- if (!context.data) {
- context.state = 'Tensor data is empty.';
- return context;
- }
- switch (this._type.dataType) {
- case 'uint8':
- case 'int8':
- case 'int16':
- case 'int32':
- case 'int64':
- case 'float32':
- case 'float64':
- break;
- default:
- context.state = 'Tensor data type is not implemented.';
- break;
- }
- context.dimensions = this._type.shape.dimensions;
- if (!context.dimensions && context.dimensions.length == 0) {
- context.state = 'Tensor has no dimensions.';
- return context;
- }
- return context;
- }
- _decode(context, dimension) {
- const results = [];
- const size = context.dimensions[dimension];
- if (dimension == context.dimensions.length - 1) {
- for (let i = 0; i < size; i++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- results.push(context.data[context.index]);
- context.index++;
- context.count++;
- }
- }
- else {
- for (let j = 0; j < size; j++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- results.push(this._decode(context, dimension + 1));
- }
- }
- return results;
- }
- };
- torch.TensorType = class {
- constructor(tensor) {
- this._dataType = tensor.dataType;
- this._shape = new torch.TensorShape(tensor.size);
- }
- get dataType() {
- return this._dataType;
- }
- get shape() {
- return this._shape;
- }
- toString() {
- return (this.dataType || '?') + this._shape.toString();
- }
- };
- torch.TensorShape = class {
- constructor(dimensions) {
- this._dimensions = dimensions;
- }
- get dimensions() {
- return this._dimensions;
- }
- toString() {
- if (this._dimensions) {
- if (this._dimensions.length == 0) {
- return '';
- }
- return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
- }
- return '';
- }
- };
- torch.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading Torch model.';
- }
- };
- torch.T7Reader = class {
- static open(context) {
- const stream = context.stream;
- if (stream && stream.length >= 4 && stream.peek(4).every((value, index) => value === 0x00 || (index == 0 && value <= 0x08))) {
- const reader = new torch.BinaryReader(stream);
- return new torch.T7Reader(reader);
- }
- if (stream && stream.length >= 2) {
- const buffer = stream.peek(2);
- const value = String.fromCharCode(stream.peek(1)[0]);
- if (buffer[1] === 0x0a && (value >= '0' && value <= '8')) {
- const reader = new torch.TextReader(stream);
- return new torch.T7Reader(reader);
- }
- }
- return null;
- }
- constructor(reader) {
- this._reader = reader;
- this._memo = new Map();
- this._types = new Map();
- const Storage = class {
- constructor(dataType, itemSize) {
- this.dataType = dataType;
- this.itemSize = itemSize;
- }
- data() {
- if (this.reader) {
- const reader = this.reader;
- reader.reset();
- const dataType = this.dataType;
- const size = this.size;
- const array = new Array(size);
- for (let i = 0; i < size; i++) {
- switch (dataType) {
- case 'uint8':
- array[i] = reader.byte();
- break;
- case 'int8':
- array[i] = reader.int8();
- break;
- case 'int16':
- array[i] = reader.int16();
- break;
- case 'int32':
- array[i] = reader.int32();
- break;
- case 'int64':
- array[i] = reader.int64();
- break;
- case 'float32':
- array[i] = reader.float32();
- break;
- case 'float64':
- array[i] = reader.float64();
- break;
- default:
- throw new torch.Error("Unsupported data type '" + dataType + "'.");
- }
- }
- this._data = array;
- delete this.reader;
- }
- return this._data;
- }
- read(reader) {
- this.size = reader.int64();
- this.reader = reader.storage(this.size, this.itemSize, this.dataType);
- }
- };
- const Tensor = class {
- constructor(dataType) {
- this.dataType = dataType;
- }
- read(reader) {
- const dim = reader.int32();
- this.size = reader.int64s(dim);
- this.stride = reader.int64s(dim);
- this.storage_offset = reader.int64() - 1;
- this.storage = reader.read();
- }
- };
- this.register('bnn.Binary');
- this.register('bnn.SpatialConvolution');
- this.register('cudnn.BatchNormalization');
- this.register('cudnn.BatchBRNNReLU');
- this.register('cudnn.BLSTM');
- this.register('cudnn.ReLU');
- this.register('cudnn.RNN');
- this.register('cudnn.Sigmoid');
- this.register('cudnn.SoftMax');
- this.register('cudnn.LogSoftMax');
- this.register('cudnn.normal3DConv');
- this.register('cudnn.normal3DdeConv');
- this.register('cudnn.SpatialAveragePooling');
- this.register('cudnn.SpatialBatchNormalization');
- this.register('cudnn.SpatialConvolution');
- this.register('cudnn.SpatialFullConvolution');
- this.register('cudnn.SpatialMaxPooling');
- this.register('cudnn.SpatialSoftMax');
- this.register('cudnn.Tanh');
- this.register('cudnn.VolumetricAveragePooling');
- this.register('cudnn.VolumetricBatchNormalization');
- this.register('cudnn.VolumetricConvolution');
- this.register('cudnn.VolumetricMaxPooling');
- this.register('Dict');
- this.register('inn.ConstAffine');
- this.register('inn.SpatialMaxPooling');
- this.register('nn.Abs');
- this.register('nn.AddConstant');
- this.register('nn.BatchNormalization');
- this.register('nn.BilinearSamplerBHWD');
- this.register('nn.BinActiveZ'); // allenai/XNOR-Net
- this.register('nn.BCECriterion');
- this.register('nn.Bottle');
- this.register('nn.Clamp');
- this.register('nn.CMul');
- this.register('nn.CAddTable');
- this.register('nn.CDivTable');
- this.register('nn.CMulTable');
- this.register('nn.CSubTable');
- this.register('nn.Concat');
- this.register('nn.Copy');
- this.register('nn.ConcatTable');
- this.register('nn.Contiguous');
- this.register('nn.Constant');
- this.register('nn.CostVolMulti');
- this.register('nn.DataParallelTable');
- this.register('nn.DepthConcat');
- this.register('nn.Dropout');
- this.register('nn.Exp');
- this.register('nn.ExpOut');
- this.register('nn.FlattenTable');
- this.register('nn.GenNoise');
- this.register('nn.Identity');
- this.register('nn.Index');
- this.register('nn.Inception');
- this.register('nn.InstanceNormalization');
- this.register('nn.JoinTable');
- this.register('nn.JointTrain');
- this.register('nn.KeypointCoordinate');
- this.register('nn.LeakyReLU');
- this.register('nn.Linear');
- this.register('nn.LinearNoBias');
- this.register('nn.LogSoftMax');
- this.register('nn.LookupTable');
- this.register('nn.LSTM');
- this.register('nn.MaskZero');
- this.register('nn.MapTable');
- this.register('nn.Max');
- this.register('nn.Mean');
- this.register('nn.Min');
- this.register('nn.MulConstant');
- this.register('nn.MM');
- this.register('nn.MSECriterion');
- this.register('nn.Narrow');
- this.register('nn.NarrowTable');
- this.register('nn.Normalize');
- this.register('nn.Normalize2');
- this.register('nn.NoiseFill');
- this.register('nn.Padding');
- this.register('nn.Parallel');
- this.register('nn.ParallelCriterion');
- this.register('nn.ParallelTable');
- this.register('nn.PixelShuffle');
- this.register('nn.Power');
- this.register('nn.PReLU');
- this.register('nn.Recursor');
- this.register('nn.ReLU');
- this.register('nn.Replicate');
- this.register('nn.Reshape');
- this.register('nn.ShaveImage');
- this.register('nn.Select');
- this.register('nn.SelectTable');
- this.register('nn.Sequencer');
- this.register('nn.Sequential');
- this.register('nn.Sigmoid');
- this.register('nn.Sum');
- this.register('nn.SoftMax');
- this.register('nn.SpatialAveragePooling');
- this.register('nn.SpatialBatchNormalization');
- this.register('nn.SpatialConvolution');
- this.register('nn.SpatialConvolutionMM');
- this.register('nn.SpatialCrossMapLRN');
- this.register('nn.SpatialDilatedConvolution');
- this.register('nn.SpatialDropout');
- this.register('nn.SpatialFractionalMaxPooling');
- this.register('nn.SpatialFullConvolution');
- this.register('nn.SpatialLPPooling');
- this.register('nn.SpatialMaxPooling');
- this.register('nn.SpatialMaxUnpooling');
- this.register('nn.SpatialReflectionPadding');
- this.register('nn.SpatialReplicationPadding');
- this.register('nn.SpatialSoftMax');
- this.register('nn.SpatialSubtractiveNormalization');
- this.register('nn.SpatialUpSamplingBilinear');
- this.register('nn.SpatialUpSamplingNearest');
- this.register('nn.SpatialZeroPadding');
- this.register('nn.SplitTable');
- this.register('nn.Squeeze');
- this.register('nn.Square');
- this.register('nn.Sqrt');
- this.register('nn.StereoJoin');
- this.register('nn.Tanh');
- this.register('nn.Transpose');
- this.register('nn.TotalVariation');
- this.register('nn.Unpool');
- this.register('nn.View');
- this.register('nn.gModule');
- this.register('nngraph.Node');
- this.register('graph.Edge');
- this.register('graph.Graph');
- this.register('torch.ByteTensor', class extends Tensor {
- constructor() {
- super('uint8');
- }
- });
- this.register('torch.CharTensor', class extends Tensor {
- constructor() {
- super('int8');
- }
- });
- this.register('torch.ShortTensor', class extends Tensor {
- constructor() {
- super('int16');
- }
- });
- this.register('torch.IntTensor', class extends Tensor {
- constructor() {
- super('int32');
- }
- });
- this.register('torch.LongTensor', class extends Tensor {
- constructor() {
- super('int64');
- }
- });
- this.register('torch.FloatTensor', class extends Tensor {
- constructor() {
- super('float32');
- }
- });
- this.register('torch.DoubleTensor', class extends Tensor {
- constructor() {
- super('float64');
- }
- });
- this.register('torch.CudaByteTensor', class extends Tensor {
- constructor() {
- super('uint8');
- }
- });
- this.register('torch.CudaCharTensor', class extends Tensor {
- constructor() {
- super('int8');
- }
- });
- this.register('torch.CudaShortTensor', class extends Tensor {
- constructor() {
- super('int16');
- }
- });
- this.register('torch.CudaIntTensor', class extends Tensor {
- constructor() {
- super('int32');
- }
- });
- this.register('torch.CudaLongTensor', class extends Tensor {
- constructor() {
- super('int64');
- }
- });
- this.register('torch.CudaTensor', class extends Tensor {
- constructor() {
- super('float32');
- }
- });
- this.register('torch.CudaDoubleTensor', class extends Tensor {
- constructor() {
- super('float64');
- }
- });
- this.register('torch.ByteStorage', class extends Storage {
- constructor() {
- super('uint8', 1);
- }
- });
- this.register('torch.CharStorage', class extends Storage {
- constructor() {
- super('int8', 1);
- }
- });
- this.register('torch.ShortStorage', class extends Storage {
- constructor() {
- super('int16', 2);
- }
- });
- this.register('torch.IntStorage', class extends Storage {
- constructor() {
- super('int32', 4);
- }
- });
- this.register('torch.LongStorage', class extends Storage {
- constructor() {
- super('int64', 8);
- }
- });
- this.register('torch.FloatStorage', class extends Storage {
- constructor() {
- super('float32', 4);
- }
- });
- this.register('torch.DoubleStorage', class extends Storage {
- constructor() {
- super('float64', 8);
- }
- });
- this.register('torch.CudaByteStorage', class extends Storage {
- constructor() {
- super('uint8', 1);
- }
- });
- this.register('torch.CudaCharStorage', class extends Storage {
- constructor() {
- super('int8', 1);
- }
- });
- this.register('torch.CudaShortStorage', class extends Storage {
- constructor() {
- super('int16', 2);
- }
- });
- this.register('torch.CudaIntStorage', class extends Storage {
- constructor() {
- super('int32', 4);
- }
- });
- this.register('torch.CudaLongStorage', class extends Storage {
- constructor() {
- super('int64', 8);
- }
- });
- this.register('torch.CudaIntStorage', class extends Storage {
- constructor() {
- super('int32', 4);
- }
- });
- this.register('torch.CudaStorage', class extends Storage {
- constructor() {
- super('float32', 4);
- }
- });
- this.register('torch.CudaFloatStorage', class extends Storage {
- constructor() {
- super('float64', 8);
- }
- });
- this.register('w2nn.AuxiliaryLossTable');
- this.register('w2nn.InplaceClip01');
- this.register('w2nn.ScaleTable');
- this.register('LuaFunction', class {
- constructor(size, dumped, upvalues) {
- this.size = size;
- this.dumped = dumped;
- this.upvalues = upvalues;
- }
- });
- }
- register(name, type) {
- type = type || class {};
- const parts = name.split('.');
- type.__name__ = parts.pop();
- type.__module__ = parts.join('.');
- type.prototype.__class__ = type;
- this._types.set(name, type);
- }
- read() {
- const type = this.int32();
- switch (type) {
- case 0: return null;
- case 1: return this.float64();
- case 2: return this.string();
- case 3: return this.table();
- case 4: return this.object();
- case 5: return this.boolean();
- case 6: return this.function();
- case 7: return this.function();
- case 8: return this.function();
- default: throw new torch.Error("File format has invalid type '" + type + "'.");
- }
- }
- boolean() {
- return this._reader.boolean();
- }
- bytes(size) {
- return this._reader.bytes(size);
- }
- int32() {
- return this._reader.int32();
- }
- int64() {
- return this._reader.int64();
- }
- int64s(size) {
- return this._reader.int64s(size);
- }
- float64() {
- return this._reader.float64();
- }
- string() {
- return this._reader.string();
- }
- object() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- let version = this.string();
- let name = null;
- if (version.startsWith('V ')) {
- name = this.string();
- version = Number(version.split(' ')[1]);
- }
- else {
- name = version;
- version = 0;
- }
- if (!this._types.has(name)) {
- this.callback(name);
- this.register(name);
- }
- const type = this._types.get(name);
- const obj = Reflect.construct(type, []);
- this._memo.set(index, obj);
- if (obj.read) {
- obj.read(this, version);
- }
- else {
- const attributes = this.read();
- if (attributes != null) {
- for (const entry of Object.entries(attributes)) {
- const key = entry[0];
- obj[key] = entry[1];
- }
- }
- }
- return obj;
- }
- table() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- const table = {};
- this._memo.set(index, table);
- const size = this.int32();
- let convert = true;
- let sum = 0;
- for (let i = 0; i < size; i++) {
- const key = this.read();
- const value = this.read();
- table[key] = value;
- if (Number.isInteger(key) && key >= 0) {
- sum += key;
- }
- else {
- convert = false;
- }
- }
- const n = Object.keys(table).length;
- if (convert && (n * (n + 1)) == (2 * sum)) {
- const list = [];
- for (let j = 0; j < n; j++) {
- let item = table[j + 1];
- if (item == table) {
- item = list;
- }
- list.push(item);
- }
- this._memo.set(index, list);
- return list;
- }
- return table;
- }
- function() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- const size = this.int32();
- const dumped = this.bytes(size);
- const upvalues = this.read();
- const type = this._types.get('LuaFunction');
- const obj = Reflect.construct(type, [ size, dumped, upvalues ]);
- this._memo.set(index, obj);
- return obj;
- }
- storage(size, itemSize, dataType) {
- return this._reader.storage(size, itemSize, dataType);
- }
- };
- torch.BinaryReader = class {
- constructor(data) {
- this._buffer = data instanceof Uint8Array ? data : data.peek();
- this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
- this._position = 0;
- this._textDecoder = new TextDecoder('ascii');
- }
- reset() {
- this._position = 0;
- }
- skip(offset) {
- this._position += offset;
- if (this._position > this._buffer.length) {
- throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
- }
- }
- boolean() {
- return this.int32() == 1;
- }
- bytes(length) {
- const position = this._position;
- this.skip(length);
- return this._buffer.subarray(position, this._position);
- }
- int8() {
- const position = this._position;
- this.skip(1);
- return this._dataView.getInt8(position, true);
- }
- int16() {
- const position = this._position;
- this.skip(2);
- return this._dataView.getInt16(position, true);
- }
- int32() {
- const position = this._position;
- this.skip(4);
- return this._dataView.getInt32(position, true);
- }
- int64() {
- const position = this._position;
- this.skip(8);
- return this._dataView.getInt64(position, true).toNumber();
- }
- int64s(size) {
- const array = [];
- for (let i = 0; i < size; i++) {
- array.push(this.int64());
- }
- return array;
- }
- float32() {
- const position = this._position;
- this.skip(4);
- return this._dataView.getFloat32(position, true);
- }
- float64() {
- const position = this._position;
- this.skip(8);
- return this._dataView.getFloat64(position, true);
- }
- string() {
- return this._textDecoder.decode(this.bytes(this.int32()));
- }
- storage(size, itemSize) {
- return new torch.BinaryReader(this.bytes(size * itemSize));
- }
- };
- torch.TextReader = class {
- constructor(data, separator) {
- this._buffer = data instanceof Uint8Array ? data : data.peek();
- this._position = 0;
- this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
- this._textDecoder = new TextDecoder('ascii');
- this._separator = separator || 0x0a;
- }
- reset() {
- this._position = 0;
- }
- line(size) {
- const start = this._position;
- while (this._position < this._buffer.length && size > -1) {
- const c = this._buffer[this._position++];
- if (c == this._separator) {
- return this._buffer.slice(start, this._position - 1);
- }
- else if (this._position == this._buffer.length) {
- return this._buffer.slice(start, this._position);
- }
- size--;
- }
- throw new torch.Error('Line exceeded maximum length.');
- }
- boolean() {
- return this.int32() == 1;
- }
- bytes(size) {
- return this.line(size);
- }
- int8() {
- return this.int64();
- }
- int16() {
- return this.int64();
- }
- int32() {
- return this.int64();
- }
- int64() {
- const token = this._textDecoder.decode(this.line(20));
- const number = Number.parseInt(token, 10);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse int64 '" + token + "'.");
- }
- return number;
- }
- int64s(size) {
- const array = [];
- if (size > 0) {
- const content = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
- for (const token of content.split(' ')) {
- const number = Number.parseInt(token, 10);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse int64 '" + token + "'.");
- }
- array.push(number);
- }
- }
- return array;
- }
- float32() {
- return this.float64();
- }
- float64() {
- const token = this._textDecoder.decode(this.line(24));
- if (token.startsWith('-nan')) {
- return -NaN;
- }
- if (token.startsWith('nan')) {
- return NaN;
- }
- if (token.startsWith('inf')) {
- return Infinity;
- }
- if (token.startsWith('-inf')) {
- return -Infinity;
- }
- const number = Number.parseFloat(token);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse float '" + token + "'.");
- }
- return number;
- }
- string() {
- const size = this.int32();
- if (size == 0) {
- return '';
- }
- const data = this.line(size);
- const content = this._textDecoder.decode(data);
- if (size != content.length) {
- throw new torch.Error('Invalid string length.');
- }
- return content;
- }
- storage(size, itemSize, dataType) {
- if (size <= 0) {
- throw new torch.Error("Unsupported storage size '" + size + "'.");
- }
- if (dataType === 'uint8') {
- const start = this._position;
- this._position += size;
- const bytes = this._buffer.slice(start, this._position);
- this.line(0);
- return new torch.BinaryReader(bytes);
- }
- const data = this.line(Number.MAX_SAFE_INTEGER);
- return new torch.TextReader(data, 0x20);
- }
- };
- if (typeof module !== 'undefined' && typeof module.exports === 'object') {
- module.exports.ModelFactory = torch.ModelFactory;
- }
|