|
|
@@ -371,33 +371,6 @@ pytorch.Node = class {
|
|
|
type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0];
|
|
|
return type;
|
|
|
};
|
|
|
- const createAttribute = (metadata, name, value) => {
|
|
|
- let visible = true;
|
|
|
- let type = 'attribute';
|
|
|
- metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata;
|
|
|
- if (metadata) {
|
|
|
- if (metadata.type) {
|
|
|
- type = metadata.type;
|
|
|
- }
|
|
|
- if (metadata.visible === false) {
|
|
|
- visible = false;
|
|
|
- } else if (metadata.default !== undefined) {
|
|
|
- if (Array.isArray(value)) {
|
|
|
- if (Array.isArray(metadata.default)) {
|
|
|
- visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
|
|
|
- } else {
|
|
|
- visible = !value.every((item) => item === metadata.default);
|
|
|
- }
|
|
|
- } else {
|
|
|
- visible = value !== metadata.default;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
|
|
|
- value = '?';
|
|
|
- }
|
|
|
- return new pytorch.Argument(name, value, type, visible);
|
|
|
- };
|
|
|
let module = null;
|
|
|
if (pytorch.Utility.isInstance(obj, 'torch.Node')) {
|
|
|
const node = obj;
|
|
|
@@ -485,13 +458,13 @@ pytorch.Node = class {
|
|
|
type = type.getElementType();
|
|
|
}
|
|
|
let argument = null;
|
|
|
- if (arg && pytorch.Utility.isInstance(arg.real_type, 'torch.ClassType')) {
|
|
|
+ if (type && pytorch.Utility.isInstance(type, 'torch.ClassType')) {
|
|
|
const obj = input.value;
|
|
|
if (!array && initializers.has(obj)) {
|
|
|
const node = new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values);
|
|
|
argument = new pytorch.Argument(name, node, 'object');
|
|
|
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
|
|
|
- const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values));
|
|
|
+ const node = obj.map((obj) => new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values));
|
|
|
argument = new pytorch.Argument(name, node, 'object[]');
|
|
|
} else {
|
|
|
const identifier = input.unique().toString();
|
|
|
@@ -797,6 +770,33 @@ pytorch.Node = class {
|
|
|
const argument = new pytorch.Argument(name, node, 'object', visible);
|
|
|
this.inputs.push(argument);
|
|
|
} else {
|
|
|
+ const createAttribute = (metadata, name, value) => {
|
|
|
+ let visible = true;
|
|
|
+ let type = 'attribute';
|
|
|
+ metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata;
|
|
|
+ if (metadata) {
|
|
|
+ if (metadata.type) {
|
|
|
+ type = metadata.type;
|
|
|
+ }
|
|
|
+ if (metadata.visible === false) {
|
|
|
+ visible = false;
|
|
|
+ } else if (metadata.default !== undefined) {
|
|
|
+ if (Array.isArray(value)) {
|
|
|
+ if (Array.isArray(metadata.default)) {
|
|
|
+ visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
|
|
|
+ } else {
|
|
|
+ visible = !value.every((item) => item === metadata.default);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ visible = value !== metadata.default;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
|
|
|
+ value = '?';
|
|
|
+ }
|
|
|
+ return new pytorch.Argument(name, value, type, visible);
|
|
|
+ };
|
|
|
const argument = createAttribute(metadata.attribute(type, name), name, value);
|
|
|
this.inputs.push(argument);
|
|
|
}
|
|
|
@@ -2923,7 +2923,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
}
|
|
|
let type = parameter.type;
|
|
|
let optional = false;
|
|
|
- if (parameter.type.endsWith('?')) {
|
|
|
+ if (type.endsWith('?')) {
|
|
|
type = parameter.type.substring(0, parameter.type.length - 1);
|
|
|
optional = true;
|
|
|
}
|
|
|
@@ -3483,6 +3483,144 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
return result[0];
|
|
|
}
|
|
|
|
|
|
+ isNativeType(obj, type) {
|
|
|
+ const torch = this.torch;
|
|
|
+ switch (type.str()) {
|
|
|
+ case 'Tensor':
|
|
|
+ return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType));
|
|
|
+ case 'Tensor[]':
|
|
|
+ return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
|
|
|
+ case 'Scalar':
|
|
|
+ return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
|
|
|
+ (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
|
|
|
+ (obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType || obj.type() instanceof torch.NumberType));
|
|
|
+ case 'bool':
|
|
|
+ return obj === true || obj === false || (pytorch.Utility.isInstance(obj, 'torch.Value') && obj.type() instanceof torch.BoolType);
|
|
|
+ case 'bool[]':
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => item === true || item === false)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.BoolType) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ case 'SymInt':
|
|
|
+ case 'int':
|
|
|
+ return Number.isInteger(obj) || typeof obj === 'bigint' ||
|
|
|
+ (typeof obj === 'number' && isNaN(obj)) || (obj instanceof Number) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.IntType) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType);
|
|
|
+ case 'SymInt[]':
|
|
|
+ case 'SymInt[2]':
|
|
|
+ case 'SymInt[3]':
|
|
|
+ case 'SymInt[4]':
|
|
|
+ case 'SymInt[5]':
|
|
|
+ case 'SymInt[6]':
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.SymIntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item)))) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ case 'SymInt[1]':
|
|
|
+ return this.isNativeType(obj, torch.IntType.get()) || this.isNativeType(obj, torch.ListType.get(torch.IntType.get()));
|
|
|
+ case 'int[]':
|
|
|
+ case 'int[2]':
|
|
|
+ case 'int[3]':
|
|
|
+ return (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.IntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item))) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType)) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.IntType);
|
|
|
+ case 'int[1]':
|
|
|
+ case 'float':
|
|
|
+ return obj !== null && (typeof obj === 'number' || obj instanceof Number) ||
|
|
|
+ (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.FloatType'));
|
|
|
+ case 'float[]':
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item))) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.ListType') && (pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.IntType') || pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.FloatType'))) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ case 'str':
|
|
|
+ return obj === null || typeof obj === 'string' ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.StringType);
|
|
|
+ case 'str[]':
|
|
|
+ return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType);
|
|
|
+ case 'str[][]':
|
|
|
+ return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
|
|
|
+ case 'Layout':
|
|
|
+ case 'ScalarType':
|
|
|
+ case 'MemoryFormat':
|
|
|
+ return Number.isInteger(obj) || obj === null ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.IntType) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType);
|
|
|
+ case 'Dimname':
|
|
|
+ return obj === null || (typeof obj === 'string' || obj instanceof String);
|
|
|
+ case 'Dimname[]':
|
|
|
+ return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string');
|
|
|
+ case 'Device':
|
|
|
+ return obj === null || obj === Object(obj);
|
|
|
+ case 't[]':
|
|
|
+ return Array.isArray(obj) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType);
|
|
|
+ case 't':
|
|
|
+ return true;
|
|
|
+ case 'AnyEnumType':
|
|
|
+ return false;
|
|
|
+ case 'complex':
|
|
|
+ return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType;
|
|
|
+ case 'Any[]':
|
|
|
+ if (Array.isArray(obj)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (obj instanceof torch.Value && obj.type() instanceof torch.ListType) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ case 't1':
|
|
|
+ case 't2':
|
|
|
+ return true;
|
|
|
+ default: {
|
|
|
+ if (type instanceof torch.ClassType &&
|
|
|
+ obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
|
|
|
+ return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`;
|
|
|
+ }
|
|
|
+ if (type instanceof torch.TupleType) {
|
|
|
+ throw new pytorch.Error('Not implemented.');
|
|
|
+ /*
|
|
|
+ if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TupleType) {
|
|
|
+ const elements = obj.type().getElementType().elements();
|
|
|
+ if (elements.length === 2) {
|
|
|
+ if (pytorch.Utility.toType(elements[0]) === match[1]) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ */
|
|
|
+ }
|
|
|
+ if (type instanceof torch.DictType) {
|
|
|
+ if (obj instanceof torch.Value && obj.type() instanceof torch.DictType) {
|
|
|
+ if ((type.getKeyType().kind() === 'VarType' || type.getKeyType().str() === obj.type().getKeyType().str()) ||
|
|
|
+ (type.getValueType().kind() === 'VarType' || type.getValueType().str() === obj.type().getValueType().str())) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ // throw new pytorch.Error(`Unknown type '${type}'.`);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
isType(obj, type) {
|
|
|
const torch = this.torch;
|
|
|
switch (type) {
|
|
|
@@ -3491,8 +3629,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
(obj instanceof torch.Value && obj.type() instanceof torch.TensorType) ||
|
|
|
(obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType));
|
|
|
case 'Tensor[]':
|
|
|
- return Array.isArray(obj) && obj.length > 0 &&
|
|
|
- obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType));
|
|
|
+ return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
|
|
|
case 'Scalar':
|
|
|
return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
|
|
|
(pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
|
|
|
@@ -3550,7 +3688,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
}
|
|
|
return false;
|
|
|
case 'string[]':
|
|
|
- return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string');
|
|
|
+ return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) ||
|
|
|
+ (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType);
|
|
|
case 'string[][]':
|
|
|
return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
|
|
|
case 'Layout':
|
|
|
@@ -3627,27 +3766,27 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
const torch = this.torch;
|
|
|
const type = name ? `${moduleName}.${name}` : moduleName;
|
|
|
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
|
|
- let overloads = null;
|
|
|
+ let op_name = null;
|
|
|
if (type.startsWith('torch.')) {
|
|
|
- overloads = this._types.get(`aten::${type.substring(6)}`);
|
|
|
- } else if (type.startsWith('ops.prim.')) {
|
|
|
- overloads = this._types.get(`prim::${type.substring(9)}`);
|
|
|
+ op_name = `aten::${type.substring(6)}`;
|
|
|
+ } else if (type.startsWith('ops.')) {
|
|
|
+ op_name = type.substring(4).replace('.', '::');
|
|
|
} else if (type === 'int') {
|
|
|
- overloads = this._types.get(`aten::Int`);
|
|
|
+ op_name = 'aten::Int';
|
|
|
} else if (type === 'str') {
|
|
|
- overloads = this._types.get(`aten::str`);
|
|
|
+ op_name = 'aten::str';
|
|
|
} else if (type === 'bool') {
|
|
|
- overloads = this._types.get(`aten::Bool`);
|
|
|
+ op_name = 'aten::Bool';
|
|
|
} else if (type === 'float') {
|
|
|
- overloads = this._types.get(`aten::Float`);
|
|
|
+ op_name = 'aten::Float';
|
|
|
} else if (type === 'complex') {
|
|
|
- overloads = this._types.get(`aten::Complex`);
|
|
|
- } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) {
|
|
|
- const path = type.split('.');
|
|
|
- if (path.length === 3) {
|
|
|
- overloads = this._types.get(`${path[1]}::${path[2]}`);
|
|
|
- }
|
|
|
- if (!overloads) {
|
|
|
+ op_name = 'aten::Complex';
|
|
|
+ }
|
|
|
+ this.native = false;
|
|
|
+ if (this.native && op_name) {
|
|
|
+ const overloads = torch._C._jit_get_schemas_for_operator(op_name);
|
|
|
+ /*
|
|
|
+ if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) {
|
|
|
const module = this.import(moduleName);
|
|
|
if (!module || !module[name]) {
|
|
|
const metadata = {};
|
|
|
@@ -3674,53 +3813,83 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
overloads = [metadata];
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
- if (!overloads) {
|
|
|
- if (type.startsWith('aten::') || type.startsWith('prim::')) {
|
|
|
- throw new pytorch.Error(`Unknown function '${type}'.`);
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
- overloads = Array.isArray(overloads) ? overloads : [overloads];
|
|
|
- const evalArgs = args.map((argument) => {
|
|
|
- if (argument.type === '=' && argument.target && argument.target.type === 'id') {
|
|
|
- argument = argument.expression;
|
|
|
+ */
|
|
|
+ if (!overloads) {
|
|
|
+ if (type.startsWith('aten::') || type.startsWith('prim::')) {
|
|
|
+ throw new pytorch.Error(`Unknown function '${type}'.`);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
}
|
|
|
- return this.expression(argument, context);
|
|
|
- });
|
|
|
- const matches = [];
|
|
|
- for (const schema of overloads) {
|
|
|
- const copyArgs = Array.prototype.slice.call(args);
|
|
|
- const copyEvalArgs = Array.prototype.slice.call(evalArgs);
|
|
|
- const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
|
|
|
- let next = false;
|
|
|
- let kwarg_only = false;
|
|
|
- while (copyEvalArgs.length > 0) {
|
|
|
- if (parameters.length <= 0) {
|
|
|
- next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
|
|
|
- break;
|
|
|
+ const evalArgs = args.map((argument) => {
|
|
|
+ if (argument.type === '=' && argument.target && argument.target.type === 'id') {
|
|
|
+ argument = argument.expression;
|
|
|
}
|
|
|
- if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
|
|
|
- parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) {
|
|
|
- const map = new Map(parameters.map((parameter) => [parameter.name, parameter]));
|
|
|
- while (copyArgs.length > 0) {
|
|
|
- const argument = copyArgs.shift();
|
|
|
- const arg = copyEvalArgs.shift();
|
|
|
- const parameter = map.get(argument.target.value);
|
|
|
- if (!parameter) {
|
|
|
+ return this.expression(argument, context);
|
|
|
+ });
|
|
|
+ const matches = [];
|
|
|
+ for (const schema of overloads) {
|
|
|
+ const parameters = schema.arguments || [];
|
|
|
+ let next = false;
|
|
|
+ let kwarg_only = false;
|
|
|
+ let position = 0;
|
|
|
+ let index = 0;
|
|
|
+ while (position < evalArgs.length) {
|
|
|
+ if (index >= parameters.length) {
|
|
|
+ next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const arg = parameters[index];
|
|
|
+ if (arg.kwarg_only) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ index++;
|
|
|
+ const value = evalArgs[position];
|
|
|
+ let type = arg.real_type;
|
|
|
+ let optional = false;
|
|
|
+ if (type instanceof torch.OptionalType) {
|
|
|
+ type = type.getElementType();
|
|
|
+ optional = true;
|
|
|
+ }
|
|
|
+ if (optional === true &&
|
|
|
+ (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
|
|
|
+ value instanceof torch.Value && value.type() instanceof torch.NoneType) {
|
|
|
+ position++;
|
|
|
+ } else if (!this.isNativeType(value, type) && value !== null) {
|
|
|
+ if (optional) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
+ } else if (args[position].type === '=') {
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
+ } else {
|
|
|
+ position++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (next) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) {
|
|
|
+ const params = new Map(parameters.slice(index).map((a) => [a.name, a]));
|
|
|
+ while (position < args.length) {
|
|
|
+ const value = evalArgs[position];
|
|
|
+ const arg = params.get(args[position].target.value);
|
|
|
+ position++;
|
|
|
+ if (!arg) {
|
|
|
next = true;
|
|
|
break;
|
|
|
}
|
|
|
- if (parameter.kwarg_only) {
|
|
|
+ if (arg.kwarg_only) {
|
|
|
kwarg_only = true;
|
|
|
}
|
|
|
- let type = parameter.type;
|
|
|
+ let type = arg.real_type;
|
|
|
let optional = false;
|
|
|
- if (parameter.type.endsWith('?')) {
|
|
|
- type = parameter.type.substring(0, parameter.type.length - 1);
|
|
|
+ if (type instanceof torch.OptionalType) {
|
|
|
+ type = type.getElementType();
|
|
|
optional = true;
|
|
|
}
|
|
|
- if (!this.isType(arg, type)) {
|
|
|
+ if (!this.isNativeType(value, type)) {
|
|
|
if (optional) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -3728,134 +3897,154 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
- continue;
|
|
|
}
|
|
|
if (next) {
|
|
|
- break;
|
|
|
+ continue;
|
|
|
}
|
|
|
- const parameter = parameters.shift();
|
|
|
- if (parameter.kwarg_only) {
|
|
|
- kwarg_only = true;
|
|
|
+ if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) {
|
|
|
+ continue;
|
|
|
}
|
|
|
- const [argument] = copyEvalArgs;
|
|
|
- /* if (type === 'Tensor' || (type === 'Scalar' && pytorch.Utility.isTensor(argument))) {
|
|
|
- if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
|
|
|
- if (optional) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- next = true;
|
|
|
- } else {
|
|
|
- copyArgs.shift();
|
|
|
- copyEvalArgs.shift();
|
|
|
+ if (!kwarg_only && parameters.slice(index).some((arg) => !arg.has_default_value())) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ matches.push(schema);
|
|
|
+ }
|
|
|
+ if (matches.length > 1) {
|
|
|
+ const keys = new Map([['IntType', 1], ['FloatType', 2], ['TensorType', 3], ['NumberType', 4]]);
|
|
|
+ matches.sort((a, b) => {
|
|
|
+ let keyA = keys.get(a.arguments[0].real_type.kind()) || 5;
|
|
|
+ let keyB = keys.get(b.arguments[0].real_type.kind()) || 5;
|
|
|
+ if (keyA === keyB && a.arguments.length > 1 && b.arguments.length > 1) {
|
|
|
+ keyA = keys.get(a.arguments[1].real_type.kind()) || 5;
|
|
|
+ keyB = keys.get(b.arguments[1].real_type.kind()) || 5;
|
|
|
+ }
|
|
|
+ return keyA - keyB;
|
|
|
+ });
|
|
|
+ }
|
|
|
+ if (matches.length === 0) {
|
|
|
+ throw new pytorch.Error(`Unknown function '${op_name}'.`);
|
|
|
+ }
|
|
|
+ // return [matches[0], evalArgs];
|
|
|
+ }
|
|
|
+ let overloads = this._types.get(op_name);
|
|
|
+ if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) {
|
|
|
+ const module = this.import(moduleName);
|
|
|
+ if (!module || !module[name]) {
|
|
|
+ const metadata = {};
|
|
|
+ metadata.name = type;
|
|
|
+ metadata.inputs = [];
|
|
|
+ metadata.outputs = [];
|
|
|
+ for (let i = 0; i < args.length; i++) {
|
|
|
+ const input = {};
|
|
|
+ let argument = args[i];
|
|
|
+ input.name = i.toString();
|
|
|
+ if (argument.type === '=' && argument.target && argument.target.type === 'id') {
|
|
|
+ input.name = this.expression(argument.target, context);
|
|
|
+ argument = argument.expression;
|
|
|
}
|
|
|
- } else */
|
|
|
- let type = parameter.type;
|
|
|
+ const obj = this.expression(argument, context);
|
|
|
+ input.type = pytorch.Utility.getType(obj);
|
|
|
+ metadata.inputs.push(input);
|
|
|
+ }
|
|
|
+ const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0;
|
|
|
+ for (let i = 0; i < count; i++) {
|
|
|
+ metadata.outputs.push({ name: '', type: '' });
|
|
|
+ }
|
|
|
+ this._metadata.add(type, metadata);
|
|
|
+ overloads = [metadata];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!overloads) {
|
|
|
+ if (type.startsWith('aten::') || type.startsWith('prim::')) {
|
|
|
+ throw new pytorch.Error(`Unknown function '${type}'.`);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ const evalArgs = args.map((argument) => {
|
|
|
+ if (argument.type === '=' && argument.target && argument.target.type === 'id') {
|
|
|
+ argument = argument.expression;
|
|
|
+ }
|
|
|
+ return this.expression(argument, context);
|
|
|
+ });
|
|
|
+ const matches = [];
|
|
|
+ for (const schema of overloads) {
|
|
|
+ const parameters = schema.inputs || [];
|
|
|
+ let next = false;
|
|
|
+ let kwarg_only = false;
|
|
|
+ let position = 0;
|
|
|
+ let index = 0;
|
|
|
+ while (position < evalArgs.length) {
|
|
|
+ if (index >= parameters.length) {
|
|
|
+ next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const arg = parameters[index];
|
|
|
+ if (arg.kwarg_only) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ index++;
|
|
|
+ const value = evalArgs[position];
|
|
|
+ let type = arg.type;
|
|
|
let optional = false;
|
|
|
- if (parameter.type.endsWith('?')) {
|
|
|
- type = parameter.type.substring(0, parameter.type.length - 1);
|
|
|
+ if (type.endsWith('?')) {
|
|
|
+ type = arg.type.substring(0, arg.type.length - 1);
|
|
|
optional = true;
|
|
|
}
|
|
|
if (optional === true &&
|
|
|
(type === 'float32' || type === 'boolean' || type === 'int64' || type === 'complex' || type === 'ScalarType' || type === 'Device' || type === 'Layout') &&
|
|
|
- argument instanceof torch.Value && argument.type() instanceof torch.NoneType) {
|
|
|
- copyArgs.shift();
|
|
|
- copyEvalArgs.shift();
|
|
|
- } else if (type === 'Tensor[]') {
|
|
|
- const [argument] = copyEvalArgs;
|
|
|
- if ((argument instanceof torch.Value && pytorch.Utility.toType(argument.type()) === 'Tensor[]') ||
|
|
|
- (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) {
|
|
|
- copyArgs.shift();
|
|
|
- copyEvalArgs.shift();
|
|
|
- } else {
|
|
|
- if (optional) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- next = true;
|
|
|
+ value instanceof torch.Value && value.type() instanceof torch.NoneType) {
|
|
|
+ position++;
|
|
|
+ } else if (!this.isType(value, type) && value !== null) {
|
|
|
+ if (optional) {
|
|
|
+ continue;
|
|
|
}
|
|
|
- /* } else if (type === 't[]') {
|
|
|
- if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) {
|
|
|
- if (optional) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- next = true;
|
|
|
- } else {
|
|
|
- copyArgs.shift();
|
|
|
- copyEvalArgs.shift();
|
|
|
- }*/
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
+ } else if (args[position].type === '=' && args[position].target.value !== arg.name) {
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
} else {
|
|
|
- const [arg] = copyArgs;
|
|
|
- if (!this.isType(argument, type) && argument !== null) {
|
|
|
+ position++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (next) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) {
|
|
|
+ const params = new Map(parameters.slice(index).map((a) => [a.name, a]));
|
|
|
+ while (position < args.length) {
|
|
|
+ const value = evalArgs[position];
|
|
|
+ const arg = params.get(args[position].target.value);
|
|
|
+ position++;
|
|
|
+ if (!arg) {
|
|
|
+ next = true;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (arg.kwarg_only) {
|
|
|
+ kwarg_only = true;
|
|
|
+ }
|
|
|
+ let type = arg.type;
|
|
|
+ let optional = false;
|
|
|
+ if (type.endsWith('?')) {
|
|
|
+ type = arg.type.substring(0, arg.type.length - 1);
|
|
|
+ optional = true;
|
|
|
+ }
|
|
|
+ if (!this.isType(value, type)) {
|
|
|
if (optional) {
|
|
|
continue;
|
|
|
}
|
|
|
next = true;
|
|
|
- } else if (arg.type === '=') {
|
|
|
- next = true;
|
|
|
- // throw new pytorch.Error('Expected named argument.');
|
|
|
- } else {
|
|
|
- copyArgs.shift();
|
|
|
- copyEvalArgs.shift();
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
- if (next) {
|
|
|
- break;
|
|
|
- }
|
|
|
}
|
|
|
if (next) {
|
|
|
continue;
|
|
|
}
|
|
|
- if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) {
|
|
|
+ if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) {
|
|
|
continue;
|
|
|
}
|
|
|
- for (let i = 0; i < schema.outputs.length; i++) {
|
|
|
- const parameter = schema.outputs[i];
|
|
|
- switch (parameter.type) {
|
|
|
- case 'Scalar':
|
|
|
- case 'Tensor':
|
|
|
- case 'Tensor[]':
|
|
|
- case 'float32':
|
|
|
- case 'float32[]':
|
|
|
- case 'int64':
|
|
|
- case 'int64[]':
|
|
|
- case 'Device':
|
|
|
- case 'boolean':
|
|
|
- case 'boolean[]':
|
|
|
- case 't':
|
|
|
- case 't[]':
|
|
|
- case 'complex':
|
|
|
- case 'complex[]':
|
|
|
- case 'string':
|
|
|
- case 'string[]':
|
|
|
- case 'Dict(string, Tensor)':
|
|
|
- case 'Dict(Tensor, t)':
|
|
|
- case 'Dict(boolean, t)':
|
|
|
- case 'Dict(complex, t)':
|
|
|
- case 'Dict(float32, t)':
|
|
|
- case 'Dict(int64, t)':
|
|
|
- case 'Dict(string, t)':
|
|
|
- case 'Dict(Tensor, tVal)':
|
|
|
- case 'Dict(boolean, tVal)':
|
|
|
- case 'Dict(complex, tVal)':
|
|
|
- case 'Dict(float32, tVal)':
|
|
|
- case 'Dict(int64, tVal)':
|
|
|
- case 'Dict(string, tVal)':
|
|
|
- case '(string, t)[]':
|
|
|
- case 'Any':
|
|
|
- break;
|
|
|
- case '__torch__.torch.classes.xnnpack.LinearOpContext':
|
|
|
- case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
|
|
|
- case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
|
|
|
- case '__torch__.torch.classes.rnn.CellParamsBase':
|
|
|
- case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
|
|
|
- case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
|
|
|
- case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
|
|
|
- break;
|
|
|
- default: {
|
|
|
- throw new pytorch.Error(`Unknown return type '${parameter.type}'.`);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (next) {
|
|
|
+ if (!kwarg_only && parameters.slice(index).some((parameter) => parameter.default === undefined)) {
|
|
|
continue;
|
|
|
}
|
|
|
matches.push(schema);
|
|
|
@@ -3872,10 +4061,10 @@ pytorch.jit.Execution = class extends pytorch.Execution {
|
|
|
return keyA - keyB;
|
|
|
});
|
|
|
}
|
|
|
- if (matches.length > 0) {
|
|
|
- return [matches[0], evalArgs];
|
|
|
+ if (matches.length === 0) {
|
|
|
+ throw new pytorch.Error(`Unknown function '${type}'.`);
|
|
|
}
|
|
|
- throw new pytorch.Error(`Unknown function '${type}'.`);
|
|
|
+ return [matches[0], evalArgs];
|
|
|
}
|
|
|
|
|
|
block(statements, context) {
|