| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269 |
- /* jshint esversion: 6 */
- /* eslint "indent": [ "error", 4, { "SwitchCase": 1 } ] */
- // Experimental
- var torchscript = torchscript || {};
- var base = base || require('./base');
- var long = long || { Long: require('long') };
- var marked = marked || require('marked');
- var zip = zip || require('./zip');
- torchscript.ModelFactory = class {
- match(context) {
- var identifier = context.identifier;
- var extension = identifier.split('.').pop().toLowerCase();
- if (extension == 'pt' || extension == 'pth' || extension == 'pkl' || extension == 'h5' || extension == 't7' ||
- extension == 'dms' || extension == 'model' || extension == 'ckpt' || identifier.endsWith('.pth.tar')) {
- if (torchscript.ModelFactory._openContainer(context)) {
- return true;
- }
- }
- return false;
- }
- open(context, host) {
- return host.require('./python').then((python) => {
- return host.require('./pickle').then((pickle) => {
- var identifier = context.identifier;
- try {
- var container = torchscript.ModelFactory._openContainer(context);
- if (container.attributes) {
- container.attributes = new pickle.Unpickler(container.attributes.data).load((name, args) => {
- return { type: name, args: args[0] };
- });
- }
- container.identifier = identifier;
- return torchscript.Metadata.open(host).then((metadata) => {
- try {
- return new torchscript.Model(metadata, host, python, container);
- }
- catch (error) {
- host.exception(error, false);
- var message = error && error.message ? error.message : error.toString();
- message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
- throw new torchscript.Error(message + " in '" + identifier + "'.");
- }
- });
- }
- catch (error) {
- host.exception(error, false);
- var message = error && error.message ? error.message : error.toString();
- message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
- return Promise.reject(new torchscript.Error(message + " in '" + identifier + "'."));
- }
- });
- });
- }
- static _openContainer(context) {
- let entries = context.entries;
- if (entries && entries.length > 0) {
- var container = { };
- container.version = entries.find((entry) => entry.name == 'version' || entry.name.endsWith('/version'));
- if (container.version) {
- container.prefix = container.version.name.substring(0, container.version.name.length - 7);
- container.attributes = entries.find((entry) => entry.name == container.prefix + 'attributes.pkl');
- container.model = entries.find((entry) => entry.name == container.prefix + 'model.json');
- container.entries = entries;
- if (container.version && container.model) {
- return container;
- }
- }
- }
- return null;
- }
- };
- torchscript.Model = class {
- constructor(metadata, host, python, container) {
- var textDecoder = new TextDecoder('utf-8');
- var model = JSON.parse(textDecoder.decode(container.model.data));
- var version = JSON.parse(textDecoder.decode(container.version.data));
- this._format = 'TorchScript v' + version.toString();
- if (model.producerName) {
- this._producer = model.producerName;
- if (model.producerVersion) {
- this._producer = this._producer + ' v' + model.producerVersion;
- }
- }
- this._graphs = [];
- this._graphs.push(new torchscript.Graph(metadata, host, python, container, model.mainModule, model.tensors));
- }
- get format() {
- return this._format;
- }
- get producer() {
- return this._producer;
- }
- get graphs() {
- return this._graphs;
- }
- };
- torchscript.Graph = class {
- constructor(metadata, host, python, container, mainModule, tensors) {
- this._name = mainModule.name;
- this._inputs = [];
- this._outputs = [];
- this._nodes = [];
- container.tensors = tensors.map((tensor) => new torchscript.Tensor(tensor, container));
- var context = null;
- try {
- context = new torchscript.GraphContext(container, python, mainModule);
- }
- catch (error) {
- var message = error && error.message ? error.message : error.toString();
- message = message.endsWith('.') ? message.substring(0, message.length - 1) : message;
- host.exception(new torchscript.Error(message + " in '" + container.identifier + "'."), false);
- }
- container.parameters = {};
- var queue = [ mainModule ];
- while (queue.length > 0) {
- var module = queue.shift();
- if (module.parameters) {
- for (var parameter of module.parameters) {
- if (parameter.tensorId) {
- var tensorId = parseInt(parameter.tensorId, 10);
- parameter.initializer = container.tensors[tensorId];
- if (parameter.outputs && parameter.outputs.length == 1) {
- container.parameters[parameter.outputs[0]] = parameter;
- }
- }
- }
- }
- if (module.submodules) {
- for (var submodule of module.submodules) {
- submodule.parent = module;
- queue.push(submodule);
- }
- }
- }
- if (context) {
- for (var input of context.inputs) {
- this._inputs.push(new torchscript.Parameter(input, true, [
- new torchscript.Argument(input, null, null)
- ]));
- }
- for (var output of context.outputs) {
- this._outputs.push(new torchscript.Parameter(output, true, [
- new torchscript.Argument(output, null, null)
- ]));
- }
- for (var node of context.nodes) {
- this._nodes.push(new torchscript.Node(metadata, container, null, node));
- }
- }
- this._loadModule(metadata, container, mainModule);
- }
- _loadModule(metadata, container, module) {
- if (module.parameters && module.parameters.length > 0 && !module.hide) {
- var node = new torchscript.Node(metadata, container, module, null);
- this._nodes.push(node);
- }
- if (module.submodules) {
- for (var submodule of module.submodules) {
- this._loadModule(metadata, container, submodule);
- }
- }
- }
- get type() {
- return this._type;
- }
- get name() {
- return this._name;
- }
- get groups() {
- return this._groups;
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- get nodes() {
- return this._nodes;
- }
- };
- torchscript.Parameter = class {
- constructor(name, visible, args) {
- this._name = name;
- this._visible = visible;
- this._arguments = args;
- }
- get name() {
- return this._name;
- }
- get visible() {
- return this._visible;
- }
- get arguments() {
- return this._arguments;
- }
- };
- torchscript.Argument = class {
- constructor(id, type, initializer) {
- this._id = id;
- this._type = type;
- this._initializer = initializer;
- }
- get id() {
- return this._id;
- }
- get type() {
- if (this._initializer) {
- return this._initializer.type;
- }
- return this._type;
- }
- get initializer() {
- return this._initializer;
- }
- };
- torchscript.Node = class {
- constructor(metadata, container, module, node) {
- this._metadata = metadata;
- this._attributes = [];
- this._inputs = [];
- this._outputs = [];
- var input = null;
- var argument = null;
- var parameter = null;
- if (module) {
- this._operator = 'Module';
- if (module.parameters) {
- for (parameter of module.parameters) {
- this._inputs.push(new torchscript.Parameter(parameter.name, true, [
- new torchscript.Argument('', null, parameter.initializer || null)
- ]));
- if (parameter.outputs) {
- this._outputs.push(new torchscript.Parameter(parameter.name, true,
- parameter.outputs.map((id) => new torchscript.Argument(id, null, null))
- ));
- }
- }
- }
- }
- if (node) {
- this._operator = node.name;
- this._name = '';
- var schema = metadata.getSchema(this._operator);
- module = null;
- var match = true;
- var count = 0;
- for (input of node.inputs) {
- for (argument of input) {
- parameter = container.parameters[argument.id];
- if (parameter) {
- if (parameter.module && (module == null || module == parameter.module)) {
- module = parameter.module;
- count++;
- }
- else {
- match = false;
- break;
- }
- }
- }
- if (!match) {
- break;
- }
- }
- if (module && module.parameters.length == count && match) {
- module.hide = true;
- for (input of node.inputs) {
- for (argument of input) {
- parameter = container.parameters[argument.id];
- if (parameter && parameter.initializer) {
- argument.initializer = parameter.initializer;
- }
- }
- }
- }
- else {
- module = null;
- }
- for (var inputIndex = 0; inputIndex < node.inputs.length; inputIndex++) {
- var inputName = inputIndex.toString();
- if (schema && schema.inputs && schema.inputs.length > inputIndex) {
- inputName = schema.inputs[inputIndex].name;
- }
- this._inputs.push(new torchscript.Parameter(inputName, true,
- node.inputs[inputIndex].map((input) => new torchscript.Argument(input.id, null, input.initializer || null))
- ));
- }
- for (var outputIndex = 0; outputIndex < node.outputs.length; outputIndex++) {
- var outputName = outputIndex.toString();
- if (schema && schema.outputs && schema.outputs.length > outputIndex) {
- outputName = schema.outputs[outputIndex].name;
- }
- this._outputs.push(new torchscript.Parameter(outputName, true, [
- new torchscript.Argument(node.outputs[outputIndex], null, null)
- ]));
- }
- for (var attributeIndex = 0; attributeIndex < node.attributes.length; attributeIndex++) {
- var attributeSchema = null;
- var attributeName = attributeIndex.toString();
- var attributeValue = node.attributes[attributeIndex];
- if (attributeValue && attributeValue.type === '=' && attributeValue.target.type == 'identifier') {
- attributeName = attributeValue.target.value;
- attributeValue = attributeValue.expression;
- if (schema && schema.attributes) {
- attributeSchema = schema.attributes.find((s) => s.name == attributeName);
- }
- }
- else {
- if (schema && schema.attributes && schema.attributes.length > attributeIndex) {
- attributeSchema = schema.attributes[attributeIndex];
- attributeName = attributeSchema.name;
- }
- }
- this._attributes.push(new torchscript.Attribute(this, attributeSchema, attributeName, attributeValue));
- }
- }
-
- if (module) {
- if (module.name) {
- var current = module;
- this._name = current.name;
- while (current.parent != null) {
- current = current.parent;
- this._name = [ current.name, this._name ].join('.')
- }
- }
- }
- }
- get name() {
- return this._name;
- }
- get group() {
- return this._group;
- }
- get operator() {
- return this._operator;
- }
- get category() {
- var schema = this._metadata.getSchema(this._operator);
- return (schema && schema.category) ? schema.category : '';
- }
- get documentation() {
- var schema = this._metadata.getSchema(this._operator);
- if (schema) {
- schema = JSON.parse(JSON.stringify(schema));
- schema.name = this._operator;
- if (schema.description) {
- schema.description = marked(schema.description);
- }
- if (schema.attributes) {
- for (var attribute of schema.attributes) {
- if (attribute.description) {
- attribute.description = marked(attribute.description);
- }
- }
- }
- if (schema.inputs) {
- for (var input of schema.inputs) {
- if (input.description) {
- input.description = marked(input.description);
- }
- }
- }
- if (schema.outputs) {
- for (var output of schema.outputs) {
- if (output.description) {
- output.description = marked(output.description);
- }
- }
- }
- return schema;
- }
- return '';
- }
- get function() {
- return false;
- }
- get attributes() {
- return this._attributes;
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- };
- torchscript.Attribute = class {
- constructor(node, schema, name, value) {
- this._node = node;
- this._name = name;
- this._value = value;
- if (value && value.type) {
- switch (value.type) {
- case 'number':
- this._value = value.value;
- break;
- case 'string':
- this._value = value.value;
- break;
- case 'boolean':
- this._value = value.value;
- break;
- case 'identifier':
- this._value = value.value;
- break;
- }
- }
- if (schema) {
- if (Object.prototype.hasOwnProperty.call(schema, 'type')) {
- this._type = schema.type;
- }
- switch (this._type) {
- case 'boolean':
- if (this._value == 'False') {
- this._value = false;
- }
- else if (this._value == 'True') {
- this._value = true;
- }
- break;
- case 'int32':
- case 'int64':
- this._value = parseInt(this._value, 10);
- break;
- case 'float32':
- case 'float64':
- this._value = parseFloat(this._value);
- break;
- case 'int32[]':
- case 'int64[]':
- if (this._value.type == 'list' && this._value.value.every((item) => item.type === 'number')) {
- this._value = this._value.value.map((item) => {
- var number = parseInt(item.value, 10);
- if (!Number.isNaN(item.value - number)) {
- return number;
- }
- return item.value;
- });
- }
- break;
- }
- if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
- this._visible = false;
- }
- else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
- if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
- this._visible = false;
- }
- else if (Array.isArray(this._value) &&
- !Array.isArray(schema.default) &&
- this.value.every((item) => item == schema.default)) {
- this._visible = false;
- }
- }
- }
- }
- get type() {
- return this._type;
- }
- get name() {
- return this._name;
- }
- get value() {
- return this._value;
- }
- get visible() {
- return (this._visible == false || this.name == 'training') ? false : true;
- }
- };
- torchscript.Tensor = class {
- constructor(tensor, container) {
- this._type = new torchscript.TensorType(tensor.dataType, new torchscript.TensorShape(tensor.dims));
- var key = container.prefix + tensor.data.key;
- var entry = container.entries.find((entry) => entry.name == key);
- this._name = tensor.data.key;
- this._data = entry.data;
- this._littleEndian = true;
- }
- get kind() {
- return 'Tensor';
- }
- get name() {
- return this._name;
- }
- get type() {
- return this._type;
- }
- get state() {
- return this._context().state;
- }
- get value() {
- var context = this._context();
- if (context.state) {
- return null;
- }
- context.limit = Number.MAX_SAFE_INTEGER;
- return this._decode(context, 0);
- }
- toString() {
- var context = this._context();
- if (context.state) {
- return '';
- }
- context.limit = 10000;
- var value = this._decode(context, 0);
- return torchscript.Tensor._stringify(value, '', ' ');
- }
- _context() {
- var context = {};
- context.state = null;
- context.index = 0;
- context.count = 0;
- if (!this._type.dataType) {
- context.state = 'Tensor has no data type.';
- return context;
- }
- if (!this._type.shape) {
- context.state = 'Tensor has no dimensions.';
- return context;
- }
- if (!this._data) {
- context.state = 'Tensor data is empty.';
- return context;
- }
- context.data = this._data;
- context.dataType = this._type.dataType;
- context.dimensions = this._type.shape.dimensions;
- context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
- return context;
- }
- _decode(context, dimension) {
- var results = [];
- var dimensions = context.dimensions;
- if (dimensions.length == 0) {
- dimensions = [ 1 ];
- }
- var size = dimensions[dimension];
- if (dimension == dimensions.length - 1) {
- for (var i = 0; i < size; i++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- switch (context.dataType)
- {
- case 'uint8':
- results.push(context.dataView.getUint8(context.index, this._littleEndian));
- context.index++;
- context.count++;
- break;
- case 'int8':
- results.push(context.dataView.getInt8(context.index, this._littleEndian));
- context.index++;
- context.count++;
- break;
- case 'int16':
- results.push(context.dataView.getInt16(context.index, this._littleEndian));
- context.index += 2;
- context.count++;
- break;
- case 'int32':
- results.push(context.dataView.getInt32(context.index, this._littleEndian));
- context.index += 4;
- context.count++;
- break;
- case 'int64':
- results.push(new long.Long(context.dataView.getUint32(context.index, true), context.dataView.getUint32(context.index + 4, true), false));
- context.index += 8;
- context.count++;
- break;
- case 'float16':
- results.push(context.dataView.getFloat16(context.index, this._littleEndian));
- context.index += 2;
- context.count++;
- break;
- case 'float32':
- results.push(context.dataView.getFloat32(context.index, this._littleEndian));
- context.index += 4;
- context.count++;
- break;
- case 'float64':
- results.push(context.dataView.getFloat64(context.index, this._littleEndian));
- context.index += 8;
- context.count++;
- break;
- }
- }
- }
- else {
- for (var j = 0; j < size; j++) {
- if (context.count > context.limit) {
- results.push('...');
- return results;
- }
- results.push(this._decode(context, dimension + 1));
- }
- }
- if (context.dimensions.length == 0) {
- return results[0];
- }
- return results;
- }
- static _stringify(value, indentation, indent) {
- if (Array.isArray(value)) {
- var result = [];
- result.push(indentation + '[');
- var items = value.map((item) => torchscript.Tensor._stringify(item, indentation + indent, indent));
- if (items.length > 0) {
- result.push(items.join(',\n'));
- }
- result.push(indentation + ']');
- return result.join('\n');
- }
- if (value && long.Long.isLong(value)) {
- return indentation + value.toString();
- }
- if (typeof value == 'string') {
- return indentation + value;
- }
- if (value == Infinity) {
- return indentation + 'Infinity';
- }
- if (value == -Infinity) {
- return indentation + '-Infinity';
- }
- if (isNaN(value)) {
- return indentation + 'NaN';
- }
- return indentation + value.toString();
- }
- };
- torchscript.TensorType = class {
- constructor(dataType, shape) {
- switch(dataType) {
- case 'FLOAT': this._dataType = 'float32'; break;
- case 'DOUBLE': this._dataType = 'float64'; break;
- case 'INT32': this._dataType = 'int32'; break;
- case 'INT64': this._dataType = 'int64'; break;
- default: throw new torchscript.Error("Unknown tensor data type '" + dataType + "'.");
- }
- this._shape = shape;
- }
- get dataType() {
- return this._dataType;
- }
- get shape() {
- return this._shape;
- }
- toString() {
- return this._dataType + this._shape.toString();
- }
- };
- torchscript.TensorShape = class {
- constructor(dimensions) {
- this._dimensions = dimensions || [];
- }
- get dimensions() {
- return this._dimensions;
- }
- toString() {
- if (this._dimensions && this._dimensions.length > 0) {
- return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
- }
- return '';
- }
- };
- torchscript.Metadata = class {
- static open(host) {
- if (torchscript.Metadata._metadata) {
- return Promise.resolve(torchscript.Metadata._metadata);
- }
- else {
- return host.request(null, 'torchscript-metadata.json', 'utf-8').then((data) => {
- torchscript.Metadata._metadata = new torchscript.Metadata(data);
- return torchscript.Metadata._metadata;
- }).catch(() => {
- torchscript.Metadata._metadata = new torchscript.Metadata(null);
- return torchscript.Metadata._metadata;
- });
- }
- }
- constructor(data) {
- this._map = {};
- this._attributeCache = {};
- if (data) {
- var items = JSON.parse(data);
- if (items) {
- for (var item of items) {
- if (item.name && item.schema) {
- this._map[item.name] = item.schema;
- }
- }
- }
- }
- }
- getSchema(operator) {
- return this._map[operator] || null;
- }
- getAttributeSchema(operator, name) {
- var map = this._attributeCache[operator];
- if (!map) {
- map = {};
- var schema = this.getSchema(operator);
- if (schema && schema.attributes && schema.attributes.length > 0) {
- for (var attribute of schema.attributes) {
- map[attribute.name] = attribute;
- }
- }
- this._attributeCache[operator] = map;
- }
- return map[name] || null;
- }
- };
- torchscript.GraphContext = class {
- constructor(container, python, mainModule) {
- this._container = container;
- this._mainModule = mainModule;
- this._inputs = [];
- this._outputs = [];
- this._nodes = [];
- this._moduleMap = {};
- this._argumentMap = {};
- this._numToTensorMap = {};
- if (mainModule.torchscriptArena && mainModule.torchscriptArena.key) {
- var codeKey = container.prefix + mainModule.torchscriptArena.key;
- var codeEntries = container.entries.filter((e) => e.name === codeKey);
- if (codeEntries.length == 1) {
- var codeEntry = codeEntries[0];
- var textDecoder = new TextDecoder('utf-8');
- var code = textDecoder.decode(codeEntry.data);
- var reader = new python.Parser(code);
- var program = reader.parse();
- var method = program.body.find((statement) => statement.type == 'def' && statement.name == 'forward');
- if (method) {
- this._body = method.body.statements;
- var methodParameters = method.parameters;
- if (methodParameters.length > 0 && methodParameters[0].name == 'self') {
- methodParameters.shift();
- }
- for (var parameter of methodParameters) {
- this._parameter(parameter);
- }
- if (this._body.length >= 2) {
- var returnStatement = this._body[this._body.length - 1];
- var assignStatement = this._body[this._body.length - 2];
- if (returnStatement.type == 'return' &&
- returnStatement.expression.type == 'identifier' &&
- assignStatement.target.type == 'identifier' &&
- assignStatement.target.value == returnStatement.expression.value) {
- returnStatement.expression = assignStatement.expression;
- this._body.pop();
- this._body.pop();
- this._body.push(returnStatement);
- }
- }
- while (this._body.length > 0) {
- var statement = this._body.shift();
- if (this._attributeStatement(statement)) {
- continue;
- }
- if (this._moduleStatement(statement)) {
- continue;
- }
- if (this._argumentStatement(statement)) {
- continue;
- }
- if (this._nodeStatement(statement)) {
- continue;
- }
- if (this._returnStatement(statement)) {
- continue;
- }
- throw new torchscript.Error("Unknown statement '" + JSON.stringify(statement) + "'.");
- }
- }
- }
- }
- }
- get inputs() {
- return this._inputs;
- }
- get outputs() {
- return this._outputs;
- }
- get nodes() {
- return this._nodes;
- }
- _parameter(parameter) {
- var type = parameter.parameterType;
- if (type.type == 'type' && type.value == 'Tuple' && type.arguments && type.arguments.length > 0) {
- if (this._body.length > 0) {
- var statement = this._body[0];
- if (statement.expression.type == 'identifier' && statement.expression.value == parameter.name) {
- if (statement.type === '=' && statement.target.type === 'tuple') {
- for (var input of statement.target.value) {
- if (input) {
- this._inputs.push(input.value);
- }
- }
- this._body.shift();
- }
- }
- }
- }
- else {
- this._inputs.push(parameter.name);
- }
- }
- _returnStatement(statement) {
- if (statement.type == 'return') {
- var variable = this._variable();
- if (this._nodeExpression(statement.expression, variable)) {
- this._outputs.push(variable.value);
- return true;
- }
- if (statement.expression.type == 'identifier') {
- this._outputs.push(statement.expression.value);
- return true;
- }
- if (statement.expression.type == 'tuple') {
- var outputs = [];
- for (var expression of statement.expression.value) {
- variable = this._variable();
- if (this._nodeExpression(expression, variable)) {
- outputs.push(variable.value);
- continue
- }
- if (expression.type == 'identifier') {
- outputs.push(expression.value);
- continue;
- }
- return false;
- }
- this._outputs = this._outputs.concat(outputs);
- return true;
- }
- }
- return false;
- }
- _nodeExpression(expression, target) {
- if (expression.type == 'call' && (target.type == 'identifier' || target.type == 'tuple')) {
- var name = this._name(expression.target);
- var namespace = 'torch.';
- if (name.startsWith(namespace)) {
- var node = {};
- node.name = name.substring(namespace.length);
- node.inputs = [];
- node.outputs = [];
- node.attributes = [];
- var args = expression.arguments;
- while (args.length > 0) {
- var argument = args[0];
- argument = this._moduleTensor(argument);
- if (argument.type == 'identifier' && this._argumentMap[argument.value]) {
- argument = this._argumentMap[argument.value];
- delete this._argumentMap[argument.value];
- }
- if (argument.type == 'identifier') {
- if (argument.value === 'False' || argument.value === 'True') {
- break;
- }
- node.inputs.push([ { id: argument.value } ]);
- args.shift();
- continue;
- }
- if (argument.type == 'list') {
- var list = [];
- for (var input of argument.value) {
- var variable = this._variable();
- if (this._nodeExpression(input, variable)) {
- list.push({ id: variable.value });
- }
- else if (this._argumentExpression(input, variable)) {
- list.push({ id: variable.value });
- }
- else if (input.type == 'identifier') {
- list.push({ id: input.value });
- }
- else {
- list = null;
- break;
- }
- }
- if (list) {
- node.inputs.push(list);
- args.shift();
- continue;
- }
- }
- if (argument.type == 'list') {
- break;
- }
- if (argument.type == 'number' || argument.type == 'string' || argument.type == 'boolean') {
- break;
- }
- if (argument.type == '=') {
- break;
- }
- variable = this._variable();
- if (this._nodeExpression(argument, variable)) {
- node.inputs.push([ { id: variable.value } ]);
- args.shift();
- continue;
- }
- if (this._argumentExpression(argument, variable)) {
- node.inputs.push([ { id: variable.value } ]);
- args.shift();
- continue;
- }
- if (argument.type == '.' &&
- argument.target.type == 'identifier' &&
- argument.target.value == 'CONSTANTS' &&
- argument.member.type == 'identifier' &&
- argument.member.value.startsWith('c')) {
- var constantId = [ argument.target.value, argument.member.value ].join('.');
- var constantIndex = parseInt(argument.member.value.substring(1), 10);
- var constantTensor = this._container.tensors[constantIndex];
- node.inputs.push([ { id: constantId, initializer: constantTensor } ]);
- args.shift();
- continue;
- }
- throw new torchscript.Error('Unknown function argument.');
- }
- while (args.length > 0) {
- if (args[0].type == 'list') {
- for (var i = 0; i < args[0].value.length; i++) {
- args[0].value[i] = this._attributeExpression(args[0].value[i]);
- }
- }
- var intExpression = this._attributeExpression(args[0]);
- if (intExpression) {
- args[0] = intExpression;
- }
- node.attributes.push(args[0]);
- args.shift();
- }
- if (target.type == 'identifier') {
- node.outputs.push(target.value);
- }
- if (target.type == 'tuple') {
- for (var identifier of target.value) {
- node.outputs.push(identifier.value);
- }
- }
- this._nodes.push(node);
- return true;
- }
- }
- return false;
- }
- _nodeStatement(statement) {
- if (statement.type == '=') {
- if (this._nodeExpression(statement.expression, statement.target)) {
- return true;
- }
- }
- return false;
- }
- _attributeExpression(expression) {
- if (expression.type == 'identifier') {
- if (this._numToTensorMap[expression.value]) {
- return { type: 'number', value: this._numToTensorMap[expression.value] };
- }
- }
- if (expression.type == 'call' &&
- expression.target.type == 'identifier' &&
- expression.target.value == 'int' &&
- expression.arguments.length == 1)
- {
- var replace = this._attributeExpression(expression.arguments[0]);
- if (replace) {
- return replace;
- }
- }
- return expression;
- }
- _attributeStatement(statement) {
- if (statement.type == '=' &&
- statement.target.type == 'identifier') {
- if (statement.expression.type == 'call' &&
- this._name(statement.expression.target) == 'ops.prim.NumToTensor' &&
- statement.expression.arguments.length == 1) {
- var size = statement.expression.arguments[0];
- if (size.type == 'call' &&
- size.arguments.length == 2 &&
- this._name(size.target) == 'torch.size' &&
- size.arguments[0].type == 'identifier' &&
- size.arguments[1].type == 'number') {
- this._numToTensorMap[statement.target.value] = this._name(size.target) + '(' + size.arguments.map((a) => a.value.toString()).join(',') + ')';
- return true;
- }
- if (size.type == 'identifier') {
- var duplicate1 = this._numToTensorMap[size.value];
- if (duplicate1) {
- this._numToTensorMap[statement.target.value] = duplicate1;
- return true;
- }
- }
- }
- if (statement.expression.type == 'call' &&
- statement.expression.arguments.length == 2 &&
- this._name(statement.expression.target) == 'torch.size' &&
- statement.expression.arguments[0].type == 'identifier' &&
- statement.expression.arguments[1].type == 'number') {
- this._numToTensorMap[statement.target.value] = this._name(statement.expression.target) + '(' + statement.expression.arguments.map((a) => a.value.toString()).join(',') + ')';
- return true;
- }
- if (statement.expression.type == 'call' &&
- statement.expression.target.type == 'identifier' &&
- statement.expression.target.value == 'int' &&
- statement.expression.arguments.length == 1 &&
- statement.expression.arguments[0].type == 'identifier') {
- var duplicate2 = this._numToTensorMap[statement.expression.arguments[0].value];
- if (duplicate2) {
- this._numToTensorMap[statement.target.value] = duplicate2;
- return true;
- }
- }
- }
- return false;
- }
- _module(expression) {
- var module;
- var submodule;
- if (expression.type === '.') {
- module = this._module(expression.target);
- if (module && module.submodules) {
- for (submodule of module.submodules) {
- if (submodule.name === expression.member.value) {
- return submodule;
- }
- }
- }
- }
- if (expression.type == 'call' &&
- expression.target.type == 'identifier' && expression.target.value == 'getattr' && expression.arguments.length == 2) {
- module = this._module(expression.arguments[0]);
- if (!module) {
- return null;
- }
- var name = null;
- if (expression.arguments[1].type == 'string') {
- name = expression.arguments[1].value.substring(1, expression.arguments[1].value.length - 1);
- }
- if (module) {
- for (submodule of module.submodules) {
- if (submodule.name === name) {
- return submodule;
- }
- }
- }
- }
- if (expression.type == 'identifier') {
- if (expression.value == 'self') {
- return this._mainModule;
- }
- module = this._moduleMap[expression.value];
- if (module) {
- return module;
- }
- }
- return null;
- }
- _moduleStatement(statement) {
- if (statement.type == '=' &&
- statement.target.type === 'identifier') {
- var moduleName = statement.target.value;
- var module = this._module(statement.expression);
- if (module) {
- this._moduleMap[moduleName] = module;
- return true;
- }
- }
- return false;
- }
- _argumentExpression(expression, target) {
- expression = this._moduleTensor(expression);
- if (expression.type === '.' && expression.member.type == 'identifier') {
- var targetModule = this._module(expression.target);
- if (targetModule && targetModule.parameters) {
- for (var parameter of targetModule.parameters) {
- parameter.module = targetModule;
- if (parameter.name === expression.member.value) {
- parameter.outputs = parameter.outputs || [];
- parameter.outputs.push(target.value);
- return true;
- }
- }
- targetModule.unresolvedParameters = targetModule.unresolvedParameters || [];
- for (var unresolvedParameter of targetModule.unresolvedParameters) {
- unresolvedParameter.module = targetModule;
- if (unresolvedParameter.name === expression.member.value) {
- unresolvedParameter.outputs = unresolvedParameter.outputs || [];
- unresolvedParameter.outputs.push(target.value);
- return true;
- }
- }
- targetModule.unresolvedParameters.push({
- module: targetModule,
- name: expression.member.value,
- outputs: [ target.value ]
- });
- return true;
- }
- }
- return false;
- }
- _argumentStatement(statement) {
- if (statement.type === '=' && statement.target.type === 'identifier') {
- if (this._argumentExpression(statement.expression, statement.target)) {
- return true;
- }
- if (statement.target.type == 'identifier' &&
- statement.expression.type == 'list') {
- this._argumentMap[statement.target.value] = statement.expression;
- return true;
- }
- }
- return false;
- }
- _variable() {
- return { type: 'identifier', value: '_gen' + Math.random().toString(36).substring(7) };
- }
- _name(expression) {
- if (expression.type == 'identifier') {
- return expression.value;
- }
- if (expression.type == '.') {
- return [ this._name(expression.target), this._name(expression.member) ].join('.');
- }
- throw new torchscript.Error("Failed to resolve name '" + JSON.stringify(expression) + "'.");
- }
- _moduleTensor(expression) {
- if (expression.type == 'call' &&
- expression.arguments.length == 1 &&
- this._name(expression.target) == 'torch.t') {
- return expression.arguments[0];
- }
- return expression;
- }
- }
- torchscript.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading TorchScript model.';
- }
- };
- if (typeof module !== 'undefined' && typeof module.exports === 'object') {
- module.exports.ModelFactory = torchscript.ModelFactory;
- }
|