소스 검색

Update pytorch.js (#1061)

Lutz Roeder 1 년 전
부모
커밋
5eef3cd5da
5개의 변경된 파일450개의 추가작업 그리고 207개의 파일을 삭제
  1. 18 13
      source/python.js
  2. 29 0
      source/pytorch-metadata.json
  3. 381 192
      source/pytorch.js
  4. 2 2
      test/models.json
  5. 20 0
      tools/pytorch_script.py

+ 18 - 13
source/python.js

@@ -4377,6 +4377,10 @@ python.Execution = class {
             }
             throw new python.Error(`Schema '${op_name}.${overload_name}' not found.`);
         });
+        this.registerFunction('torch._C._jit_get_schemas_for_operator', (op_name) => {
+            const registry = torch._C._get_registry();
+            return registry.getAllOperatorsFor(op_name).map((op) => op.schema());
+        });
         this.registerFunction('torch._C._jit_get_operation', (op_name) => {
             const registry = torch._C._get_registry();
             const sortedOps = registry.getAllOperatorsFor(op_name);
@@ -6152,14 +6156,14 @@ python.Execution = class {
 
         });
         this.registerType('torch.Type', class {
-            constructor(kind, name) {
+            constructor(kind, annotation_str) {
                 this._kind = kind;
-                if (name) {
-                    this._name = name;
+                if (annotation_str) {
+                    this._annotation_str = annotation_str;
                 }
             }
-            static get(kind, name) {
-                return new torch.Type(kind, name);
+            static get(kind, annotation_str) {
+                return new torch.Type(kind, annotation_str);
             }
             kind() {
                 return this._kind;
@@ -6171,8 +6175,8 @@ python.Execution = class {
                 throw new python.Error(`Not implemented '${this.kind()}'.`);
             }
             str() {
-                if (this._kind === 'VarType' && this._name) {
-                    return this._name;
+                if (this._kind === 'VarType' && this._annotation_str) {
+                    return this._annotation_str;
                 } else if (this._kind === 'ScalarTypeType') {
                     return 'ScalarType';
                 } else if (this._kind === 'QSchemeType') {
@@ -6722,6 +6726,7 @@ python.Execution = class {
                     case 't': case 't1': case 't2': case 'tVal': return torch.Type.get('VarType', value);
                     case 'Any': return torch.AnyType.get();
                     case 'AnyEnumType': return torch.Type.get('AnyEnumType');
+                    case 'Dimname': return torch.StringType.get();
                     case 'QScheme': return torch.Type.get('QSchemeType');
                     case 'Stream': return torch.StreamObjType.get();
                     case 'Storage': return torch.Type.get('Storage');
@@ -7036,7 +7041,7 @@ python.Execution = class {
         });
         this.registerType('torch.FunctionSchema', class {
             constructor(name, overload_name, args, returns, is_vararg, is_varret) {
-                let index = name.indexOf('(');
+                const index = name.indexOf('(');
                 if (index === -1) {
                     this._name = name;
                     this._overload_name = overload_name;
@@ -7046,15 +7051,15 @@ python.Execution = class {
                     this._is_varret = is_varret;
                 } else {
                     const value = name.substring(0, index).trim();
-                    this._buffer = name.substring(index, name.length);
-                    index = value.indexOf('.');
-                    if (index === -1) {
+                    const dot = value.indexOf('.');
+                    if (dot === -1) {
                         this._name = value;
                         this._overload_name = '';
                     } else {
-                        this._name = value.substring(0, index);
-                        this._overload_name = value.substring(index + 1, value.length);
+                        this._name = value.substring(0, dot);
+                        this._overload_name = value.substring(dot + 1, value.length);
                     }
+                    this._buffer = name.substring(index, name.length);
                 }
             }
             static parse(schema) {

+ 29 - 0
source/pytorch-metadata.json

@@ -738,6 +738,16 @@
       { "type": "Tensor" }
     ]
   },
+  {
+    "name": "aten::__is__(t1 self, t2 obj) -> bool",
+    "inputs": [
+      { "name": "self", "type": "t1" },
+      { "name": "obj", "type": "t2" }
+    ],
+    "outputs": [
+      { "type": "boolean" }
+    ]
+  },
   {
     "name": "aten::__isnot__(t1 self, t2 obj) -> bool",
     "inputs": [
@@ -5084,6 +5094,25 @@
       { "type": "Tensor" }
     ]
   },
+  {
+    "name": "aten::device(str a) -> Device",
+    "inputs": [
+      { "name": "a", "type": "string" }
+    ],
+    "outputs": [
+      { "type": "Device" }
+    ]
+  },
+  {
+    "name": "aten::device.with_index(str type, int index) -> Device",
+    "inputs": [
+      { "name": "type", "type": "string" },
+      { "name": "index", "type": "int64" }
+    ],
+    "outputs": [
+      { "type": "Device" }
+    ]
+  },
   {
     "name": "aten::diag(Tensor self, int diagonal=0) -> Tensor",
     "inputs": [

+ 381 - 192
source/pytorch.js

@@ -371,33 +371,6 @@ pytorch.Node = class {
             type.name = name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0];
             return type;
         };
-        const createAttribute = (metadata, name, value) => {
-            let visible = true;
-            let type = 'attribute';
-            metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata;
-            if (metadata) {
-                if (metadata.type) {
-                    type = metadata.type;
-                }
-                if (metadata.visible === false) {
-                    visible = false;
-                } else if (metadata.default !== undefined) {
-                    if (Array.isArray(value)) {
-                        if (Array.isArray(metadata.default)) {
-                            visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
-                        } else {
-                            visible = !value.every((item) => item === metadata.default);
-                        }
-                    } else {
-                        visible = value !== metadata.default;
-                    }
-                }
-            }
-            if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
-                value = '?';
-            }
-            return new pytorch.Argument(name, value, type, visible);
-        };
         let module = null;
         if (pytorch.Utility.isInstance(obj, 'torch.Node')) {
             const node = obj;
@@ -485,13 +458,13 @@ pytorch.Node = class {
                     type = type.getElementType();
                 }
                 let argument = null;
-                if (arg && pytorch.Utility.isInstance(arg.real_type, 'torch.ClassType')) {
+                if (type && pytorch.Utility.isInstance(type, 'torch.ClassType')) {
                     const obj = input.value;
                     if (!array && initializers.has(obj)) {
                         const node = new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values);
                         argument = new pytorch.Argument(name, node, 'object');
                     } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
-                        const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values));
+                        const node = obj.map((obj) => new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values));
                         argument = new pytorch.Argument(name, node, 'object[]');
                     } else {
                         const identifier = input.unique().toString();
@@ -797,6 +770,33 @@ pytorch.Node = class {
                         const argument = new pytorch.Argument(name, node, 'object', visible);
                         this.inputs.push(argument);
                     } else {
+                        const createAttribute = (metadata, name, value) => {
+                            let visible = true;
+                            let type = 'attribute';
+                            metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata;
+                            if (metadata) {
+                                if (metadata.type) {
+                                    type = metadata.type;
+                                }
+                                if (metadata.visible === false) {
+                                    visible = false;
+                                } else if (metadata.default !== undefined) {
+                                    if (Array.isArray(value)) {
+                                        if (Array.isArray(metadata.default)) {
+                                            visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
+                                        } else {
+                                            visible = !value.every((item) => item === metadata.default);
+                                        }
+                                    } else {
+                                        visible = value !== metadata.default;
+                                    }
+                                }
+                            }
+                            if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
+                                value = '?';
+                            }
+                            return new pytorch.Argument(name, value, type, visible);
+                        };
                         const argument = createAttribute(metadata.attribute(type, name), name, value);
                         this.inputs.push(argument);
                     }
@@ -2923,7 +2923,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                     }
                     let type = parameter.type;
                     let optional = false;
-                    if (parameter.type.endsWith('?')) {
+                    if (type.endsWith('?')) {
                         type = parameter.type.substring(0, parameter.type.length - 1);
                         optional = true;
                     }
@@ -3483,6 +3483,144 @@ pytorch.jit.Execution = class extends pytorch.Execution {
         return result[0];
     }
 
+    isNativeType(obj, type) {
+        const torch = this.torch;
+        switch (type.str()) {
+            case 'Tensor':
+                return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType));
+            case 'Tensor[]':
+                return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
+            case 'Scalar':
+                return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
+                    (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
+                    (obj instanceof torch.Value && (obj.type() instanceof torch.IntType || obj.type() instanceof torch.FloatType || obj.type() instanceof torch.NumberType));
+            case 'bool':
+                return obj === true || obj === false || (pytorch.Utility.isInstance(obj, 'torch.Value') && obj.type() instanceof torch.BoolType);
+            case 'bool[]':
+                if (Array.isArray(obj) && obj.every((item) => item === true || item === false)) {
+                    return true;
+                }
+                if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.BoolType) {
+                    return true;
+                }
+                return false;
+            case 'SymInt':
+            case 'int':
+                return Number.isInteger(obj) || typeof obj === 'bigint' ||
+                    (typeof obj === 'number' && isNaN(obj)) || (obj instanceof Number) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.IntType) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType);
+            case 'SymInt[]':
+            case 'SymInt[2]':
+            case 'SymInt[3]':
+            case 'SymInt[4]':
+            case 'SymInt[5]':
+            case 'SymInt[6]':
+                if (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.SymIntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item)))) {
+                    return true;
+                }
+                if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType) {
+                    return true;
+                }
+                return false;
+            case 'SymInt[1]':
+                return this.isNativeType(obj, torch.IntType.get()) || this.isNativeType(obj, torch.ListType.get(torch.IntType.get()));
+            case 'int[]':
+            case 'int[2]':
+            case 'int[3]':
+                return (Array.isArray(obj) && obj.every((item) => this.isNativeType(item, torch.IntType.get()) || item === undefined || (item.__class__ === 'number' && isNaN(item))) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.IntType)) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType && obj.type().getElementType().getElementType() instanceof torch.IntType);
+            case 'int[1]':
+            case 'float':
+                return obj !== null && (typeof obj === 'number' || obj instanceof Number) ||
+                    (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.FloatType'));
+            case 'float[]':
+                if (Array.isArray(obj) && obj.every((item) => (typeof item === 'number' || item instanceof Number) && !isNaN(item))) {
+                    return true;
+                }
+                if (pytorch.Utility.isInstance(obj, 'torch.Value') && pytorch.Utility.isInstance(obj.type(), 'torch.ListType') && (pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.IntType') || pytorch.Utility.isInstance(obj.type().getElementType(), 'torch.FloatType'))) {
+                    return true;
+                }
+                return false;
+            case 'str':
+                return obj === null || typeof obj === 'string' ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.StringType);
+            case 'str[]':
+                return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType);
+            case 'str[][]':
+                return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
+            case 'Layout':
+            case 'ScalarType':
+            case 'MemoryFormat':
+                return Number.isInteger(obj) || obj === null ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.IntType) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.IntType);
+            case 'Dimname':
+                return obj === null || (typeof obj === 'string' || obj instanceof String);
+            case 'Dimname[]':
+                return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string');
+            case 'Device':
+                return obj === null || obj === Object(obj);
+            case 't[]':
+                return Array.isArray(obj) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.ListType);
+            case 't':
+                return true;
+            case 'AnyEnumType':
+                return false;
+            case 'complex':
+                return obj instanceof torch.Value && obj.type() instanceof torch.ComplexType;
+            case 'Any[]':
+                if (Array.isArray(obj)) {
+                    return true;
+                }
+                if (obj instanceof torch.Value && obj.type() instanceof torch.ListType) {
+                    return true;
+                }
+                return false;
+            case 't1':
+            case 't2':
+                return true;
+            default: {
+                if (type instanceof torch.ClassType &&
+                    obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
+                    return type.qualified_name() === `${obj.__class__.__module__}.${obj.__class__.__name__}`;
+                }
+                if (type instanceof torch.TupleType) {
+                    throw new pytorch.Error('Not implemented.');
+                    /*
+                    if (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TupleType) {
+                        const elements = obj.type().getElementType().elements();
+                        if (elements.length === 2) {
+                            if (pytorch.Utility.toType(elements[0]) === match[1]) {
+                                return true;
+                            }
+                        }
+                    }
+                    return false;
+                    */
+                }
+                if (type instanceof torch.DictType) {
+                    if (obj instanceof torch.Value && obj.type() instanceof torch.DictType) {
+                        if ((type.getKeyType().kind() === 'VarType' || type.getKeyType().str() === obj.type().getKeyType().str()) ||
+                            (type.getValueType().kind() === 'VarType' || type.getValueType().str() === obj.type().getValueType().str())) {
+                            return true;
+                        }
+                    }
+                    return false;
+                }
+                // throw new pytorch.Error(`Unknown type '${type}'.`);
+                return true;
+            }
+        }
+    }
+
     isType(obj, type) {
         const torch = this.torch;
         switch (type) {
@@ -3491,8 +3629,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                     (obj instanceof torch.Value && obj.type() instanceof torch.TensorType) ||
                     (obj instanceof torch.Value && obj.type() instanceof torch.OptionalType && obj.type().getElementType() instanceof torch.TensorType));
             case 'Tensor[]':
-                return Array.isArray(obj) && obj.length > 0 &&
-                    obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType));
+                return (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null || (tensor instanceof torch.Value && tensor.type() instanceof torch.TensorType))) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.TensorType);
             case 'Scalar':
                 return (obj !== null && (obj !== Object(obj) || obj instanceof Number)) ||
                     (pytorch.Utility.isTensor(obj) && Array.isArray(obj.size()) && obj.size().length === 0) ||
@@ -3550,7 +3688,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                 }
                 return false;
             case 'string[]':
-                return Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string');
+                return (Array.isArray(obj) && obj.every((item) => item === null || typeof item === 'string')) ||
+                    (obj instanceof torch.Value && obj.type() instanceof torch.ListType && obj.type().getElementType() instanceof torch.StringType);
             case 'string[][]':
                 return Array.isArray(obj) && obj.every((item) => Array.isArray(item) && item.every((item) => typeof item === 'string'));
             case 'Layout':
@@ -3627,27 +3766,27 @@ pytorch.jit.Execution = class extends pytorch.Execution {
         const torch = this.torch;
         const type = name ? `${moduleName}.${name}` : moduleName;
         // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
-        let overloads = null;
+        let op_name = null;
         if (type.startsWith('torch.')) {
-            overloads = this._types.get(`aten::${type.substring(6)}`);
-        } else if (type.startsWith('ops.prim.')) {
-            overloads = this._types.get(`prim::${type.substring(9)}`);
+            op_name = `aten::${type.substring(6)}`;
+        } else if (type.startsWith('ops.')) {
+            op_name = type.substring(4).replace('.', '::');
         } else if (type === 'int') {
-            overloads = this._types.get(`aten::Int`);
+            op_name = 'aten::Int';
         } else if (type === 'str') {
-            overloads = this._types.get(`aten::str`);
+            op_name = 'aten::str';
         } else if (type === 'bool') {
-            overloads = this._types.get(`aten::Bool`);
+            op_name = 'aten::Bool';
         } else if (type === 'float') {
-            overloads = this._types.get(`aten::Float`);
+            op_name = 'aten::Float';
         } else if (type === 'complex') {
-            overloads = this._types.get(`aten::Complex`);
-        } else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) {
-            const path = type.split('.');
-            if (path.length === 3) {
-                overloads = this._types.get(`${path[1]}::${path[2]}`);
-            }
-            if (!overloads) {
+            op_name = 'aten::Complex';
+        }
+        this.native = false;
+        if (this.native && op_name) {
+            const overloads = torch._C._jit_get_schemas_for_operator(op_name);
+            /*
+            if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) {
                 const module = this.import(moduleName);
                 if (!module || !module[name]) {
                     const metadata = {};
@@ -3674,53 +3813,83 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                     overloads = [metadata];
                 }
             }
-        }
-        if (!overloads) {
-            if (type.startsWith('aten::') || type.startsWith('prim::')) {
-                throw new pytorch.Error(`Unknown function '${type}'.`);
-            }
-            return null;
-        }
-        overloads = Array.isArray(overloads) ? overloads : [overloads];
-        const evalArgs = args.map((argument) => {
-            if (argument.type === '=' && argument.target && argument.target.type === 'id') {
-                argument = argument.expression;
+            */
+            if (!overloads) {
+                if (type.startsWith('aten::') || type.startsWith('prim::')) {
+                    throw new pytorch.Error(`Unknown function '${type}'.`);
+                }
+                return null;
             }
-            return this.expression(argument, context);
-        });
-        const matches = [];
-        for (const schema of overloads) {
-            const copyArgs = Array.prototype.slice.call(args);
-            const copyEvalArgs = Array.prototype.slice.call(evalArgs);
-            const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
-            let next = false;
-            let kwarg_only = false;
-            while (copyEvalArgs.length > 0) {
-                if (parameters.length <= 0) {
-                    next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
-                    break;
+            const evalArgs = args.map((argument) => {
+                if (argument.type === '=' && argument.target && argument.target.type === 'id') {
+                    argument = argument.expression;
                 }
-                if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
-                    parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) {
-                    const map = new Map(parameters.map((parameter) => [parameter.name, parameter]));
-                    while (copyArgs.length > 0) {
-                        const argument = copyArgs.shift();
-                        const arg = copyEvalArgs.shift();
-                        const parameter = map.get(argument.target.value);
-                        if (!parameter) {
+                return this.expression(argument, context);
+            });
+            const matches = [];
+            for (const schema of overloads) {
+                const parameters = schema.arguments || [];
+                let next = false;
+                let kwarg_only = false;
+                let position = 0;
+                let index = 0;
+                while (position < evalArgs.length) {
+                    if (index >= parameters.length) {
+                        next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
+                        break;
+                    }
+                    const arg = parameters[index];
+                    if (arg.kwarg_only) {
+                        break;
+                    }
+                    index++;
+                    const value = evalArgs[position];
+                    let type = arg.real_type;
+                    let optional = false;
+                    if (type instanceof torch.OptionalType) {
+                        type = type.getElementType();
+                        optional = true;
+                    }
+                    if (optional === true &&
+                        (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
+                        value instanceof torch.Value && value.type() instanceof torch.NoneType) {
+                        position++;
+                    } else if (!this.isNativeType(value, type) && value !== null) {
+                        if (optional) {
+                            continue;
+                        }
+                        next = true;
+                        break;
+                    } else if (args[position].type === '=') {
+                        next = true;
+                        break;
+                    } else {
+                        position++;
+                    }
+                }
+                if (next) {
+                    continue;
+                }
+                if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) {
+                    const params = new Map(parameters.slice(index).map((a) => [a.name, a]));
+                    while (position < args.length) {
+                        const value = evalArgs[position];
+                        const arg = params.get(args[position].target.value);
+                        position++;
+                        if (!arg) {
                             next = true;
                             break;
                         }
-                        if (parameter.kwarg_only) {
+                        if (arg.kwarg_only) {
                             kwarg_only = true;
                         }
-                        let type = parameter.type;
+                        let type = arg.real_type;
                         let optional = false;
-                        if (parameter.type.endsWith('?')) {
-                            type = parameter.type.substring(0, parameter.type.length - 1);
+                        if (type instanceof torch.OptionalType) {
+                            type = type.getElementType();
                             optional = true;
                         }
-                        if (!this.isType(arg, type)) {
+                        if (!this.isNativeType(value, type)) {
                             if (optional) {
                                 continue;
                             }
@@ -3728,134 +3897,154 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                             break;
                         }
                     }
-                    continue;
                 }
                 if (next) {
-                    break;
+                    continue;
                 }
-                const parameter = parameters.shift();
-                if (parameter.kwarg_only) {
-                    kwarg_only = true;
+                if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) {
+                    continue;
                 }
-                const [argument] = copyEvalArgs;
-                /* if (type === 'Tensor' || (type === 'Scalar' && pytorch.Utility.isTensor(argument))) {
-                    if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
-                        if (optional) {
-                            continue;
-                        }
-                        next = true;
-                    } else {
-                        copyArgs.shift();
-                        copyEvalArgs.shift();
+                if (!kwarg_only && parameters.slice(index).some((arg) => !arg.has_default_value())) {
+                    continue;
+                }
+                matches.push(schema);
+            }
+            if (matches.length > 1) {
+                const keys = new Map([['IntType', 1], ['FloatType', 2], ['TensorType', 3], ['NumberType', 4]]);
+                matches.sort((a, b) => {
+                    let keyA = keys.get(a.arguments[0].real_type.kind()) || 5;
+                    let keyB = keys.get(b.arguments[0].real_type.kind()) || 5;
+                    if (keyA === keyB && a.arguments.length > 1 && b.arguments.length > 1) {
+                        keyA = keys.get(a.arguments[1].real_type.kind()) || 5;
+                        keyB = keys.get(b.arguments[1].real_type.kind()) || 5;
+                    }
+                    return keyA - keyB;
+                });
+            }
+            if (matches.length === 0) {
+                throw new pytorch.Error(`Unknown function '${op_name}'.`);
+            }
+            // return [matches[0], evalArgs];
+        }
+        let overloads = this._types.get(op_name);
+        if (!overloads && type.startsWith('ops.') && !type.startsWith('ops.prim')) {
+            const module = this.import(moduleName);
+            if (!module || !module[name]) {
+                const metadata = {};
+                metadata.name = type;
+                metadata.inputs = [];
+                metadata.outputs = [];
+                for (let i = 0; i < args.length; i++) {
+                    const input = {};
+                    let argument = args[i];
+                    input.name = i.toString();
+                    if (argument.type === '=' && argument.target && argument.target.type === 'id') {
+                        input.name = this.expression(argument.target, context);
+                        argument = argument.expression;
                     }
-                } else */
-                let type = parameter.type;
+                    const obj = this.expression(argument, context);
+                    input.type = pytorch.Utility.getType(obj);
+                    metadata.inputs.push(input);
+                }
+                const count = context.target.length > 0 ? context.target[context.target.length - 1].length : 0;
+                for (let i = 0; i < count; i++) {
+                    metadata.outputs.push({ name: '', type: '' });
+                }
+                this._metadata.add(type, metadata);
+                overloads = [metadata];
+            }
+        }
+        if (!overloads) {
+            if (type.startsWith('aten::') || type.startsWith('prim::')) {
+                throw new pytorch.Error(`Unknown function '${type}'.`);
+            }
+            return null;
+        }
+        const evalArgs = args.map((argument) => {
+            if (argument.type === '=' && argument.target && argument.target.type === 'id') {
+                argument = argument.expression;
+            }
+            return this.expression(argument, context);
+        });
+        const matches = [];
+        for (const schema of overloads) {
+            const parameters = schema.inputs || [];
+            let next = false;
+            let kwarg_only = false;
+            let position = 0;
+            let index = 0;
+            while (position < evalArgs.length) {
+                if (index >= parameters.length) {
+                    next = !schema.name.startsWith('_caffe2::') && !schema.is_vararg;
+                    break;
+                }
+                const arg = parameters[index];
+                if (arg.kwarg_only) {
+                    break;
+                }
+                index++;
+                const value = evalArgs[position];
+                let type = arg.type;
                 let optional = false;
-                if (parameter.type.endsWith('?')) {
-                    type = parameter.type.substring(0, parameter.type.length - 1);
+                if (type.endsWith('?')) {
+                    type = arg.type.substring(0, arg.type.length - 1);
                     optional = true;
                 }
                 if (optional === true &&
                     (type === 'float32' || type === 'boolean' || type === 'int64' || type === 'complex' || type === 'ScalarType' || type === 'Device' || type === 'Layout') &&
-                    argument instanceof torch.Value && argument.type() instanceof torch.NoneType) {
-                    copyArgs.shift();
-                    copyEvalArgs.shift();
-                } else if (type === 'Tensor[]') {
-                    const [argument] = copyEvalArgs;
-                    if ((argument instanceof torch.Value && pytorch.Utility.toType(argument.type()) === 'Tensor[]') ||
-                        (Array.isArray(argument) && argument.every((item) => pytorch.Utility.isTensor(item) || item === null || (item instanceof torch.Value && item.type() instanceof torch.TensorType)))) {
-                        copyArgs.shift();
-                        copyEvalArgs.shift();
-                    } else {
-                        if (optional) {
-                            continue;
-                        }
-                        next = true;
+                    value instanceof torch.Value && value.type() instanceof torch.NoneType) {
+                    position++;
+                } else if (!this.isType(value, type) && value !== null) {
+                    if (optional) {
+                        continue;
                     }
-                /* } else if (type === 't[]') {
-                    if (!Array.isArray(argument) && (argument instanceof torch.Value === false || argument.type() instanceof torch.ListType === false)) {
-                        if (optional) {
-                            continue;
-                        }
-                        next = true;
-                    } else {
-                        copyArgs.shift();
-                        copyEvalArgs.shift();
-                    }*/
+                    next = true;
+                    break;
+                } else if (args[position].type === '=' && args[position].target.value !== arg.name) {
+                    next = true;
+                    break;
                 } else {
-                    const [arg] = copyArgs;
-                    if (!this.isType(argument, type) && argument !== null) {
+                    position++;
+                }
+            }
+            if (next) {
+                continue;
+            }
+            if (args.every((arg, index) => index < position || (arg.type === '=' && arg.target && arg.target.type === 'id'))) {
+                const params = new Map(parameters.slice(index).map((a) => [a.name, a]));
+                while (position < args.length) {
+                    const value = evalArgs[position];
+                    const arg = params.get(args[position].target.value);
+                    position++;
+                    if (!arg) {
+                        next = true;
+                        break;
+                    }
+                    if (arg.kwarg_only) {
+                        kwarg_only = true;
+                    }
+                    let type = arg.type;
+                    let optional = false;
+                    if (type.endsWith('?')) {
+                        type = arg.type.substring(0, arg.type.length - 1);
+                        optional = true;
+                    }
+                    if (!this.isType(value, type)) {
                         if (optional) {
                             continue;
                         }
                         next = true;
-                    } else if (arg.type === '=') {
-                        next = true;
-                        // throw new pytorch.Error('Expected named argument.');
-                    } else {
-                        copyArgs.shift();
-                        copyEvalArgs.shift();
+                        break;
                     }
                 }
-                if (next) {
-                    break;
-                }
             }
             if (next) {
                 continue;
             }
-            if (!kwarg_only && parameters.some((parameter) => parameter.default === undefined)) {
+            if (position < evalArgs.length && !schema.is_vararg && !schema.name.startsWith('_caffe2::')) {
                 continue;
             }
-            for (let i = 0; i < schema.outputs.length; i++) {
-                const parameter = schema.outputs[i];
-                switch (parameter.type) {
-                    case 'Scalar':
-                    case 'Tensor':
-                    case 'Tensor[]':
-                    case 'float32':
-                    case 'float32[]':
-                    case 'int64':
-                    case 'int64[]':
-                    case 'Device':
-                    case 'boolean':
-                    case 'boolean[]':
-                    case 't':
-                    case 't[]':
-                    case 'complex':
-                    case 'complex[]':
-                    case 'string':
-                    case 'string[]':
-                    case 'Dict(string, Tensor)':
-                    case 'Dict(Tensor, t)':
-                    case 'Dict(boolean, t)':
-                    case 'Dict(complex, t)':
-                    case 'Dict(float32, t)':
-                    case 'Dict(int64, t)':
-                    case 'Dict(string, t)':
-                    case 'Dict(Tensor, tVal)':
-                    case 'Dict(boolean, tVal)':
-                    case 'Dict(complex, tVal)':
-                    case 'Dict(float32, tVal)':
-                    case 'Dict(int64, tVal)':
-                    case 'Dict(string, tVal)':
-                    case '(string, t)[]':
-                    case 'Any':
-                        break;
-                    case '__torch__.torch.classes.xnnpack.LinearOpContext':
-                    case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
-                    case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
-                    case '__torch__.torch.classes.rnn.CellParamsBase':
-                    case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
-                    case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
-                    case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
-                        break;
-                    default: {
-                        throw new pytorch.Error(`Unknown return type '${parameter.type}'.`);
-                    }
-                }
-            }
-            if (next) {
+            if (!kwarg_only && parameters.slice(index).some((parameter) => parameter.default === undefined)) {
                 continue;
             }
             matches.push(schema);
@@ -3872,10 +4061,10 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                 return keyA - keyB;
             });
         }
-        if (matches.length > 0) {
-            return [matches[0], evalArgs];
+        if (matches.length === 0) {
+            throw new pytorch.Error(`Unknown function '${type}'.`);
         }
-        throw new pytorch.Error(`Unknown function '${type}'.`);
+        return [matches[0], evalArgs];
     }
 
     block(statements, context) {

+ 2 - 2
test/models.json

@@ -6281,7 +6281,7 @@
     "target":   "TestSerialization.test_lstm.traced.pt",
     "source":   "https://github.com/user-attachments/files/16121906/TestSerialization.test_lstm.traced.pt.zip[TestSerialization.test_lstm.traced.pt]",
     "format":   "TorchScript v1.6",
-    "assert":   "model.graphs[0].nodes.length == 9",
+    "assert":   "model.graphs[0].nodes.length == 10",
     "link":     "https://github.com/lutzroeder/netron/issues/1067"
   },
   {
@@ -6289,7 +6289,7 @@
     "target":   "TFModel_traced_eager_quant.pt",
     "source":   "https://github.com/lutzroeder/netron/files/10867120/TFModel_traced_eager_quant.pt.zip[TFModel_traced_eager_quant.pt]",
     "format":   "TorchScript v1.6",
-    "assert":   "model.graphs[0].nodes.length == 46",
+    "assert":   "model.graphs[0].nodes.length == 51",
     "link":     "https://github.com/lutzroeder/netron/issues/1067"
   },
   {

+ 20 - 0
tools/pytorch_script.py

@@ -430,6 +430,15 @@ known_schema_definitions = [
     'aten::__and__.int(int a, int b) -> int',
     'aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor',
     'aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor',
+    'aten::__contains__.Tensor(Dict(Tensor, t) dict, Tensor key) -> bool',
+    'aten::__contains__.bool(Dict(bool, t) dict, bool key) -> bool',
+    'aten::__contains__.complex(Dict(complex, t) dict, complex key) -> bool',
+    'aten::__contains__.float(Dict(float, t) dict, float key) -> bool',
+    'aten::__contains__.float_list(float[] l, float item) -> bool',
+    'aten::__contains__.int(Dict(int, t) dict, int key) -> bool',
+    'aten::__contains__.int_list(int[] l, int item) -> bool',
+    'aten::__contains__.str(Dict(str, t) dict, str key) -> bool',
+    'aten::__contains__.str_list(str[] l, str item) -> bool',
     'aten::__getitem__.Dict_bool(Dict(bool, t) self, bool key) -> t(*)',
     'aten::__getitem__.Dict_complex(Dict(complex, t) self, complex key) -> t(*)',
     'aten::__getitem__.Dict_float(Dict(float, t) self, float key) -> t(*)',
@@ -438,6 +447,7 @@ known_schema_definitions = [
     'aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)',
     'aten::__getitem__.str(str s, int index) -> str',
     'aten::__getitem__.t(t[](a) list, int idx) -> t(*)',
+    'aten::__is__(t1 self, t2 obj) -> bool',
     'aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)',
     'aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)',
     'aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))',
@@ -447,6 +457,13 @@ known_schema_definitions = [
     'aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))',
     'aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)',
     'aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))',
+    'aten::_set_item.Tensor(Dict(Tensor, t)(a!) l, Tensor(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.bool(Dict(bool, t)(a!) l, bool(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.complex(Dict(complex, t)(a!) l, complex(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.float(Dict(float, t)(a!) l, float(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.int(Dict(int, t)(a!) l, int(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.str(Dict(str, t)(a!) l, str(b -> *) idx, t(c -> *) v) -> ()',
+    'aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)',
     'aten::add(Scalar a, Scalar b) -> Scalar',
     'aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor',
     'aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)',
@@ -554,6 +571,8 @@ known_schema_definitions = [
     'aten::Complex.Tensor_int(Tensor x, int y) -> complex',
     'aten::Complex.Tensor_Tensor(Tensor a, Tensor b) -> complex',
     'aten::ComplexImplicit(Tensor a) -> complex',
+    'aten::device(str a) -> Device',
+    'aten::device.with_index(str type, int index) -> Device',
     'aten::dict.bool((bool, tVal)[] inputs) -> Dict(bool, tVal)',
     'aten::dict.complex((complex, tVal)[] inputs) -> Dict(complex, tVal)',
     'aten::dict.Dict_bool(Dict(bool, t)(a) self) -> Dict(bool, t)',
@@ -893,6 +912,7 @@ known_schema_definitions = [
     'aten::values.str(Dict(str, t) self) -> t[](*)',
     'aten::values.Tensor(Dict(Tensor, t) self) -> t[](*)',
     'aten::values(Tensor(a) self) -> Tensor(a)',
+    'aten::warn(str message, int stacklevel=2) -> ()',
     'prim::abs.complex(complex a) -> float',
     'prim::abs.float(float a) -> float',
     'prim::abs.int(int a) -> int',