2
0
Lutz Roeder 3 жил өмнө
parent
commit
1d38bbbed1
2 өөрчлөгдсөн 122 нэмэгдсэн , 195 устгасан
  1. 101 184
      source/python.js
  2. 21 11
      source/view.js

+ 101 - 184
source/python.js

@@ -2728,7 +2728,7 @@ python.Execution = class {
         this.registerType('sklearn.utils.deprecation.DeprecationDict', class {});
         this.registerType('pickle.Unpickler', class {
             constructor(data) {
-                this._reader = data instanceof Uint8Array ? new python.Unpickler.BinaryReader(data) : new python.Unpickler.StreamReader(data);
+                this._reader = data instanceof Uint8Array ? new python.BinaryReader(data) : new python.StreamReader(data);
                 this.persistent_load = () => {
                     throw new python.Error('Unsupported persistent id.');
                 };
@@ -2738,36 +2738,32 @@ python.Execution = class {
                 const marker = [];
                 let stack = [];
                 const memo = new Map();
-                const OpCode = python.Unpickler.OpCode;
                 while (reader.position < reader.length) {
                     const opcode = reader.byte();
                     // console.log((reader.position - 1).toString() + ' ' + Object.entries(OpCode).find((entry) => entry[1] === opcode)[0]);
+                    // https://svn.python.org/projects/python/trunk/Lib/pickletools.py
+                    // https://github.com/python/cpython/blob/master/Lib/pickle.py
                     switch (opcode) {
-                        case OpCode.PROTO: {
+                        case 128: { // PROTO
                             const version = reader.byte();
                             if (version > 5) {
                                 throw new python.Error("Unsupported protocol version '" + version + "'.");
                             }
                             break;
                         }
-                        case OpCode.GLOBAL: {
+                        case 99: { // GLOBAL 'c'
                             const module = reader.line();
                             const name = reader.line();
                             stack.push(this.find_class(module, name));
                             break;
                         }
-                        case OpCode.STACK_GLOBAL: {
+                        case 147: { // STACK_GLOBAL '\x93' (Protocol 4)
                             const name = stack.pop();
                             const module = stack.pop();
                             stack.push(this.find_class(module, name));
                             break;
                         }
-                        case OpCode.PUT: {
-                            const index = parseInt(reader.line(), 10);
-                            memo.set(index, stack[stack.length - 1]);
-                            break;
-                        }
-                        case OpCode.OBJ: {
+                        case 111: { // OBJ 'o'
                             const args = stack;
                             const cls = args.pop();
                             stack = marker.pop();
@@ -2775,56 +2771,59 @@ python.Execution = class {
                             stack.push(obj);
                             break;
                         }
-                        case OpCode.GET: {
+                        case 112 : { // PUT 'p'
+                            const index = parseInt(reader.line(), 10);
+                            memo.set(index, stack[stack.length - 1]);
+                            break;
+                        }
+                        case 103: { // GET 'g'
                             const index = parseInt(reader.line(), 10);
                             stack.push(memo.get(index));
                             break;
                         }
-                        case OpCode.POP:
+                        case 48: // POP '0'
                             stack.pop();
                             break;
-                        case OpCode.POP_MARK:
+                        case 49: // POP_MARK '1'
                             stack = marker.pop();
                             break;
-                        case OpCode.DUP:
+                        case 50: // DUP '2'
                             stack.push(stack[stack.length-1]);
                             break;
-                        case OpCode.PERSID:
+                        case 80: // PERSID 'P'
                             stack.push(this.persistent_load(reader.line()));
                             break;
-                        case OpCode.BINPERSID:
+                        case 81: // BINPERSID 'Q'
                             stack.push(this.persistent_load(stack.pop()));
                             break;
-                        case OpCode.REDUCE: {
+                        case 82: { // REDUCE 'R'
                             const args = stack.pop();
                             const func = stack.pop();
-                            stack.push(execution.invoke(func, args));
+                            stack.push(this._reduce(func, args));
                             break;
                         }
-                        case OpCode.NEWOBJ: {
+                        case 129: { // NEWOBJ
                             const args = stack.pop();
                             const cls = stack.pop();
-                            // TODO resolved
-                            // cls.__new__(cls, args)
-                            const obj = execution.invoke(cls, args);
+                            const obj = this._newobj(cls, args);
                             stack.push(obj);
                             break;
                         }
-                        case OpCode.NEWOBJ_EX: {
+                        case 146: { // NEWOBJ_EX '\x92' (Protocol 4)
                             const kwargs = stack.pop();
                             const args = stack.pop();
                             const cls = stack.pop();
                             if (Object.entries(kwargs).length > 0) {
                                 throw new python.Error("Unpickle 'NEWOBJ_EX' not implemented.");
                             }
-                            const obj = execution.invoke(cls, args);
+                            const obj = this._newobj(cls, args);
                             stack.push(obj);
                             break;
                         }
-                        case OpCode.BINGET:
+                        case 104: // BINGET 'h'
                             stack.push(memo.get(reader.byte()));
                             break;
-                        case OpCode.INST: {
+                        case 105: { // INST 'i'
                             const module = reader.line();
                             const name = reader.line();
                             const args = stack;
@@ -2836,43 +2835,43 @@ python.Execution = class {
                             stack.push(obj);
                             break;
                         }
-                        case OpCode.LONG_BINGET:
+                        case 106: // LONG_BINGET 'j'
                             stack.push(memo.get(reader.uint32()));
                             break;
-                        case OpCode.BINPUT:
+                        case 113: // BINPUT 'q'
                             memo.set(reader.byte(), stack[stack.length - 1]);
                             break;
-                        case OpCode.LONG_BINPUT:
+                        case 114: // LONG_BINPUT 'r'
                             memo.set(reader.uint32(), stack[stack.length - 1]);
                             break;
-                        case OpCode.BININT:
+                        case 74: // BININT 'J'
                             stack.push(reader.int32());
                             break;
-                        case OpCode.BININT1:
+                        case 75: // BININT1 'K'
                             stack.push(reader.byte());
                             break;
-                        case OpCode.LONG:
+                        case 76: // LONG 'L'
                             stack.push(parseInt(reader.line(), 10));
                             break;
-                        case OpCode.BININT2:
+                        case 77: // BININT2 'M'
                             stack.push(reader.uint16());
                             break;
-                        case OpCode.BINBYTES:
+                        case 66: // BINBYTES 'B' (Protocol 3)
                             stack.push(reader.read(reader.int32()));
                             break;
-                        case OpCode.BINBYTES8:
-                            stack.push(reader.read(reader.int64()));
-                            break;
-                        case OpCode.SHORT_BINBYTES:
+                        case 67: // SHORT_BINBYTES 'C' (Protocol 3)
                             stack.push(reader.read(reader.byte()));
                             break;
-                        case OpCode.FLOAT:
+                        case 142: // BINBYTES8 '\x8e' (Protocol 4)
+                            stack.push(reader.read(reader.int64()));
+                            break;
+                        case 70: // FLOAT 'F'
                             stack.push(parseFloat(reader.line()));
                             break;
-                        case OpCode.BINFLOAT:
+                        case 71: // BINFLOAT 'G'
                             stack.push(reader.float64());
                             break;
-                        case OpCode.INT: {
+                        case 73: { // INT 'I'
                             const value = reader.line();
                             if (value == '01') {
                                 stack.push(true);
@@ -2885,16 +2884,16 @@ python.Execution = class {
                             }
                             break;
                         }
-                        case OpCode.EMPTY_LIST:
+                        case 93: // EMPTY_LIST ']'
                             stack.push([]);
                             break;
-                        case OpCode.EMPTY_TUPLE:
+                        case 41: // EMPTY_TUPLE ')'
                             stack.push([]);
                             break;
-                        case OpCode.EMPTY_SET:
+                        case 143: // EMPTY_SET '\x8f' (Protocol 4)
                             stack.push([]);
                             break;
-                        case OpCode.ADDITEMS: {
+                        case 144: { // ADDITEMS '\x90' (Protocol 4)
                             const items = stack;
                             stack = marker.pop();
                             const obj = stack[stack.length - 1];
@@ -2903,13 +2902,13 @@ python.Execution = class {
                             }
                             break;
                         }
-                        case OpCode.FROZENSET: {
+                        case 145: { // FROZENSET '\x91' (Protocol 4)
                             const items = stack;
                             stack = marker.pop();
                             stack.push(items);
                             break;
                         }
-                        case OpCode.DICT: {
+                        case 100: { // DICT 'd'
                             const items = stack;
                             stack = marker.pop();
                             const dict = {};
@@ -2919,19 +2918,36 @@ python.Execution = class {
                             stack.push(dict);
                             break;
                         }
-                        case OpCode.LIST: {
+                        case 108: { // LIST 'l'
                             const items = stack;
                             stack = marker.pop();
                             stack.push(items);
                             break;
                         }
-                        case OpCode.TUPLE: {
+                        case 116: { // TUPLE 't'
                             const items = stack;
                             stack = marker.pop();
                             stack.push(items);
                             break;
                         }
-                        case OpCode.SETITEM: {
+                        case 133: { // TUPLE1 // '\x85'
+                            stack.push([ stack.pop() ]);
+                            break;
+                        }
+                        case 134: { // TUPLE2 '\x86'
+                            const b = stack.pop();
+                            const a = stack.pop();
+                            stack.push([ a, b ]);
+                            break;
+                        }
+                        case 135: { // TUPLE3 '\x87'
+                            const c = stack.pop();
+                            const b = stack.pop();
+                            const a = stack.pop();
+                            stack.push([ a, b, c ]);
+                            break;
+                        }
+                        case 115: { // SETITEM 's'
                             const value = stack.pop();
                             const key = stack.pop();
                             const obj = stack[stack.length - 1];
@@ -2943,7 +2959,7 @@ python.Execution = class {
                             }
                             break;
                         }
-                        case OpCode.SETITEMS: {
+                        case 117: { // SETITEMS 'u'
                             const items = stack;
                             stack = marker.pop();
                             const obj = stack[stack.length - 1];
@@ -2957,42 +2973,42 @@ python.Execution = class {
                             }
                             break;
                         }
-                        case OpCode.EMPTY_DICT:
+                        case 125: // EMPTY_DICT '}'
                             stack.push({});
                             break;
-                        case OpCode.APPEND: {
+                        case 97: { // APPEND 'a'
                             const append = stack.pop();
                             stack[stack.length-1].push(append);
                             break;
                         }
-                        case OpCode.APPENDS: {
+                        case 101: { // APPENDS 'e'
                             const appends = stack;
                             stack = marker.pop();
                             const list = stack[stack.length - 1];
                             list.push.apply(list, appends);
                             break;
                         }
-                        case OpCode.STRING: {
+                        case 83: { // STRING 'S'
                             const str = reader.line();
                             stack.push(str.substr(1, str.length - 2));
                             break;
                         }
-                        case OpCode.BINSTRING:
+                        case 84: // BINSTRING 'T'
                             stack.push(reader.string(reader.uint32()));
                             break;
-                        case OpCode.SHORT_BINSTRING:
+                        case 85 : // SHORT_BINSTRING 'U'
                             stack.push(reader.string(reader.byte()));
                             break;
-                        case OpCode.UNICODE:
+                        case 86: // UNICODE 'V'
                             stack.push(reader.line());
                             break;
-                        case OpCode.BINUNICODE:
+                        case 88: // BINUNICODE 'X
                             stack.push(reader.string(reader.uint32(), 'utf-8'));
                             break;
-                        case OpCode.SHORT_BINUNICODE:
+                        case 140: // SHORT_BINUNICODE '\x8c' (Protocol 4)
                             stack.push(reader.string(reader.byte(), 'utf-8'));
                             break;
-                        case OpCode.BUILD: {
+                        case 98: { // BUILD 'b'
                             const state = stack.pop();
                             let obj = stack.pop();
                             if (obj.__setstate__) {
@@ -3020,17 +3036,17 @@ python.Execution = class {
                             stack.push(obj);
                             break;
                         }
-                        case OpCode.MARK:
+                        case 40: // MARK '('
                             marker.push(stack);
                             stack = [];
                             break;
-                        case OpCode.NEWTRUE:
+                        case 136: // NEWTRUE '\x88'
                             stack.push(true);
                             break;
-                        case OpCode.NEWFALSE:
+                        case 137: // NEWFALSE '\x89'
                             stack.push(false);
                             break;
-                        case OpCode.LONG1: {
+                        case 138: { // LONG1 '\x8a'
                             const data = reader.read(reader.byte());
                             let number = 0;
                             switch (data.length) {
@@ -3045,41 +3061,28 @@ python.Execution = class {
                             stack.push(number);
                             break;
                         }
-                        case OpCode.LONG4:
+                        case 139: // LONG4 '\x8b'
                             // TODO decode LONG4
                             stack.push(reader.read(reader.uint32()));
                             break;
-                        case OpCode.TUPLE1:
-                            stack.push([ stack.pop() ]);
-                            break;
-                        case OpCode.TUPLE2: {
-                            const b = stack.pop();
-                            const a = stack.pop();
-                            stack.push([ a, b ]);
-                            break;
-                        }
-                        case OpCode.TUPLE3: {
-                            const c = stack.pop();
-                            const b = stack.pop();
-                            const a = stack.pop();
-                            stack.push([ a, b, c ]);
-                            break;
-                        }
-                        case OpCode.MEMOIZE:
+                        case 148: // MEMOIZE '\x94' (Protocol 4)
                             memo.set(memo.size, stack[stack.length - 1]);
                             break;
-                        case OpCode.FRAME:
+                        case 149: // FRAME '\x95' (Protocol 4)
                             reader.read(8);
                             break;
-                        case OpCode.BYTEARRAY8: {
+                        case 150: { // BYTEARRAY8 '\x96' (Protocol 5)
                             stack.push(reader.read(reader.int64()));
                             break;
                         }
-                        case OpCode.NONE:
+                        case 78: // NONE 'N'
                             stack.push(null);
                             break;
-                        case OpCode.STOP:
+                        case 46: // STOP '.'
                             return stack.pop();
+                        case 141: // BINUNICODE8 '\x8d' (Protocol 4)
+                        case 151: // NEXT_BUFFER '\x97' (Protocol 5)
+                        case 152: // READONLY_BUFFER '\x98' (Protocol 5)
                         default:
                             throw new python.Error('Unknown opcode ' + opcode + ' at position ' + (reader.position - 1).toString() + '.');
                     }
@@ -3093,6 +3096,13 @@ python.Execution = class {
             _instantiate(cls, args) {
                 return execution.invoke(cls, args);
             }
+            _newobj(cls, args) {
+                // cls.__new__(cls, args)
+                return execution.invoke(cls, args);
+            }
+            _reduce(func, args) {
+                return execution.invoke(func, args);
+            }
             read(size) {
                 return this._reader.read(size);
             }
@@ -6190,99 +6200,7 @@ python.Execution.Context = class {
     }
 };
 
-python.Unpickler = class {
-
-    static open(data, execution) {
-        const reader = data instanceof Uint8Array ? new python.Unpickler.BinaryReader(data) : new python.Unpickler.StreamReader(data);
-        if (reader.length > 2) {
-            const head = reader.peek(2);
-            if (head[0] === 0x80 && head[1] < 7) {
-                execution = typeof execution === 'function' ? execution() : execution;
-                return execution.invoke('pickle.Unpickler', [ data ]);
-            }
-            reader.seek(-1);
-            const tail = reader.peek(1);
-            reader.seek(0);
-            if (tail[0] === 0x2e) {
-                execution = typeof execution === 'function' ? execution() : execution;
-                return execution.invoke('pickle.Unpickler', [ data ]);
-            }
-        }
-        return null;
-    }
-};
-
-// https://svn.python.org/projects/python/trunk/Lib/pickletools.py
-// https://github.com/python/cpython/blob/master/Lib/pickle.py
-python.Unpickler.OpCode = {
-    MARK: 40,              // '('
-    EMPTY_TUPLE: 41,       // ')'
-    STOP: 46,              // '.'
-    POP: 48,               // '0'
-    POP_MARK: 49,          // '1'
-    DUP: 50,               // '2'
-    BINBYTES: 66,          // 'B' (Protocol 3)
-    SHORT_BINBYTES: 67,    // 'C' (Protocol 3)
-    FLOAT: 70,             // 'F'
-    BINFLOAT: 71,          // 'G'
-    INT: 73,               // 'I'
-    BININT: 74,            // 'J'
-    BININT1: 75,           // 'K'
-    LONG: 76,              // 'L'
-    BININT2: 77,           // 'M'
-    NONE: 78,              // 'N'
-    PERSID: 80,            // 'P'
-    BINPERSID: 81,         // 'Q'
-    REDUCE: 82,            // 'R'
-    STRING: 83,            // 'S'
-    BINSTRING: 84,         // 'T'
-    SHORT_BINSTRING: 85,   // 'U'
-    UNICODE: 86,           // 'V'
-    BINUNICODE: 88,        // 'X'
-    EMPTY_LIST: 93,        // ']'
-    APPEND: 97,            // 'a'
-    BUILD: 98,             // 'b'
-    GLOBAL: 99,            // 'c'
-    DICT: 100,             // 'd'
-    APPENDS: 101,          // 'e'
-    GET: 103,              // 'g'
-    BINGET: 104,           // 'h'
-    INST: 105,             // 'i'
-    LONG_BINGET: 106,      // 'j'
-    LIST: 108,             // 'l'
-    OBJ: 111,              // 'o'
-    PUT: 112,              // 'p'
-    BINPUT: 113,           // 'q'
-    LONG_BINPUT: 114,      // 'r'
-    SETITEM: 115,          // 's'
-    TUPLE: 116,            // 't'
-    SETITEMS: 117,         // 'u'
-    EMPTY_DICT: 125,       // '}'
-    PROTO: 128,
-    NEWOBJ: 129,
-    TUPLE1: 133,           // '\x85'
-    TUPLE2: 134,           // '\x86'
-    TUPLE3: 135,           // '\x87'
-    NEWTRUE: 136,          // '\x88'
-    NEWFALSE: 137,         // '\x89'
-    LONG1: 138,            // '\x8a'
-    LONG4: 139,            // '\x8b'
-    SHORT_BINUNICODE: 140, // '\x8c' (Protocol 4)
-    BINUNICODE8: 141,      // '\x8d' (Protocol 4)
-    BINBYTES8: 142,        // '\x8e' (Protocol 4)
-    EMPTY_SET: 143,        // '\x8f' (Protocol 4)
-    ADDITEMS: 144,         // '\x90' (Protocol 4)
-    FROZENSET: 145,        // '\x91' (Protocol 4)
-    NEWOBJ_EX: 146,        // '\x92' (Protocol 4)
-    STACK_GLOBAL: 147,     // '\x93' (Protocol 4)
-    MEMOIZE: 148,          // '\x94' (Protocol 4)
-    FRAME: 149,            // '\x95' (Protocol 4)
-    BYTEARRAY8: 150,       // '\x96' (Protocol 5)
-    NEXT_BUFFER: 151,      // '\x97' (Protocol 5)
-    READONLY_BUFFER: 152   // '\x98' (Protocol 5)
-};
-
-python.Unpickler.BinaryReader = class {
+python.BinaryReader = class {
 
     constructor(buffer) {
         this._buffer = buffer;
@@ -6317,7 +6235,7 @@ python.Unpickler.BinaryReader = class {
 
     stream(length) {
         const buffer = this.read(length);
-        return new python.Unpickler.BinaryReader(buffer);
+        return new python.BinaryReader(buffer);
     }
 
     peek(length) {
@@ -6397,7 +6315,7 @@ python.Unpickler.BinaryReader = class {
     }
 };
 
-python.Unpickler.StreamReader = class {
+python.StreamReader = class {
 
     constructor(stream) {
         this._stream = stream;
@@ -6526,5 +6444,4 @@ python.Error = class extends Error {
 
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
     module.exports.Execution = python.Execution;
-    module.exports.Unpickler = python.Unpickler;
 }

+ 21 - 11
source/view.js

@@ -4245,20 +4245,30 @@ view.ModelContext = class {
                         case 'pkl': {
                             let unpickler = null;
                             try {
-                                if (stream.length > 2) {
-                                    const archive = zip.Archive.open(stream, 'zlib');
-                                    const data = archive ? archive.entries.get('') : stream;
+                                const archive = zip.Archive.open(stream, 'zlib');
+                                const data = archive ? archive.entries.get('') : stream;
+                                let condition = false;
+                                if (data.length > 2) {
+                                    const head = data.peek(2);
+                                    condition = head[0] === 0x80 && head[1] < 7;
+                                    if (!condition) {
+                                        data.seek(-1);
+                                        const tail = data.peek(1);
+                                        data.seek(0);
+                                        condition = tail[0] === 0x2e;
+                                    }
+                                }
+                                if (condition) {
                                     const signature = [ 0x80, undefined, 0x63, 0x5F, 0x5F, 0x74, 0x6F, 0x72, 0x63, 0x68, 0x5F, 0x5F, 0x2E]; // __torch__.
                                     const torch = signature.length <= data.length && data.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value);
-                                    unpickler = python.Unpickler.open(data, () => {
-                                        const execution = new python.Execution();
-                                        execution.on('resolve', (_, name) => {
-                                            if (!torch || !name.startsWith('__torch__.')) {
-                                                this.exception(new view.Error("Unknown type name '" + name + "'."));
-                                            }
-                                        });
-                                        return execution;
+                                    const execution = new python.Execution();
+                                    execution.on('resolve', (_, name) => {
+                                        if (!torch || !name.startsWith('__torch__.')) {
+                                            this.exception(new view.Error("Unknown type name '" + name + "'."));
+                                        }
                                     });
+                                    const pickle = execution.__import__('pickle');
+                                    unpickler = new pickle.Unpickler(data);
                                 }
                             }
                             catch (err) {