Просмотр исходного кода

Update PyTorch Package experiment (#928)

Lutz Roeder 3 лет назад
Родитель
Сommit
c0e14fff2a
3 измененных файлов с 149 добавлено и 57 удалено
  1. 130 32
      source/python.js
  2. 18 24
      source/pytorch.js
  3. 1 1
      test/models.json

+ 130 - 32
source/python.js

@@ -138,32 +138,32 @@ python.Parser = class {
 
         node = this._eat('id', 'global');
         if (node) {
-            node.variable = [];
+            node.names = [];
             do {
-                node.variable.push(this._parseName());
+                node.names.push(this._parseName(true).value);
             }
             while (this._tokenizer.eat(','));
             return node;
         }
         node = this._eat('id', 'nonlocal');
         if (node) {
-            node.variable = [];
+            node.names = [];
             do {
-                node.variable.push(this._parseName());
+                node.names.push(this._parseName(true).value);
             }
             while (this._tokenizer.eat(','));
             return node;
         }
         node = this._eat('id', 'import');
         if (node) {
-            node.modules = [];
+            node.names = [];
             do {
-                const module = this._node('module');
-                module.name = this._parseExpression(-1, [], false);
+                const alias = this._node('alias');
+                alias.name = this._parseDottedName();
                 if (this._tokenizer.eat('id', 'as')) {
-                    module.as = this._parseExpression(-1, [], false);
+                    alias.asname = this._parseName(true).value;
                 }
-                node.modules.push(module);
+                node.names.push(alias);
             }
             while (this._tokenizer.eat(','));
             return node;
@@ -171,24 +171,21 @@ python.Parser = class {
         node = this._eat('id', 'from');
         if (node) {
             node.type = 'import_from';
+            node.level = 0;
             const dots = this._tokenizer.peek();
             if (dots && Array.from(dots.type).every((c) => c == '.')) {
                 this._eat(dots.type);
                 node.level = Array.from(dots.type).length;
-                node.module = this._parseExpression();
-            }
-            else {
-                node.level = 0;
-                node.module = this._parseExpression();
             }
+            node.module = this._parseDottedName();
             this._tokenizer.expect('id', 'import');
             node.names = [];
             const close = this._tokenizer.eat('(');
             do {
                 const alias = this._node('alias');
-                alias.name = this._parseName();
+                alias.name = this._parseName(true).value;
                 if (this._tokenizer.eat('id', 'as')) {
-                    alias.asname = this._parseName();
+                    alias.asname = this._parseName(true).value;
                 }
                 node.names.push(alias);
             }
@@ -203,13 +200,13 @@ python.Parser = class {
 
         node = this._eat('id', 'class');
         if (node) {
-            node.name = this._parseName().value;
+            node.name = this._parseName(true).value;
             if (decorator_list) {
                 node.decorator_list = Array.from(decorator_list);
                 decorator_list = null;
             }
             if (this._tokenizer.peek().value === '(') {
-                node.base = this._parseArguments();
+                node.bases = this._parseArguments();
             }
             this._tokenizer.expect(':');
             node.body = this._parseSuite();
@@ -229,7 +226,7 @@ python.Parser = class {
             if (async) {
                 node.async = async;
             }
-            node.name = this._parseName().value;
+            node.name = this._parseName(true).value;
             if (decorator_list) {
                 node.decorator_list = Array.from(decorator_list);
                 decorator_list = null;
@@ -821,15 +818,27 @@ python.Parser = class {
         return node;
     }
 
-    _parseName() {
+    _parseName(required) {
         const token = this._tokenizer.peek();
         if (token.type == 'id' && !token.keyword) {
             this._tokenizer.read();
             return token;
         }
+        if (required) {
+            throw new python.Error("Invalid syntax" + this._tokenizer.location());
+        }
         return null;
     }
 
+    _parseDottedName() {
+        const list = [];
+        do {
+            list.push(this._parseName(true).value);
+        }
+        while (this._tokenizer.eat('.'));
+        return list.join('.');
+    }
+
     _parseLiteral() {
         const token = this._tokenizer.peek();
         if (token.type == 'string' || token.type == 'number' || token.type == 'boolean') {
@@ -1942,6 +1951,10 @@ python.Execution = class {
             }
         });
         this.registerType('keras.engine.sequential.Sequential', class {});
+        this.registerType('lasagne.layers.conv.Conv2DLayer', class {});
+        this.registerType('lasagne.layers.dense.DenseLayer', class {});
+        this.registerType('lasagne.layers.input.InputLayer', class {});
+        this.registerType('lasagne.layers.pool.MaxPool2DLayer', class {});
         this.registerType('lightgbm.sklearn.LGBMRegressor', class {});
         this.registerType('lightgbm.sklearn.LGBMClassifier', class {});
         this.registerType('lightgbm.basic.Booster', class {
@@ -2355,6 +2368,90 @@ python.Execution = class {
                 Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
+        this.registerType('theano.compile.function_module._constructor_Function', class {});
+        this.registerType('theano.compile.function_module._constructor_FunctionMaker', class {});
+        this.registerType('theano.compile.function_module.Supervisor', class {});
+        this.registerType('theano.compile.io.In', class {});
+        this.registerType('theano.compile.io.SymbolicOutput', class {});
+        this.registerType('theano.compile.mode.Mode', class {});
+        this.registerType('theano.compile.ops.OutputGuard', class {});
+        this.registerType('theano.compile.ops.Shape', class {});
+        this.registerType('theano.compile.ops.Shape_i', class {});
+        this.registerType('theano.gof.destroyhandler.DestroyHandler', class {});
+        this.registerType('theano.gof.fg.FunctionGraph', class {});
+        this.registerType('theano.gof.graph.Apply', class {});
+        this.registerType('theano.gof.link.Container', class {});
+        this.registerType('theano.gof.opt._metadict', class {});
+        this.registerType('theano.gof.opt.ChangeTracker', class {});
+        this.registerType('theano.gof.opt.MergeFeature', class {});
+        this.registerType('theano.gof.optdb.Query', class {});
+        this.registerType('theano.gof.toolbox.PreserveVariableAttributes', class {});
+        this.registerType('theano.gof.toolbox.ReplaceValidate', class {});
+        this.registerType('theano.gof.utils.scratchpad', class {});
+        this.registerType('theano.misc.ordered_set.Link', class {});
+        this.registerType('theano.misc.ordered_set.OrderedSet', class {});
+        this.registerType('theano.sandbox.cuda.basic_ops.HostFromGpu', class {});
+        this.registerType('theano.sandbox.cuda.type.CudaNdarray_unpickler', class {});
+        this.registerType('theano.sandbox.cuda.type.CudaNdarrayType', class {});
+        this.registerType('theano.sandbox.cuda.var.CudaNdarraySharedVariable', class {});
+        this.registerType('theano.scalar.basic.Abs', class {});
+        this.registerType('theano.scalar.basic.Add', class {});
+        this.registerType('theano.scalar.basic.Cast', class {});
+        this.registerType('theano.scalar.basic.Composite', class {});
+        this.registerType('theano.scalar.basic.EQ', class {});
+        this.registerType('theano.scalar.basic.GE', class {});
+        this.registerType('theano.scalar.basic.Identity', class {});
+        this.registerType('theano.scalar.basic.IntDiv', class {});
+        this.registerType('theano.scalar.basic.Inv', class {});
+        this.registerType('theano.scalar.basic.LE', class {});
+        this.registerType('theano.scalar.basic.LT', class {});
+        this.registerType('theano.scalar.basic.Mul', class {});
+        this.registerType('theano.scalar.basic.Neg', class {});
+        this.registerType('theano.scalar.basic.Scalar', class {});
+        this.registerType('theano.scalar.basic.ScalarConstant', class {});
+        this.registerType('theano.scalar.basic.ScalarVariable', class {});
+        this.registerType('theano.scalar.basic.Second', class {});
+        this.registerType('theano.scalar.basic.Sgn', class {});
+        this.registerType('theano.scalar.basic.specific_out', class {});
+        this.registerType('theano.scalar.basic.Sub', class {});
+        this.registerType('theano.scalar.basic.Switch', class {});
+        this.registerType('theano.scalar.basic.Tanh', class {});
+        this.registerType('theano.scalar.basic.transfer_type', class {});
+        this.registerType('theano.scalar.basic.TrueDiv', class {});
+        this.registerType('theano.tensor.basic.Alloc', class {});
+        this.registerType('theano.tensor.basic.Dot', class {});
+        this.registerType('theano.tensor.basic.MaxAndArgmax', class {});
+        this.registerType('theano.tensor.basic.Reshape', class {});
+        this.registerType('theano.tensor.basic.ScalarFromTensor', class {});
+        this.registerType('theano.tensor.blas.Dot22', class {});
+        this.registerType('theano.tensor.blas.Dot22Scalar', class {});
+        this.registerType('theano.tensor.blas.Gemm', class {});
+        this.registerType('theano.tensor.elemwise.DimShuffle', class {});
+        this.registerType('theano.tensor.elemwise.Elemwise', class {});
+        this.registerType('theano.tensor.elemwise.Sum', class {});
+        this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d', class {});
+        this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradInputs', class {});
+        this.registerType('theano.tensor.nnet.abstract_conv.AbstractConv2d_gradWeights', class {});
+        this.registerType('theano.tensor.nnet.corr.CorrMM', class {});
+        this.registerType('theano.tensor.nnet.corr.CorrMM_gradInputs', class {});
+        this.registerType('theano.tensor.nnet.corr.CorrMM_gradWeights', class {});
+        this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1Hot', class {});
+        this.registerType('theano.tensor.nnet.nnet.CrossentropyCategorical1HotGrad', class {});
+        this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmax1HotWithBiasDx', class {});
+        this.registerType('theano.tensor.nnet.nnet.CrossentropySoftmaxArgmax1HotWithBias', class {});
+        this.registerType('theano.tensor.nnet.nnet.Softmax', class {});
+        this.registerType('theano.tensor.nnet.nnet.SoftmaxGrad', class {});
+        this.registerType('theano.tensor.nnet.nnet.SoftmaxWithBias', class {});
+        this.registerType('theano.tensor.opt.MakeVector', class {});
+        this.registerType('theano.tensor.opt.ShapeFeature', class {});
+        this.registerType('theano.tensor.sharedvar.TensorSharedVariable', class {});
+        this.registerType('theano.tensor.signal.pool.MaxPoolGrad', class {});
+        this.registerType('theano.tensor.signal.pool.Pool', class {});
+        this.registerType('theano.tensor.subtensor.Subtensor', class {});
+        this.registerType('theano.tensor.type.TensorType', class {});
+        this.registerType('theano.tensor.var.TensorConstant', class {});
+        this.registerType('theano.tensor.var.TensorConstantSignature', class {});
+        this.registerType('theano.tensor.var.TensorVariable', class {});
         this.registerType('thinc.describe.Biases', class {
             __setstate__(state) {
                 Object.assign(this, state);
@@ -3159,31 +3256,32 @@ python.Execution = class {
                 break;
             }
             case 'import': {
-                for (const module of statement.modules) {
-                    const moduleName = python.Utility.target(module.name);
-                    const globals = this.package(moduleName);
-                    if (module.as) {
-                        context.set(module.as, globals);
+                for (const alias of statement.names) {
+                    const module = this.package(alias.name);
+                    if (alias.asname) {
+                        context.set(alias.asname, module);
+                    }
+                    else {
+                        context.setx(alias.name, module);
                     }
                 }
                 break;
             }
             case 'import_from': {
                 let module = null;
-                let moduleName = python.Utility.target(statement.module);
                 if (statement.level > 0) {
                     let paths = context.getx('__file__').split('/');
                     paths = paths.slice(0, paths.length - statement.level);
-                    paths.push(moduleName.replace('.', '/'));
-                    moduleName = paths.join('/');
-                    module = this.package(moduleName);
+                    paths.push(statement.module.replace('.', '/'));
+                    const name = paths.join('/');
+                    module = this.package(name);
                 }
                 else {
-                    module = this._package(moduleName, context);
+                    module = this._package(statement.module, context);
                 }
                 for (const entry of statement.names) {
-                    const name = entry.name.value;
-                    const asname = entry.asname ? entry.asname.value : null;
+                    const name = entry.name;
+                    const asname = entry.asname ? entry.asname : null;
                     context.set(asname ? asname : name, module[name]);
                 }
                 break;

+ 18 - 24
source/pytorch.js

@@ -156,12 +156,12 @@ pytorch.Graph = class {
                 break;
             }
             case 'module': {
-                this._type = (graph.obj.__module__ && graph.obj.__name__) ? (graph.obj.__module__ + '.' + graph.obj.__name__) : '';
-                this._loadModule(metadata, graph.obj, [], []);
+                this._type = (graph.data.__module__ && graph.data.__name__) ? (graph.data.__module__ + '.' + graph.data.__name__) : '';
+                this._loadModule(metadata, graph.data, [], []);
                 break;
             }
             case 'weights': {
-                for (const state_group of graph.layers) {
+                for (const state_group of graph.data) {
                     const attributes = state_group.attributes || [];
                     const inputs = state_group.states.map((parameter) => {
                         return new pytorch.Parameter(parameter.name, true,
@@ -2985,13 +2985,12 @@ pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
                     }
                 }
                 execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) {
-                    // torch.jit._script.RecursiveScriptModule
-                    return script_module_id;
+                    return "torch.jit._script.RecursiveScriptModule('" + script_module_id + "')";
                 });
                 const unpickler = python.Unpickler.open(stream);
                 const root = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
-                if (root.model) {
-                    const location = {
+                /* if (root.model) {
+                    const location = {6
                         model: '.data/ts_code/' + root.model + '/data.pkl',
                         code: '.data/ts_code/code/',
                         data: '.data/',
@@ -2999,17 +2998,12 @@ pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
                     const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution, location, name);
                     this._graphs.push(graph);
                 }
-                else {
-                    const obj = pytorch.Utility.findModule(root);
-                    if (Array.isArray(obj) && obj.length === 1) {
-                        obj[0].type = 'module';
-                        obj[0].name = obj[0].name || name;
-                        this._graphs.push(obj[0]);
-                    }
-                    else {
-                        throw new pytorch.Error('Unsupported packaged model.');
-                    }
-                }
+                else { */
+                this._graphs.push({
+                    name: name,
+                    type: 'module',
+                    data: root
+                });
             }
         }
         return this._graphs;
@@ -3915,11 +3909,11 @@ pytorch.Utility = class {
                 }
                 if (obj) {
                     if (obj._modules) {
-                        return [ { name: '', obj: obj } ];
+                        return [ { name: '', data: obj } ];
                     }
                     const objKeys = Object.keys(obj).filter((key) => obj[key] && obj[key]._modules);
                     if (objKeys.length > 1) {
-                        return objKeys.map((key) => { return { name: key, obj: obj[key] }; });
+                        return objKeys.map((key) => { return { name: key, data: obj[key] }; });
                     }
                 }
             }
@@ -3967,7 +3961,7 @@ pytorch.Utility = class {
             const argument = { id: '', value: obj };
             const parameter = { name: 'value', arguments: [ argument ] };
             layers.push({ states: [ parameter ] });
-            return [ { layers: layers } ];
+            return [ { data: layers } ];
         }
         return null;
     }
@@ -3989,7 +3983,7 @@ pytorch.Utility = class {
                     }
                 }
                 layers.push(layer);
-                return [ { layers: layers } ];
+                return [ { data: layers } ];
             }
             if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
                 const layers = [];
@@ -4011,7 +4005,7 @@ pytorch.Utility = class {
                     }
                     layers.push(layer);
                 }
-                return [ { layers: layers } ];
+                return [ { data: layers } ];
             }
         }
         return null;
@@ -4200,7 +4194,7 @@ pytorch.Utility = class {
                 }
                 graphs.push({
                     name: graph_key,
-                    layers: layers.values()
+                    data: layers.values()
                 });
             }
             return graphs;

+ 1 - 1
test/models.json

@@ -4934,7 +4934,7 @@
     "type":     "pytorch",
     "target":   "v3_1_ru.pt",
     "source":   "https://github.com/lutzroeder/netron/files/9075630/v3_1_ru.pt.zip[v3_1_ru.pt]",
-    "error":    "Found non-callable @@iterator in 'v3_1_ru.pt'.",
+    "format":   "PyTorch Package v1.9",
     "link":     "https://github.com/lutzroeder/netron/issues/928"
   },
   {