|
|
@@ -538,7 +538,8 @@ pytorch.Node = class {
|
|
|
}
|
|
|
|
|
|
get operator() {
|
|
|
- return this._type;
|
|
|
+ const index = this._type.indexOf(':');
|
|
|
+ return index === -1 ? this._type : this._type.substring(0, index);
|
|
|
}
|
|
|
|
|
|
get category() {
|
|
|
@@ -960,15 +961,28 @@ pytorch.Metadata = class {
|
|
|
if (items) {
|
|
|
for (const item of items) {
|
|
|
if (item.name && item.schema) {
|
|
|
+ item.schema.name = item.name;
|
|
|
this._map.set(item.name, item.schema);
|
|
|
}
|
|
|
+ const index = item.name.indexOf(':');
|
|
|
+ if (index !== -1) {
|
|
|
+ const name = item.name.substring(0, index);
|
|
|
+ if (!this._map.has(name)) {
|
|
|
+ this._map.set(name, [])
|
|
|
+ }
|
|
|
+ this._map.get(name).push(item.name);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
type(operator) {
|
|
|
- return this._map.get(operator) || null;
|
|
|
+ const schema = this._map.get(operator);
|
|
|
+ if (schema) {
|
|
|
+ return Array.isArray(schema) ? schema.map((name) => this._map.get(name)) : schema;
|
|
|
+ }
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
attribute(operator, name) {
|
|
|
@@ -1466,6 +1480,12 @@ pytorch.Execution = class {
|
|
|
this._registerFunction('ops.prim.RaiseException', function(message) {
|
|
|
throw new pytorch.Error(message);
|
|
|
});
|
|
|
+ this._registerFunction('torch.add', function(left, right) {
|
|
|
+ if (typeof left === 'number' && typeof right === 'number') {
|
|
|
+ return left * right;
|
|
|
+ }
|
|
|
+ throw new pytorch.Error('Unknown torch.add expression type.');
|
|
|
+ });
|
|
|
this._registerFunction('torch.__is__', function(left, right) {
|
|
|
if (left === null && right === null) {
|
|
|
return true;
|
|
|
@@ -1552,9 +1572,6 @@ pytorch.Execution = class {
|
|
|
if (typeof left === 'number' && typeof right === 'number') {
|
|
|
return left * right;
|
|
|
}
|
|
|
- if (pytorch.Utility.isTensor(left) && pytorch.Utility.isTensor(right)) {
|
|
|
- return { __module__: 'torch', __name__: 'Tensor', __origin__: 'torch.mul' };
|
|
|
- }
|
|
|
throw new pytorch.Error('Unknown torch.mul expression type.');
|
|
|
});
|
|
|
this._registerFunction('torch.ne', function(left, right) {
|
|
|
@@ -2863,12 +2880,18 @@ pytorch.Container.Zip = class {
|
|
|
let args = [ this.data ]; // self
|
|
|
if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
|
|
|
for (const parameter of this.data.forward.__code__.parameters) {
|
|
|
- if (parameter.name !== 'self' &&
|
|
|
- parameter.parameterType.type === 'type' &&
|
|
|
- parameter.parameterType.name.type === 'id' &&
|
|
|
- parameter.parameterType.name.value === 'Tensor') {
|
|
|
- this._inputs.push(parameter.name);
|
|
|
- args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
|
|
|
+ if (parameter.name !== 'self') {
|
|
|
+ const type = parameter.parameterType;
|
|
|
+ if (type.type === 'type' && type.name.type) {
|
|
|
+ if (type.name.value === 'Tensor') {
|
|
|
+ this._inputs.push(parameter.name);
|
|
|
+ args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' });
|
|
|
+ }
|
|
|
+ if (type.name.value === 'Tuple' && type.arguments.every((item) => item.type === 'type' && item.name.type === 'id' && item.name.value === 'Tensor')) {
|
|
|
+ this._inputs.push(parameter.name);
|
|
|
+ args.push(type.arguments.map(() => { return { __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input' } }));
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -2919,87 +2942,101 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
let callArgs = Array.prototype.slice.call(args);
|
|
|
if (callTarget) {
|
|
|
const type = callTarget + '.' + name;
|
|
|
- const schema = this._metadata.type(type);
|
|
|
- if (schema) {
|
|
|
- args = Array.prototype.slice.call(args);
|
|
|
- let node = {
|
|
|
- type: type,
|
|
|
- inputs: [],
|
|
|
- attributes: [],
|
|
|
- outputs: []
|
|
|
- };
|
|
|
- const inputSchemas = Array.prototype.slice.call(schema.inputs);
|
|
|
- while (inputSchemas.length > 0) {
|
|
|
- let inputSchema = inputSchemas.shift();
|
|
|
- const argument = this.expression(callArgs.shift(), context);
|
|
|
- while (inputSchema.option === 'optional' && Array.isArray(argument) && inputSchema.type !== 'T[]' && inputSchemas.length > 0) {
|
|
|
- inputSchema = inputSchemas.shift();
|
|
|
- }
|
|
|
- const parameters = Array.isArray(argument) ? argument : [ argument ];
|
|
|
- let inputs = [];
|
|
|
- for (let parameter of parameters) {
|
|
|
- if (parameter) {
|
|
|
- if (!pytorch.Utility.isTensor(parameter)) {
|
|
|
- if (typeof parameter !== 'number' && isNaN(parameter)) {
|
|
|
- return super.call(target, name, args, context);
|
|
|
+ // ./aten/src/ATen/native/native_functions.yaml
|
|
|
+ let schemas = this._metadata.type(type);
|
|
|
+ if (schemas) {
|
|
|
+ if (!Array.isArray(schemas)) {
|
|
|
+ schemas = [ schemas ]
|
|
|
+ }
|
|
|
+ for (const schema of schemas) {
|
|
|
+ let node = {
|
|
|
+ type: schema.name,
|
|
|
+ inputs: [],
|
|
|
+ attributes: [],
|
|
|
+ outputs: []
|
|
|
+ };
|
|
|
+ let next = false;
|
|
|
+ const inputSchemas = Array.prototype.slice.call(schema.inputs);
|
|
|
+ while (inputSchemas.length > 0) {
|
|
|
+ let inputSchema = inputSchemas.shift();
|
|
|
+ const argument = this.expression(callArgs.shift(), context);
|
|
|
+ while (inputSchema.option === 'optional' && Array.isArray(argument) && inputSchema.type !== 'T[]' && inputSchemas.length > 0) {
|
|
|
+ inputSchema = inputSchemas.shift();
|
|
|
+ }
|
|
|
+ const parameters = Array.isArray(argument) ? argument : [ argument ];
|
|
|
+ let inputs = [];
|
|
|
+ for (let parameter of parameters) {
|
|
|
+ if (parameter !== undefined) {
|
|
|
+ if (!pytorch.Utility.isTensor(parameter) && parameter !== null) {
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (parameter === null) {
|
|
|
+ parameter = {};
|
|
|
+ }
|
|
|
+ if (parameter.__variable__) {
|
|
|
+ inputs.push({ id: parameter.__variable__ });
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const id = this._variable().value;
|
|
|
+ parameter.__variable__ = id;
|
|
|
+ parameter.__outputs__ = parameter.__outputs__ || [];
|
|
|
+ parameter.__outputs__.push(id);
|
|
|
+ inputs.push({ id: id });
|
|
|
}
|
|
|
- parameter = {};
|
|
|
- }
|
|
|
- if (parameter.__variable__) {
|
|
|
- inputs.push({ id: parameter.__variable__ });
|
|
|
- }
|
|
|
- else {
|
|
|
- const id = this._variable().value;
|
|
|
- parameter.__variable__ = id;
|
|
|
- parameter.__outputs__ = parameter.__outputs__ || [];
|
|
|
- parameter.__outputs__.push(id);
|
|
|
- inputs.push({ id: id });
|
|
|
}
|
|
|
}
|
|
|
+ if (next) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ node.inputs.push(inputs);
|
|
|
}
|
|
|
- node.inputs.push(inputs);
|
|
|
- }
|
|
|
- while (callArgs.length > 0 && callArgs[0].type !== '=') {
|
|
|
- const value = this.expression(callArgs.shift(), context);
|
|
|
- node.attributes.push(value);
|
|
|
- }
|
|
|
- while (callArgs.length > 0) {
|
|
|
- const arg = callArgs.shift();
|
|
|
- if (arg.type === '=' && arg.target && arg.target.type === 'id') {
|
|
|
- const value = this.expression(arg.expression, context);
|
|
|
- node.attributes.push({ type: '=', target: arg.target, expression: value });
|
|
|
+ if (next) {
|
|
|
+ callArgs = Array.prototype.slice.call(args);
|
|
|
+ continue;
|
|
|
}
|
|
|
- else {
|
|
|
- throw new pytorch.Attribute('Expected named argument.');
|
|
|
- }
|
|
|
- }
|
|
|
- let outputs = []
|
|
|
- for (let i = 0; i < schema.outputs.length; i++) {
|
|
|
- let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
|
|
|
- switch (type) {
|
|
|
- case 'torch.cat':
|
|
|
- case 'torch.conv2d':
|
|
|
- case 'torch.flatten':
|
|
|
- case 'torch.quantize_per_tensor':
|
|
|
- case 'torch.relu_':
|
|
|
- case 'torch.dropout': {
|
|
|
- parameter.size = [ undefined, undefined, undefined, undefined ];
|
|
|
- break;
|
|
|
+ while (callArgs.length > 0 && callArgs[0].type !== '=') {
|
|
|
+ const value = this.expression(callArgs.shift(), context);
|
|
|
+ node.attributes.push(value);
|
|
|
+ }
|
|
|
+ while (callArgs.length > 0) {
|
|
|
+ const arg = callArgs.shift();
|
|
|
+ if (arg.type === '=' && arg.target && arg.target.type === 'id') {
|
|
|
+ const value = this.expression(arg.expression, context);
|
|
|
+ node.attributes.push({ type: '=', target: arg.target, expression: value });
|
|
|
}
|
|
|
- case 'torch.embedding': {
|
|
|
- parameter.size = [ undefined, undefined, undefined ];
|
|
|
- break;
|
|
|
+ else {
|
|
|
+ throw new pytorch.Attribute('Expected named argument.');
|
|
|
}
|
|
|
}
|
|
|
- parameter.__variable__ = this._variable().value;
|
|
|
- outputs.push(parameter)
|
|
|
- node.outputs.push(parameter.__variable__);
|
|
|
- }
|
|
|
- this._nodes.push(node);
|
|
|
- if (outputs.length > 1) {
|
|
|
- return outputs;
|
|
|
+ let outputs = []
|
|
|
+ for (let i = 0; i < schema.outputs.length; i++) {
|
|
|
+ let parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
|
|
|
+ switch (type) {
|
|
|
+ case 'torch.cat':
|
|
|
+ case 'torch.conv2d':
|
|
|
+ case 'torch.flatten':
|
|
|
+ case 'torch.quantize_per_tensor':
|
|
|
+ case 'torch.relu_':
|
|
|
+ case 'torch.dropout': {
|
|
|
+ parameter.size = [ undefined, undefined, undefined, undefined ];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'torch.embedding': {
|
|
|
+ parameter.size = [ undefined, undefined, undefined ];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ parameter.__variable__ = this._variable().value;
|
|
|
+ outputs.push(parameter)
|
|
|
+ node.outputs.push(parameter.__variable__);
|
|
|
+ }
|
|
|
+ this._nodes.push(node);
|
|
|
+ if (outputs.length > 1) {
|
|
|
+ return outputs;
|
|
|
+ }
|
|
|
+ return outputs[0];
|
|
|
}
|
|
|
- return outputs[0];
|
|
|
}
|
|
|
}
|
|
|
return super.call(target, name, args, context);
|