| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245 |
- /* jshint esversion: 6 */
- var torch = torch || {};
- torch.ModelFactory = class {
- match(context) {
- const extension = context.identifier.split('.').pop().toLowerCase();
- if (extension == 't7') {
- const stream = context.stream;
- if (stream.length >= 1 && stream.peek(1)[0] <= 58) {
- return true;
- }
- }
- return false;
- }
- open(context) {
- return torch.Metadata.open(context).then((metadata) => {
- const identifier = context.identifier;
- const buffer = context.stream.peek();
- const reader = new torch.T7Reader(buffer, (name) => {
- if (name && name != 'nn.JointTrainModule' && !name.startsWith('nn.MSDNet_') && !name.startsWith('onmt.')) {
- context.exception(new torch.Error("Unknown type '" + name + "' in '" + identifier + "'."), false);
- }
- return null;
- });
- let root = reader.read();
- if (root && Array.isArray(root) && root.length == 2 && root[0].__type__ && !root[1].__type__) {
- root = root[0];
- }
- return new torch.Model(metadata, root);
- });
- }
- };
- torch.Model = class {
- constructor(metadata, root) {
- this._graphs = [];
- this._graphs.push(new torch.Graph(metadata, root));
- }
- get graphs() {
- return this._graphs;
- }
- get format() {
- return 'Torch v7';
- }
- };
- torch.Graph = class {
- constructor(metadata, root) {
- this._inputs = [];
- this._outputs = [];
- this._nodes = [];
- this._groups = 'false';
- if (Object.prototype.hasOwnProperty.call(root, 'model')) {
- root = root.model;
- }
- const inputs = [];
- const outputs = [];
- this._loadModule(metadata, root, [], '', inputs, outputs);
- this._inputs = this._inputs.concat(inputs.map((input, index) => {
- return new torch.Parameter('input' + (index != 0 ? (index + 1).toString() : ''), true, [ input ]);
- }));
- this._outputs = this._outputs.concat(outputs.map((output, index) => {
- return new torch.Parameter('output' + (index != 0 ? (index + 1).toString() : ''), true, [ output ]);
- }));
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- get nodes() {
- return this._nodes;
- }
- get groups() {
- return this._groups;
- }
- _loadModule(metadata, module, groups, key, inputs, outputs) {
- if (groups.length > 0) {
- this._groups = true;
- }
- switch (module.__type__) {
- case 'nn.Sequential': {
- groups.push(key);
- let subInputs = inputs;
- let subOutputs = [];
- const length = module.modules.length;
- let index = 0;
- for (const subModule of module.modules) {
- if (index == length - 1) {
- subOutputs = outputs;
- }
- this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
- subInputs = subOutputs;
- subOutputs = [];
- index++;
- }
- groups.pop();
- break;
- }
- case 'nn.Parallel':
- case 'nn.ParallelTable':
- case 'nn.JointTrain': {
- groups.push(key);
- let newInputs = [];
- let newOutputs = [];
- let index = 0;
- for (const subModule of module.modules) {
- const subInputs = [].concat(inputs);
- const subOutputs = [].concat(outputs);
- this._loadModule(metadata, subModule, groups, index.toString(), subInputs, subOutputs);
- if (inputs.length == 0) {
- newInputs = newInputs.concat(subInputs);
- }
- if (outputs.length == 0) {
- newOutputs = newOutputs.concat(subOutputs);
- }
- index++;
- }
- inputs = inputs.concat(newInputs);
- for (const newOutput of newOutputs) {
- outputs.push(newOutput);
- }
- groups.pop();
- break;
- }
- case 'nn.Concat':
- case 'nn.ConcatTable': {
- const prefix = key;
- if (inputs.length == 0) {
- inputs.push(new torch.Argument(groups.join('/') + ':' + key + ':in', null, null));
- }
- let concatInputs = [];
- let index = 0;
- for (const subModule of module.modules) {
- const streamInputs = inputs.map((input) => input);
- const streamOutputs = [];
- this._loadModule(metadata, subModule, groups, prefix + '.' + index.toString(), streamInputs, streamOutputs);
- concatInputs = concatInputs.concat(streamOutputs);
- index++;
- }
- delete module.modules;
- delete module.dimension;
- this._createNode(metadata, module, groups, key, concatInputs, outputs);
- break;
- }
- case 'nn.Inception': {
- delete module.modules; // TODO
- delete module.module; // TODO
- delete module.transfer; // TODO
- delete module.pool; // TODO
- this._createNode(metadata, module, groups, key, inputs, outputs);
- break;
- }
- default: {
- this._createNode(metadata, module, groups, key, inputs, outputs);
- break;
- }
- }
- }
- _createNode(metadata, module, group, subIndex, inputs, outputs) {
- this._nodes.push(new torch.Node(metadata, module, group, subIndex, inputs, outputs));
- }
- };
- torch.Parameter = class {
- constructor(name, visible, args) {
- this._name = name;
- this._visible = visible;
- this._arguments = args;
- }
- get name() {
- return this._name;
- }
- get visible() {
- return this._visible;
- }
- get arguments() {
- return this._arguments;
- }
- };
- torch.Argument = class {
- constructor(name, type, initializer) {
- if (typeof name !== 'string') {
- throw new torch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
- }
- this._name = name;
- this._type = type;
- this._initializer = initializer;
- }
- get name() {
- return this._name;
- }
- get type() {
- if (this._initializer) {
- return this._initializer.type;
- }
- return this._type;
- }
- get initializer() {
- return this._initializer;
- }
- };
- torch.Node = class {
- constructor(metadata, module, groups, name, inputs, outputs) {
- this._group = groups.join('/');
- if (module.name && typeof module.name === 'string') {
- this._name = module.name;
- delete module.name;
- }
- else {
- this._name = this._group ? (this._group + ':' + name) : name;
- }
- const type = module.__type__ || 'nn.Module';
- this._type = metadata.type(type);
- let initializers = [];
- for (const key of Object.keys(module)) {
- const obj = module[key];
- if (obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Storage')) {
- module[key] = obj.data();
- }
- }
- delete module.iSize;
- delete module.finput;
- delete module.fgradInput;
- delete module.output;
- delete module.gradInput;
- delete module.gradWeight;
- delete module.gradBias;
- delete module.grad_tmp;
- delete module.scaleT;
- delete module._input;
- delete module._output;
- delete module._gradInput;
- delete module._gradOutput;
- delete module.buffer;
- delete module.buffer2;
- delete module.tmp_in;
- delete module.tmp_out;
- delete module.accUpdateGradParameters;
- switch (this._type.name) {
- case 'nn.Linear':
- delete module.addBuffer;
- break;
- case 'nn.Normalize':
- case 'nn.Normalize2':
- delete module.addBuffer;
- delete module.normp;
- delete module.norm;
- break;
- case 'cudnn.SpatialConvolution':
- case 'cudnn.SpatialFullConvolution':
- case 'nn.SpatialConvolution':
- case 'nn.SpatialConvolutionMM':
- case 'nn.SpatialDilatedConvolution':
- case 'nn.SpatialFullConvolution':
- delete module.ones;
- delete module.input_slice;
- delete module.output_slice;
- delete module.convDescData;
- this._updateSize(module, 'adj');
- this._updateSize(module, 'd');
- this._updateSize(module, 'dilation');
- this._updateSize(module, 'k');
- this._updateSize(module, 'pad');
- break;
- case 'cudnn.BatchNormalization':
- case 'cudnn.SpatialBatchNormalization':
- case 'nn.BatchNormalization':
- case 'nn.SpatialBatchNormalization':
- case 'nn.InstanceNormalization':
- delete module.save_mean;
- delete module.save_std;
- delete module.gradWeight;
- delete module.normalized;
- delete module.centered;
- delete module.bn; // TODO InstanceNormalization
- break;
- case 'nn.SpatialCrossMapLRN':
- delete module.scale;
- break;
- case 'cudnn.SpatialMaxPooling':
- case 'cudnn.SpatialAveragePooling':
- case 'inn.SpatialMaxPooling':
- case 'nn.SpatialMaxPooling':
- case 'nn.SpatialAveragePooling':
- delete module.indices;
- this._updateSize(module, 'pad');
- this._updateSize(module, 'd');
- this._updateSize(module, 'k');
- break;
- case 'nn.SpatialZeroPadding':
- case 'nn.SpatialReflectionPadding':
- case 'nn.SpatialReplicationPadding':
- this._updateBox(module, 'pad');
- break;
- case 'nn.Dropout':
- delete module.noise;
- break;
- case 'nn.gModule':
- delete module.forwardnodes;
- delete module.backwardnodes;
- break;
- case 'nn.StereoJoin':
- delete module.output_L;
- break;
- }
- this._attributes = [];
- if (module.__type__) {
- for (const key of Object.keys(module)) {
- if (key == '__type__' || key == '_type') {
- continue;
- }
- const obj = module[key];
- if (Array.isArray(obj) && obj.every(((item) => item && item.__type__ && item.__type__.startsWith('nn.')))) {
- continue;
- }
- if (obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Tensor')) {
- initializers.push(new torch.Parameter(key, true, [
- new torch.Argument(key, null, new torch.Tensor(obj))
- ]));
- continue;
- }
- if (key == 'modules' || (obj.__type__ && obj.__type__ != 'Function')) {
- continue;
- }
- this._attributes.push(new torch.Attribute(metadata, type, key, obj));
- }
- }
- this._inputs = [];
- if (inputs.length == 0 && this._name) {
- inputs.push(new torch.Argument(this._name + ':in', null, null));
- }
- this._inputs.push(new torch.Parameter('input', true, inputs));
- if (outputs.length == 0 && this._name) {
- outputs.push(new torch.Argument(this._name, null, null));
- }
- this._outputs = [];
- this._outputs.push(new torch.Parameter('output', true, outputs));
- initializers = initializers.filter((argument) => {
- if (argument.name == 'weight') {
- this._inputs.push(argument);
- return false;
- }
- return true;
- });
- initializers = initializers.filter((argument) => {
- if (argument.name == 'bias') {
- this._inputs.push(argument);
- return false;
- }
- return true;
- });
- this._inputs = this._inputs.concat(initializers);
- }
- get name() {
- return this._name;
- }
- get type() {
- return this._type;
- }
- get group() {
- return this._group;
- }
- get attributes() {
- return this._attributes;
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- _updateSize(module, name) {
- if (Object.prototype.hasOwnProperty.call(module, name + 'W') &&
- Object.prototype.hasOwnProperty.call(module, name + 'H')) {
- module[name] = [ module[name + 'W'], module[name + 'H'] ];
- delete module[name + 'W'];
- delete module[name + 'H'];
- }
- }
- _updateBox(module, name) {
- if (Object.prototype.hasOwnProperty.call(module, name + '_t') &&
- Object.prototype.hasOwnProperty.call(module, name + '_r') &&
- Object.prototype.hasOwnProperty.call(module, name + '_b') &&
- Object.prototype.hasOwnProperty.call(module, name + '_l')) {
- module[name] = [ module[name + '_t'], module[name + '_r'], module[name + '_b'], module[name + '_l'] ];
- delete module[name + '_t'];
- delete module[name + '_r'];
- delete module[name + '_b'];
- delete module[name + '_l'];
- }
- }
- };
- torch.Attribute = class {
- constructor(metadata, type, name, value) {
- this._name = name;
- this._value = value;
- if (name == 'train') {
- this._visible = false;
- }
- const schema = metadata.attribute(type, name);
- if (schema) {
- if (Object.prototype.hasOwnProperty.call(schema, 'visible')) {
- this._visible = schema.visible;
- }
- else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
- if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
- this._visible = false;
- }
- }
- }
- }
- get name() {
- return this._name;
- }
- get value() {
- return this._value;
- }
- get visible() {
- return this._visible == false ? false : true;
- }
- };
- torch.Tensor = class {
- constructor(tensor) {
- this._type = new torch.TensorType(tensor);
- this._storage = tensor.storage;
- this._offset = tensor.storage_offset;
- }
- get type() {
- return this._type;
- }
- get state() {
- return this._context().state || null;
- }
- get value() {
- const context = this._context();
- if (context.state) {
- return null;
- }
- context.limit = Number.MAX_SAFE_INTEGER;
- return this._decode(context, 0);
- }
- toString() {
- const context = this._context();
- if (context.state) {
- return '';
- }
- context.limit = 1000;
- const value = this._decode(context, 0);
- return JSON.stringify(value, null, 4);
- }
- _context() {
- const context = {};
- context.state = null;
- context.index = 0;
- context.count = 0;
- if (!this._storage) {
- context.state = 'Tensor data is empty.';
- return context;
- }
- context.data = this._storage.data();
- context.index = this._offset;
- if (!context.data) {
- context.state = 'Tensor data is empty.';
- return context;
- }
- switch (this._type.dataType) {
- case 'uint8':
- case 'int8':
- case 'int16':
- case 'int32':
- case 'int64':
- case 'float32':
- case 'float64':
- break;
- default:
- context.state = 'Tensor data type is not implemented.';
- break;
- }
- context.dimensions = this._type.shape.dimensions;
- if (!context.dimensions && context.dimensions.length == 0) {
- context.state = 'Tensor has no dimensions.';
- return context;
- }
- return context;
- }
- _decode(context, dimension) {
- const results = [];
- const size = context.dimensions[dimension];
- if (dimension == context.dimensions.length - 1) {
- for (let i = 0; i < size; i++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- results.push(context.data[context.index]);
- context.index++;
- context.count++;
- }
- }
- else {
- for (let j = 0; j < size; j++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- results.push(this._decode(context, dimension + 1));
- }
- }
- return results;
- }
- };
- torch.TensorType = class {
- constructor(tensor) {
- this._dataType = tensor.dataType;
- this._shape = new torch.TensorShape(tensor.size);
- }
- get dataType() {
- return this._dataType;
- }
- get shape() {
- return this._shape;
- }
- toString() {
- return (this.dataType || '?') + this._shape.toString();
- }
- };
- torch.TensorShape = class {
- constructor(dimensions) {
- this._dimensions = dimensions;
- }
- get dimensions() {
- return this._dimensions;
- }
- toString() {
- if (this._dimensions) {
- if (this._dimensions.length == 0) {
- return '';
- }
- return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
- }
- return '';
- }
- };
- torch.Metadata = class {
- static open(context) {
- if (torch.Metadata._metadata) {
- return Promise.resolve(torch.Metadata._metadata);
- }
- return context.request('torch-metadata.json', 'utf-8', null).then((data) => {
- torch.Metadata._metadata = new torch.Metadata(data);
- return torch.Metadata._metadata;
- }).catch(() => {
- torch.Metadata._metadata = new torch.Metadata(null);
- return torch.Metadata._metadata;
- });
- }
- constructor(data) {
- this._map = new Map();
- this._attributeCache = {};
- if (data) {
- const items = JSON.parse(data);
- for (const item of items) {
- this._map.set(item.name, item);
- }
- }
- }
- type(name) {
- if (!this._map.has(name)) {
- this._map.set(name, { name: name });
- }
- return this._map.get(name);
- }
- attribute(type, name) {
- let map = this._attributeCache[type];
- if (!map) {
- map = {};
- const schema = this.type(type);
- if (schema && schema.attributes && schema.attributes.length > 0) {
- for (const attribute of schema.attributes) {
- map[attribute.name] = attribute;
- }
- }
- this._attributeCache[type] = map;
- }
- return map[name] || null;
- }
- };
- torch.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading Torch model.';
- }
- };
- torch.T7Reader = class {
- constructor(buffer, callback) {
- this._callback = callback;
- this._memo = new Map();
- this._registry = {};
- this._registry['bnn.Binary'] = function(reader) { reader.nn(this); };
- this._registry['bnn.SpatialConvolution'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.BatchNormalization'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.BatchBRNNReLU'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.BLSTM'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.ReLU'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.RNN'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.Sigmoid'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SoftMax'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.LogSoftMax'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialAveragePooling'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialBatchNormalization'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialConvolution'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialFullConvolution'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.SpatialSoftMax'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.Tanh'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.VolumetricAveragePooling'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.VolumetricBatchNormalization'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.VolumetricConvolution'] = function(reader) { reader.nn(this); };
- this._registry['cudnn.VolumetricMaxPooling'] = function(reader) { reader.nn(this); };
- this._registry['Dict'] = function(reader) { reader.nn(this); };
- this._registry['inn.ConstAffine'] = function(reader) { reader.nn(this); };
- this._registry['inn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.Abs'] = function(reader) { reader.nn(this); };
- this._registry['nn.AddConstant'] = function(reader) { reader.nn(this); };
- this._registry['nn.BatchNormalization'] = function(reader) { reader.nn(this); };
- this._registry['nn.BilinearSamplerBHWD'] = function(reader) { reader.nn(this); };
- this._registry['nn.BinActiveZ'] = function(reader) { reader.nn(this); }; // allenai/XNOR-Net
- this._registry['nn.BCECriterion'] = function(reader) { reader.nn(this); };
- this._registry['nn.Bottle'] = function(reader) { reader.nn(this); };
- this._registry['nn.Clamp'] = function(reader) { reader.nn(this); };
- this._registry['nn.CMul'] = function(reader) { reader.nn(this); };
- this._registry['nn.CAddTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.CDivTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.CMulTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.CSubTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Concat'] = function(reader) { reader.nn(this); };
- this._registry['nn.Copy'] = function(reader) { reader.nn(this); };
- this._registry['nn.ConcatTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Contiguous'] = function(reader) { reader.nn(this); };
- this._registry['nn.Constant'] = function(reader) { reader.nn(this); };
- this._registry['nn.CostVolMulti'] = function(reader) { reader.nn(this); };
- this._registry['nn.DataParallelTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.DepthConcat'] = function(reader) { reader.nn(this); };
- this._registry['nn.Dropout'] = function(reader) { reader.nn(this); };
- this._registry['nn.Exp'] = function(reader) { reader.nn(this); };
- this._registry['nn.ExpOut'] = function(reader) { reader.nn(this); };
- this._registry['nn.FlattenTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.GenNoise'] = function(reader) { reader.nn(this); };
- this._registry['nn.Identity'] = function(reader) { reader.nn(this); };
- this._registry['nn.Index'] = function(reader) { reader.nn(this); };
- this._registry['nn.Inception'] = function(reader) { reader.nn(this); };
- this._registry['nn.InstanceNormalization'] = function(reader) { reader.nn(this); };
- this._registry['nn.JoinTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.JointTrain'] = function(reader) { reader.nn(this); };
- this._registry['nn.KeypointCoordinate'] = function(reader) { reader.nn(this); };
- this._registry['nn.LeakyReLU'] = function(reader) { reader.nn(this); };
- this._registry['nn.Linear'] = function(reader) { reader.nn(this); };
- this._registry['nn.LinearNoBias'] = function(reader) { reader.nn(this); };
- this._registry['nn.LogSoftMax'] = function(reader) { reader.nn(this); };
- this._registry['nn.LookupTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.LSTM'] = function(reader) { reader.nn(this); };
- this._registry['nn.MaskZero'] = function(reader) { reader.nn(this); };
- this._registry['nn.MapTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Max'] = function(reader) { reader.nn(this); };
- this._registry['nn.Mean'] = function(reader) { reader.nn(this); };
- this._registry['nn.Min'] = function(reader) { reader.nn(this); };
- this._registry['nn.MulConstant'] = function(reader) { reader.nn(this); };
- this._registry['nn.MM'] = function(reader) { reader.nn(this); };
- this._registry['nn.MSECriterion'] = function(reader) { reader.nn(this); };
- this._registry['nn.Narrow'] = function(reader) { reader.nn(this); };
- this._registry['nn.NarrowTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Normalize'] = function(reader) { reader.nn(this); };
- this._registry['nn.Normalize2'] = function(reader) { reader.nn(this); };
- this._registry['nn.NoiseFill'] = function(reader) { reader.nn(this); };
- this._registry['nn.Padding'] = function(reader) { reader.nn(this); };
- this._registry['nn.Parallel'] = function(reader) { reader.nn(this); };
- this._registry['nn.ParallelCriterion'] = function(reader) { reader.nn(this); };
- this._registry['nn.ParallelTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.PixelShuffle'] = function(reader) { reader.nn(this); };
- this._registry['nn.Power'] = function(reader) { reader.nn(this); };
- this._registry['nn.PReLU'] = function(reader) { reader.nn(this); };
- this._registry['nn.Recursor'] = function(reader) { reader.nn(this); };
- this._registry['nn.ReLU'] = function(reader) { reader.nn(this); };
- this._registry['nn.Replicate'] = function(reader) { reader.nn(this); };
- this._registry['nn.Reshape'] = function(reader) { reader.nn(this); };
- this._registry['nn.ShaveImage'] = function(reader) { reader.nn(this); };
- this._registry['nn.Select'] = function(reader) { reader.nn(this); };
- this._registry['nn.SelectTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Sequencer'] = function(reader) { reader.nn(this); };
- this._registry['nn.Sequential'] = function(reader) { reader.nn(this); };
- this._registry['nn.Sigmoid'] = function(reader) { reader.nn(this); };
- this._registry['nn.Sum'] = function(reader) { reader.nn(this); };
- this._registry['nn.SoftMax'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialAveragePooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialBatchNormalization'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialConvolution'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialConvolutionMM'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialCrossMapLRN'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialDilatedConvolution'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialDropout'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialFractionalMaxPooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialFullConvolution'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialLPPooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialMaxPooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialMaxUnpooling'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialReflectionPadding'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialReplicationPadding'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialSoftMax'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialSubtractiveNormalization'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialUpSamplingBilinear'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialUpSamplingNearest'] = function(reader) { reader.nn(this); };
- this._registry['nn.SpatialZeroPadding'] = function(reader) { reader.nn(this); };
- this._registry['nn.SplitTable'] = function(reader) { reader.nn(this); };
- this._registry['nn.Squeeze'] = function(reader) { reader.nn(this); };
- this._registry['nn.Square'] = function(reader) { reader.nn(this); };
- this._registry['nn.Sqrt'] = function(reader) { reader.nn(this); };
- this._registry['nn.StereoJoin'] = function(reader) { reader.nn(this); };
- this._registry['nn.Tanh'] = function(reader) { reader.nn(this); };
- this._registry['nn.Transpose'] = function(reader) { reader.nn(this); };
- this._registry['nn.TotalVariation'] = function(reader) { reader.nn(this); };
- this._registry['nn.Unpool'] = function(reader) { reader.nn(this); };
- this._registry['nn.View'] = function(reader) { reader.nn(this); };
- this._registry['nn.gModule'] = function(reader) { reader.nn(this); };
- this._registry['nngraph.Node'] = function(reader) { reader.nn(this); };
- this._registry['graph.Edge'] = function(reader) { reader.nn(this); };
- this._registry['graph.Graph'] = function(reader) { reader.nn(this); };
- this._registry['torch.ByteTensor'] = function(reader) { reader.tensor(this, 'uint8'); };
- this._registry['torch.CharTensor'] = function(reader) { reader.tensor(this, 'int8'); };
- this._registry['torch.ShortTensor'] = function(reader) { reader.tensor(this, 'int16'); };
- this._registry['torch.IntTensor'] = function(reader) { reader.tensor(this, 'int32'); };
- this._registry['torch.LongTensor'] = function(reader) { reader.tensor(this, 'int64'); };
- this._registry['torch.FloatTensor'] = function(reader) { reader.tensor(this, 'float32'); };
- this._registry['torch.DoubleTensor'] = function(reader) { reader.tensor(this, 'float64'); };
- this._registry['torch.CudaByteTensor'] = function(reader) { reader.tensor(this, 'uint8'); };
- this._registry['torch.CudaCharTensor'] = function(reader) { reader.tensor(this, 'int8'); };
- this._registry['torch.CudaShortTensor'] = function(reader) { reader.tensor(this, 'int16'); };
- this._registry['torch.CudaIntTensor'] = function(reader) { reader.tensor(this, 'int32'); };
- this._registry['torch.CudaLongTensor'] = function(reader) { reader.tensor(this, 'int64'); };
- this._registry['torch.CudaTensor'] = function(reader) { reader.tensor(this, 'float32'); };
- this._registry['torch.CudaDoubleTensor'] = function(reader) { reader.tensor(this, 'float64'); };
- this._registry['torch.ByteStorage'] = function(reader) { reader.storage(this, 'uint8', 1); };
- this._registry['torch.CharStorage'] = function(reader) { reader.storage(this, 'int8', 1); };
- this._registry['torch.ShortStorage'] = function(reader) { reader.storage(this, 'int16', 2); };
- this._registry['torch.IntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
- this._registry['torch.LongStorage'] = function(reader) { reader.storage(this, 'int64', 8); };
- this._registry['torch.FloatStorage'] = function(reader) { reader.storage(this, 'float32', 4); };
- this._registry['torch.DoubleStorage'] = function(reader) { reader.storage(this, 'float64', 8); };
- this._registry['torch.CudaByteStorage'] = function(reader) { reader.storage(this, 'uint8', 1); };
- this._registry['torch.CudaCharStorage'] = function(reader) { reader.storage(this, 'int8', 1); };
- this._registry['torch.CudaShortStorage'] = function(reader) { reader.storage(this, 'int16', 2); };
- this._registry['torch.CudaIntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
- this._registry['torch.CudaLongStorage'] = function(reader) { reader.storage(this, 'int64', 8); };
- this._registry['torch.CudaIntStorage'] = function(reader) { reader.storage(this, 'int32', 4); };
- this._registry['torch.CudaStorage'] = function(reader) { reader.storage(this, 'float32', 4); };
- this._registry['torch.CudaFloatStorage'] = function(reader) { reader.storage(this, 'float64', 8); };
- this._registry['w2nn.AuxiliaryLossTable'] = function(reader) { reader.nn(this); };
- this._registry['w2nn.InplaceClip01'] = function(reader) { reader.nn(this); };
- this._registry['w2nn.ScaleTable'] = function(reader) { reader.nn(this); };
- if (buffer.length == 0) {
- throw new torch.Error('File is empty.');
- }
- if (buffer[0] <= 8) {
- this._reader = new torch.BinaryReader(buffer);
- }
- else {
- this._reader = new torch.TextReader(buffer);
- this._reader.int32();
- this._reader.reset();
- }
- }
- read() {
- const type = this.int32();
- switch (type) {
- case 0: return null;
- case 1: return this.float64();
- case 2: return this.string();
- case 3: return this.table();
- case 4: return this.object();
- case 5: return this.boolean();
- case 6: return this.function();
- case 7: return this.function();
- case 8: return this.function();
- default: throw new torch.Error("File format has invalid type '" + type + "'.");
- }
- }
- boolean() {
- return this._reader.boolean();
- }
- bytes(size) {
- return this._reader.bytes(size);
- }
- int32() {
- return this._reader.int32();
- }
- int64() {
- return this._reader.int64();
- }
- int64s(size) {
- return this._reader.int64s(size);
- }
- float64() {
- return this._reader.float64();
- }
- string() {
- return this._reader.string();
- }
- object() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- let version = this.string();
- let name = null;
- if (version.startsWith('V ')) {
- name = this.string();
- version = Number(version.split(' ')[1]);
- }
- else {
- name = version;
- version = 0;
- }
- const obj = { __type__: name };
- this._memo.set(index, obj);
- let constructor = this._registry[name];
- if (constructor) {
- constructor.apply(obj, [ this, version ]);
- }
- else {
- constructor = this._callback(name);
- if (constructor) {
- constructor.apply(obj, [ this, version ]);
- }
- this.nn(obj);
- }
- return obj;
- }
- table() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- const table = {};
- this._memo.set(index, table);
- const size = this.int32();
- let convert = true;
- let sum = 0;
- for (let i = 0; i < size; i++) {
- const key = this.read();
- const value = this.read();
- table[key] = value;
- if (Number.isInteger(key) && key >= 0) {
- sum += key;
- }
- else {
- convert = false;
- }
- }
- const n = Object.keys(table).length;
- if (convert && (n * (n + 1)) == (2 * sum)) {
- const list = [];
- for (let j = 0; j < n; j++) {
- let item = table[j + 1];
- if (item == table) {
- item = list;
- }
- list.push(item);
- }
- this._memo.set(index, list);
- return list;
- }
- return table;
- }
- function() {
- const index = this.int32();
- if (this._memo.has(index)) {
- return this._memo.get(index);
- }
- const size = this.int32();
- const dumped = this.bytes(size);
- const upvalues = this.read();
- const func = { __type__: 'Function', size: size, dumped: dumped, upvalues: upvalues };
- this._memo.set(index, func);
- return func;
- }
- nn(obj) {
- const attributes = this.read();
- if (attributes != null) {
- for (const key of Object.keys(attributes)) {
- obj[key] = attributes[key];
- }
- }
- }
- tensor(obj, dataType) {
- const dim = this.int32();
- obj.dataType = dataType;
- obj.size = this.int64s(dim);
- obj.stride = this.int64s(dim);
- obj.storage_offset = this.int64() - 1;
- obj.storage = this.read();
- }
- storage(obj, dataType, itemSize) {
- obj.dataType = dataType;
- obj.itemSize = itemSize;
- obj.size = this.int64();
- obj.reader = this._reader.storage(obj.size, obj.itemSize, dataType);
- obj.data = function() {
- if (this.reader) {
- const reader = this.reader;
- reader.reset();
- const size = obj.size;
- const array = new Array(size);
- for (let i = 0; i < size; i++) {
- switch (dataType) {
- case 'uint8':
- array[i] = this.reader.byte();
- break;
- case 'int8':
- array[i] = this.reader.int8();
- break;
- case 'int16':
- array[i] = this.reader.int16();
- break;
- case 'int32':
- array[i] = this.reader.int32();
- break;
- case 'int64':
- array[i] = this.reader.int64();
- break;
- case 'float32':
- array[i] = this.reader.float32();
- break;
- case 'float64':
- array[i] = this.reader.float64();
- break;
- }
- }
- obj._data = array;
- delete obj.reader;
- }
- return obj._data;
- };
- }
- };
- torch.BinaryReader = class {
- constructor(buffer) {
- this._buffer = buffer;
- this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
- this._position = 0;
- this._textDecoder = new TextDecoder('ascii');
- }
- reset() {
- this._position = 0;
- }
- skip(offset) {
- this._position += offset;
- if (this._position > this._buffer.length) {
- throw new torch.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
- }
- }
- boolean() {
- return this.int32() == 1;
- }
- bytes(length) {
- const position = this._position;
- this.skip(length);
- return this._buffer.subarray(position, this._position);
- }
- int8() {
- const position = this._position;
- this.skip(1);
- return this._dataView.getInt8(position, true);
- }
- int16() {
- const position = this._position;
- this.skip(2);
- return this._dataView.getInt16(position, true);
- }
- int32() {
- const position = this._position;
- this.skip(4);
- return this._dataView.getInt32(position, true);
- }
- int64() {
- const position = this._position;
- this.skip(8);
- return this._dataView.getInt64(position, true).toNumber();
- }
- int64s(size) {
- const array = [];
- for (let i = 0; i < size; i++) {
- array.push(this.int64());
- }
- return array;
- }
- float32() {
- const position = this._position;
- this.skip(4);
- return this._dataView.getFloat32(position, true);
- }
- float64() {
- const position = this._position;
- this.skip(8);
- return this._dataView.getFloat64(position, true);
- }
- string() {
- return this._textDecoder.decode(this.bytes(this.int32()));
- }
- storage(size, itemSize) {
- return new torch.BinaryReader(this.bytes(size * itemSize));
- }
- };
- torch.TextReader = class {
- constructor(buffer, separator) {
- this._buffer = buffer;
- this._position = 0;
- this._dataView = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
- this._textDecoder = new TextDecoder('ascii');
- this._separator = separator || 0x0a;
- }
- reset() {
- this._position = 0;
- }
- line(size) {
- const start = this._position;
- while (this._position < this._buffer.length && size > -1) {
- const c = this._buffer[this._position++];
- if (c == this._separator) {
- return this._buffer.slice(start, this._position - 1);
- }
- else if (this._position == this._buffer.length) {
- return this._buffer.slice(start, this._position);
- }
- size--;
- }
- throw new torch.Error('Line exceeded maximum length.');
- }
- boolean() {
- return this.int32() == 1;
- }
- bytes(size) {
- return this.line(size);
- }
- int8() {
- return this.int64();
- }
- int16() {
- return this.int64();
- }
- int32() {
- return this.int64();
- }
- int64() {
- const token = this._textDecoder.decode(this.line(20));
- const number = Number.parseInt(token, 10);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse int64 '" + token + "'.");
- }
- return number;
- }
- int64s(size) {
- const array = [];
- if (size > 0) {
- const text = this._textDecoder.decode(this.line(Number.MAX_SAFE_INTEGER));
- for (const token of text.split(' ')) {
- const number = Number.parseInt(token, 10);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse int64 '" + token + "'.");
- }
- array.push(number);
- }
- }
- return array;
- }
- float32() {
- return this.float64();
- }
- float64() {
- const token = this._textDecoder.decode(this.line(24));
- if (token.startsWith('-nan')) {
- return -NaN;
- }
- if (token.startsWith('nan')) {
- return NaN;
- }
- if (token.startsWith('inf')) {
- return Infinity;
- }
- if (token.startsWith('-inf')) {
- return -Infinity;
- }
- const number = Number.parseFloat(token);
- if (Number.isNaN(token - number)) {
- throw new torch.Error("Couldn't parse float '" + token + "'.");
- }
- return number;
- }
- string() {
- const size = this.int32();
- if (size == 0) {
- return '';
- }
- const data = this.line(size);
- const text = this._textDecoder.decode(data);
- if (size != text.length) {
- throw torch.Error('Invalid text length.');
- }
- return text;
- }
- storage(size, itemSize, dataType) {
- if (size <= 0) {
- throw new torch.Error("Unsupported storage size '" + size + "'.");
- }
- if (dataType === 'uint8') {
- const start = this._position;
- this._position += size;
- const bytes = this._buffer.slice(start, this._position);
- this.line(0);
- return new torch.BinaryReader(bytes);
- }
- const data = this.line(Number.MAX_SAFE_INTEGER);
- return new torch.TextReader(data, 0x20);
- }
- };
- if (typeof module !== 'undefined' && typeof module.exports === 'object') {
- module.exports.ModelFactory = torch.ModelFactory;
- }
|