Lutz Roeder 3 лет назад
Родитель
Сommit
e6d0e411db
2 измененных файлов с 118 добавлено и 71 удалено
  1. 117 71
      source/python.js
  2. 1 0
      source/pytorch.js

+ 117 - 71
source/python.js

@@ -1639,22 +1639,24 @@ python.Execution = class {
 
     constructor(sources, exceptionCallback) {
         const self = this;
+        const execution = self;
         this._sources = sources || new Map();
         this._exceptionCallback = exceptionCallback;
         this._utf8Decoder = new TextDecoder('utf-8');
         this._unresolved = new Map();
         const dict = class extends Map {};
-        this._registry = new Map();
         this._modules = new dict();
-        this._context = new python.Execution.Context();
+        this._registry = new Map();
         this._builtins = this.register('builtins');
         this._builtins.type = { __module__: 'builtins', __name__: 'type' };
         this._builtins.type.__class__ = this._builtins.type;
         this._builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._builtins.type };
         this._builtins.module.__type__ = this._builtins.module;
+        this._registry.set('__builtin__', this._builtins);
+        this._context = new python.Execution.Context();
+        this._context.setx('__builtins__', this.import('builtins'));
         const typing = this.register('typing');
         this._typing = typing;
-        this.register('__builtin__');
         this.register('_codecs');
         this.register('argparse');
         this.register('collections');
@@ -1677,6 +1679,9 @@ python.Execution = class {
         this.registerType('builtins.method', class {});
         this.registerType('builtins.dict', dict);
         this.registerType('builtins.list', class {});
+        this.registerFunction('builtins.__import__', function(name, globals, locals, fromlist, level) {
+            return execution.__import__(name, globals, locals, fromlist, level);
+        });
         this.registerFunction('builtins.bool', function(value) {
             if (value) {
                 if (value.__bool__) {
@@ -1984,7 +1989,7 @@ python.Execution = class {
                 }
                 const size = this.dtype.itemsize * this.shape.reduce((a, b) => a * b, 1);
                 this.data = unpickler.read(size);
-                return self.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
+                return execution.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
             }
         });
         this.registerType('keras.engine.sequential.Sequential', class {});
@@ -2295,7 +2300,7 @@ python.Execution = class {
                 }
                 const size = this.dtype.itemsize * this.shape.reduce((a, b) => a * b, 1);
                 this.data = unpickler.read(size);
-                return self.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
+                return execution.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
             }
         });
         this.registerType('sklearn.externals.joblib.numpy_pickle.NDArrayWrapper', class {
@@ -2307,7 +2312,7 @@ python.Execution = class {
                 this.allow_mmap = state.allow_mmap;
             }
             __read__(/* unpickler */) {
-                return this; // return self.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
+                return this; // return execution.invoke(this.subclass, [ this.shape, this.dtype, this.data ]);
             }
         });
         this.registerType('sklearn.ensemble._bagging.BaggingClassifier', class {});
@@ -2547,7 +2552,7 @@ python.Execution = class {
                         case OpCode.REDUCE: {
                             const items = stack.pop();
                             const type = stack.pop();
-                            stack.push(self.invoke(type, items));
+                            stack.push(execution.invoke(type, items));
                             break;
                         }
                         case OpCode.NEWOBJ: {
@@ -2555,7 +2560,7 @@ python.Execution = class {
                             const cls = stack.pop();
                             // TODO resolved
                             // cls.__new__(cls, args)
-                            stack.push(self.invoke(cls, args));
+                            stack.push(execution.invoke(cls, args));
                             break;
                         }
                         case OpCode.BINGET:
@@ -2824,12 +2829,12 @@ python.Execution = class {
                 throw new python.Error('Unexpected end of file.');
             }
             find_class(module, name) {
-                self.import(module, null, 0);
-                module = self._modules.get(module);
+                execution.import(module, null, 0);
+                module = execution._modules.get(module);
                 return module[name];
             }
             _instantiate(cls, args) {
-                return self.invoke(cls, args);
+                return execution.invoke(cls, args);
             }
             read(size) {
                 return this._reader.read(size);
@@ -2846,12 +2851,12 @@ python.Execution = class {
         });
         this.registerType('spacy._ml.PrecomputableAffine', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => execution.invoke(name, args), null));
             }
         });
         this.registerType('spacy.syntax._parser_model.ParserModel', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => execution.invoke(name, args), null));
             }
         });
         this.registerType('theano.compile.function_module._constructor_Function', class {});
@@ -2965,22 +2970,22 @@ python.Execution = class {
         });
         this.registerType('thinc.neural._classes.affine.Affine', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state, execution).load());
             }
         });
         this.registerType('thinc.neural._classes.convolution.ExtractWindow', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state, execution).load());
             }
         });
         this.registerType('thinc.neural._classes.feature_extracter.FeatureExtracter', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state, execution).load());
             }
         });
         this.registerType('thinc.neural._classes.feed_forward.FeedForward', class {
             __setstate__(state) {
-                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state, execution).load());
             }
         });
         this.registerType('thinc.neural._classes.function_layer.FunctionLayer', class {
@@ -3030,7 +3035,10 @@ python.Execution = class {
         this.registerType('xgboost.core.Booster', class {});
         this.registerType('xgboost.sklearn.XGBClassifier', class {});
         this.registerType('xgboost.sklearn.XGBRegressor', class {});
-        this.registerFunction('__builtin__.bytearray', function(source, encoding /*, errors */) {
+        this.registerFunction('_codecs.encode', function(obj /*, econding */) {
+            return obj;
+        });
+        this.registerFunction('builtins.bytearray', function(source, encoding /*, errors */) {
             if (source) {
                 if (Array.isArray(source) || source instanceof Uint8Array) {
                     const target = new Uint8Array(source.length);
@@ -3051,7 +3059,7 @@ python.Execution = class {
             }
             return [];
         });
-        this.registerFunction('__builtin__.bytes', function(source, encoding /*, errors */) {
+        this.registerFunction('builtins.bytes', function(source, encoding /*, errors */) {
             if (source) {
                 if (Array.isArray(source) || source instanceof Uint8Array) {
                     const target = new Uint8Array(source.length);
@@ -3071,30 +3079,9 @@ python.Execution = class {
             }
             return [];
         });
-        this.registerFunction('__builtin__.set', function(iterable) {
+        this.registerFunction('builtins.frozenset', function(iterable) {
             return iterable ? iterable : [];
         });
-        this.registerFunction('__builtin__.frozenset', function(iterable) {
-            return iterable ? iterable : [];
-        });
-        this.registerFunction('__builtin__.getattr', function(obj, name, defaultValue) {
-            if (Object.prototype.hasOwnProperty.call(obj, name)) {
-                return obj[name];
-            }
-            return defaultValue;
-        });
-        this.registerFunction('__builtin__.slice', function(start, stop , step) {
-            return [ start, stop, step ];
-        });
-        this.registerFunction('__builtin__.type', function(obj) {
-            return obj ? obj.__class__ : undefined;
-        });
-        this.registerFunction('_codecs.encode', function(obj /*, econding */) {
-            return obj;
-        });
-        this.registerFunction('builtins.bytearray', function(data) {
-            return { data: data };
-        });
         this.registerFunction('builtins.getattr', function(obj, name, defaultValue) {
             if (Object.prototype.hasOwnProperty.call(obj, name)) {
                 return obj[name];
@@ -3105,7 +3092,7 @@ python.Execution = class {
             return iterable ? iterable : [];
         });
         this.registerFunction('builtins.slice', function(start, stop, step) {
-            return { start: start, stop: stop, step: step };
+            return [ start, stop, step ];
         });
         this.registerFunction('cloudpickle.cloudpickle._builtin_type', function(name) {
             return name;
@@ -3477,24 +3464,30 @@ python.Execution = class {
         return null;
     }
 
-    import(name) {
+    import(name, current, level) {
+        if (level) {
+            let bits = current.split('.');
+            if (bits.length < level) {
+                throw new python.Error('Invalid relative import beyond top-level package.');
+            }
+            bits = bits.slice(0, bits.length - level);
+            const base = bits.join('.');
+            name = name ? [ base, name ].join('.') : base;
+        }
         const index = name.lastIndexOf('.');
+        let parent = null;
+        let child = null;
         if (index > 0) {
-            this.import(name.substring(0, index));
+            parent = name.substring(0, index);
+            child = name.substring(index + 1);
+            this.import(parent);
         }
         if (!this._modules.has(name)) {
-            const module = {};
+            const module = this._registry.get(name) || {};
             module.__class__ = this._builtins.module;
             module.__name__ = name;
+            module.__package__ = name;
             this._modules.set(name, module);
-            if (this._registry.has(name)) {
-                const entries = this._registry.get(name);
-                for (const entry of Object.entries(entries)) {
-                    const name = entry[0];
-                    const value = entry[1];
-                    module[name] = value;
-                }
-            }
             const file = name.split('.').join('/') + '.py';
             const program = this.parse(file);
             if (program) {
@@ -3502,10 +3495,66 @@ python.Execution = class {
                 const context = this._context.push(module);
                 this.block(program.body, context);
             }
+            if (parent) {
+                const parent_module = this._modules.get(parent);
+                parent_module[child] = module;
+            }
         }
         return this._modules.get(name);
     }
 
+    __import__(name, globals, locals, fromlist, level) {
+        let module = null;
+        level = level || 0;
+        if (level === 0) {
+            module = this.import(name);
+        }
+        else {
+            globals = globals || {};
+            let current = globals.getx('__package__');
+            if (!current) {
+                const spec = globals.getx('__spec__');
+                if (spec) {
+                    current = spec.parent;
+                }
+                else {
+                    const name = globals['__name__'];
+                    const bits = name.split('.');
+                    bits.pop();
+                    current = bits.join('.');
+                }
+            }
+            module = this.import(name, current, level);
+        }
+        if (!fromlist) {
+            if (level === 0) {
+                return this.import(name.split('.')[0]);
+            }
+            else if (name) {
+                throw new python.Error('');
+                // cut_off = len(name) - len(name.partition('.')[0])
+                // return sys.modules[module.__name__[:len(module.__name__)-cut_off]]
+            }
+        }
+        else if (module.__path__) {
+            const _handle_fromlist = (module, fromlist, import_, recursive) => {
+                for (const x of fromlist) {
+                    if (x == '*') {
+                        if (!recursive && module.__all__) {
+                            _handle_fromlist(module, module.__all__, import_, true);
+                        }
+                    }
+                    else if (!module[x]) {
+                        import_(module.__name__ + '.' + x);
+                    }
+                }
+                return module;
+            };
+            _handle_fromlist(module, fromlist, this.import);
+        }
+        return module;
+    }
+
     type(name) {
         const type = this._context.getx(name);
         if (type !== undefined) {
@@ -3641,17 +3690,7 @@ python.Execution = class {
                 const module = context.get('__name__');
                 const self = this;
                 const parent = context.get('__class__');
-                let type = null;
-                if (parent === this._builtins.type) {
-                    type = this._builtins.method;
-                }
-                else if (parent === this._builtins.module) {
-                    type = this._builtins.function;
-                }
-                else {
-                    type = this._builtins.method;
-                    // throw new python.Error('Invalid function scope.'); // TODO
-                }
+                const type = (parent === this._builtins.module) ? this._builtins.function : this._builtins.method;
                 const func = {
                     __class__: type,
                     __globals__: context,
@@ -3752,8 +3791,13 @@ python.Execution = class {
             }
             case 'import': {
                 for (const alias of statement.names) {
-                    const module = this.import(alias.name);
+                    let module = this.__import__(alias.name, context);
                     if (alias.asname) {
+                        const bits = alias.name.split('.').reverse();
+                        bits.pop();
+                        while (bits.length > 0) {
+                            module = module[bits.pop()];
+                        }
                         context.set(alias.asname, module);
                     }
                     else {
@@ -3765,14 +3809,16 @@ python.Execution = class {
             case 'import_from': {
                 let module = null;
                 if (statement.level > 0) {
-                    let paths = context.getx('__file__').split('/');
-                    paths = paths.slice(0, paths.length - statement.level);
-                    paths.push(statement.module.replace('.', '/'));
-                    const name = paths.join('/');
-                    module = this.import(name);
+                    const fromlist = statement.names.map((name) => name.name);
+                    module = this.__import__(statement.module, context, null, fromlist, statement.level);
                 }
                 else {
-                    module = this._package(statement.module, context);
+                    module = this.__import__(statement.module, context, null, 0);
+                    const bits = statement.module.split('.').reverse();
+                    bits.pop();
+                    while (bits.length > 0) {
+                        module = module[bits.pop()];
+                    }
                 }
                 for (const entry of statement.names) {
                     const name = entry.name;

+ 1 - 0
source/pytorch.js

@@ -1134,6 +1134,7 @@ pytorch.Execution = class extends python.Execution {
         this.registerType('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', class {});
         this.registerType('torch.nn.utils.weight_norm.WeightNorm', class {});
         this.registerType('torch.optim.adam.Adam', class {});
+        this.register('torch.optim').Adam = this._registry.get('torch.optim.adam').Adam;
         this.registerType('torch.optim.adamw.AdamW', class {});
         this.registerType('torch.optim.adagrad.Adagrad', class {});
         this.registerType('torch.optim.adadelta.Adadelta', class {});