소스 검색

Add NumPy test file (#711)

Lutz Roeder 5 달 전
부모
커밋
a11c22c2ca
3개의 변경된 파일170개의 추가작업 그리고 53개의 파일을 삭제
  1. 152 37
      source/python.js
  2. 11 15
      source/view.js
  3. 7 1
      test/models.json

+ 152 - 37
source/python.js

@@ -637,14 +637,38 @@ python.Execution = class {
                 this.type_ignores = type_ignores;
             }
         });
-        this.registerFunction('ast.parse', (source, filename, debug) => {
-            const parser =  new ast._Parser(source, filename, debug);
-            return parser.parse();
+        this.registerFunction('ast.parse', (source, filename, mode, debug) => {
+            const parser =  new ast._Parser();
+            const module = parser.parse(source, filename, debug, mode);
+            return module;
+        });
+        this.registerFunction('ast._convert_literal', (node) => {
+            if (node instanceof ast.Constant) {
+                return node.value;
+            }
+            if (node instanceof ast.Dict && node.keys.length === node.values.length) {
+                const keys = node.keys.map((k) => ast._convert_literal(k));
+                const values = node.values.map((v) => ast._convert_literal(v));
+                return Object.fromEntries(keys.map((k, i) => [k, values[i]]));
+            }
+            if (node instanceof ast.Tuple) {
+                return new builtins.tuple(node.elts.map((e) => ast._convert_literal(e)));
+            }
+            if (node instanceof ast.List) {
+                return new builtins.list(node.elts.map((e) => ast._convert_literal(e)));
+            }
+            throw new python.Error(`'${node.__class__.__name__}' not implemented.`);
+        });
+        this.registerFunction('ast.literal_eval', (node_or_string) => {
+            if (typeof node_or_string === 'string') {
+                node_or_string = ast.parse(node_or_string, '', 'eval').body;
+            } else {
+                throw new python.Error(`'ast.literal_eval' node eval not implemented.`);
+            }
+            return ast._convert_literal(node_or_string);
         });
         this.registerType('ast._Parser', class {
-            constructor(text, file, debug) {
-                this._tokenizer = new ast._Tokenizer(text, file);
-                this._debug = debug;
+            constructor() {
                 ast._Parser._precedence = ast._Parser._precedence || {
                     'or': 2, 'and': 3, 'not' : 4,
                     'in': 5, 'instanceof': 5, 'is': 5, '<': 5, '>': 5, '<=': 5, '>=': 5, '<>': 5, '==': 5, '!=': 5,
@@ -654,9 +678,11 @@ python.Execution = class {
                     '~': 13, '**': 14
                 };
             }
-            parse() {
+            parse(text, file, debug, mode) {
+                this._tokenizer = new ast._Tokenizer(text, file);
+                this._debug = debug;
                 const position = this._position();
-                const body = [];
+                let body = [];
                 while (!this._tokenizer.match('eof')) {
                     const statement = this._parseStatement();
                     if (statement) {
@@ -671,9 +697,15 @@ python.Execution = class {
                     }
                     throw new python.Error(`Unsupported statement ${this._location()}`);
                 }
-                const node = new ast.Module(body);
-                this._mark(node, position);
-                return node;
+                if (mode === 'eval') {
+                    if (body.length !== 1 || body[0] instanceof ast.Expr === false) {
+                        throw new python.Error('Expected expression.');
+                    }
+                    body = body[0].value;
+                }
+                const module = new ast.Module(body);
+                this._mark(module, position);
+                return module;
             }
             _parseSuite() {
                 const body = [];
@@ -1019,6 +1051,7 @@ python.Execution = class {
                         case 'Call':
                         case 'Compare':
                         case 'Constant':
+                        case 'Dict':
                         case 'Ellipsis':
                         case 'For':
                         case 'If':
@@ -2089,9 +2122,12 @@ python.Execution = class {
                     isDecimal = !decimal(c) && c !== '.' && c !== 'e' && c !== 'j';
                 }
                 if (isDecimal) {
-                    if (this._get(i) === 'j' || this._get(i) === 'J' || this._get(i) === 'l' || this._get(i) === 'L') {
+                    if (this._get(i) === 'j' || this._get(i) === 'J') {
                         return { 'type': 'complex', value: this._text.substring(this._position, i + 1) };
                     }
+                    // if (this._get(i) === 'l' || this._get(i) === 'L') {
+                    //     Python 2 long integer
+                    // }
                     const intText = this._text.substring(this._position, i);
                     if (!isNaN(parseInt(intText, 10))) {
                         return { type: 'int', value: intText };
@@ -2304,7 +2340,14 @@ python.Execution = class {
         this.registerType('builtins.dict', dict);
         this.registerType('builtins.ellipsis', class {});
         this.registerType('builtins.cell', class {});
-        this.registerType('builtins.list', class extends Array {});
+        this.registerType('builtins.list', class extends Array {
+            constructor(iterable) {
+                super();
+                if (Array.isArray(iterable)) {
+                    this.push(...iterable);
+                }
+            }
+        });
         this.registerType('builtins.number', class {});
         this.registerFunction('builtins.__import__', (name, globals, locals, fromlist, level) => {
             return execution.__import__(name, globals, locals, fromlist, level);
@@ -2571,6 +2614,17 @@ python.Execution = class {
         });
         this.registerType('numpy.dtype', class {
             constructor(obj, align, copy) {
+                if (typeof obj !== 'string' && obj && Array.isArray(obj.names)) {
+                    this.kind = 'V';
+                    this.byteorder = '|';
+                    this.itemsize = obj.itemsize;
+                    this.names = obj.names;
+                    this.fields = new Map();
+                    for (let i = 0; i < obj.names.length; i++) {
+                        this.fields.set(obj.names[i], new builtins.tuple([obj.formats[i], obj.offsets[i]]));
+                    }
+                    return;
+                }
                 if (typeof obj === 'string' && (obj.startsWith('<') || obj.startsWith('>') || obj.startsWith('|'))) {
                     this.byteorder = obj.substring(0, 1);
                     obj = obj.substring(1);
@@ -2711,6 +2765,8 @@ python.Execution = class {
         });
         this.registerType('numpy.generic', class {});
         this.registerType('numpy.inexact', class {});
+        this.registerType('numpy.flexible', class extends numpy.generic {});
+        this.registerType('numpy.void', class extends numpy.flexible {});
         this.registerType('numpy.bool_', class extends numpy.generic {});
         this.registerType('numpy.number', class extends numpy.generic {});
         this.registerType('numpy.integer', class extends numpy.number {});
@@ -4925,7 +4981,7 @@ python.Execution = class {
         this.register('numpy.core.numeric', numpy._core.numeric);
         numpy._core._multiarray_umath._reconstruct = numpy.core.multiarray._reconstruct;
         this.registerFunction('numpy.load', (file) => {
-            // https://github.com/numpy/numpy/blob/main/numpy/lib/format.py
+            // https://github.com/numpy/numpy/blob/main/numpy/lib/_format_impl.py
             const signature = [0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59];
             if (!file.read(6).every((v, i) => v === signature[i])) {
                 throw new python.Error('Invalid signature.');
@@ -4935,22 +4991,7 @@ python.Execution = class {
             if (major > 3) {
                 throw new python.Error(`Invalid version '${[major, minor].join('.')}'.`);
             }
-            const buffer = new Uint8Array([0, 0, 0, 0]);
-            buffer.set(file.read(major >= 2 ? 4 : 2), 0);
-            const header_length = buffer[3] << 24 | buffer[2] << 16 | buffer[1] << 8 | buffer[0];
-            let header = file.read(header_length);
-            const decoder = new TextDecoder(major >= 3 ? 'utf-8' : 'ascii');
-            header = decoder.decode(header);
-            header = JSON.parse(header.replace(/\(/,'[').replace(/\)/,']').replace('[,','[1,]').replace(',]',']').replace(/'/g, '"').replace(/:\s*False\s*,/,':false,').replace(/:\s*True\s*,/,':true,').replace(/,\s*\}/, ' }'));
-            if (!header.descr || header.descr.length < 2) {
-                throw new python.Error("Missing property 'descr'.");
-            }
-            if (!header.shape) {
-                throw new python.Error("Missing property 'shape'.");
-            }
-            const shape = header.shape;
-            const dtype = self.invoke('numpy.dtype', [header.descr.substring(1)]);
-            dtype.byteorder = header.descr.substring(0, 1);
+            const [shape, fortran_order, dtype] = numpy.lib._format_impl._read_array_header(file, version);
             let data = null;
             switch (dtype.byteorder) {
                 case '|': {
@@ -4963,19 +5004,16 @@ python.Execution = class {
                 }
                 case '>':
                 case '<': {
-                    if (header.descr.length !== 3 && header.descr[1] !== 'U' && header.descr.substring(1) !== 'c16') {
-                        throw new python.Error(`Unsupported data type '${header.descr}'.`);
-                    }
                     const count = shape.length === 0 ? 1 : shape.reduce((a, b) => a * b, 1);
                     const stream = file.getbuffer().nbytes > 0x1000000;
                     data = file.read(dtype.itemsize * count, stream);
                     break;
                 }
                 default: {
-                    throw new python.Error(`Unsupported data type '${header.descr}'.`);
+                    throw new python.Error(`Unsupported data type '${dtype.str}'.`);
                 }
             }
-            if (header.fortran_order) {
+            if (fortran_order) {
                 data = null;
             }
             return self.invoke('numpy.ndarray', [shape, dtype, data]);
@@ -5007,6 +5045,83 @@ python.Execution = class {
             file.write(encoder.encode(header));
             file.write(arr.tobytes());
         });
+        this.registerFunction('numpy.lib._format_impl._read_array_header', (file, version) => {
+            const buffer = new Uint8Array([0, 0, 0, 0]);
+            const [major] = version;
+            buffer.set(file.read(major >= 2 ? 4 : 2), 0);
+            const header_length = buffer[3] << 24 | buffer[2] << 16 | buffer[1] << 8 | buffer[0];
+            let header = file.read(header_length);
+            const decoder = new TextDecoder(major >= 3 ? 'utf-8' : 'ascii');
+            header = decoder.decode(header).trim();
+            try {
+                header = ast.literal_eval(header);
+            } catch {
+                if (major <= 2) {
+                    header = numpy.lib._format_impl._filter_header(header);
+                    header = ast.literal_eval(header);
+                }
+            }
+            if (header.descr === undefined) {
+                throw new python.Error("Invalid 'descr'.");
+            }
+            if (!Array.isArray(header.shape)) {
+                throw new python.Error("Invalid 'shape'.");
+            }
+            const dtype = numpy.lib._format_impl.descr_to_dtype(header.descr);
+            return [header.shape, header.fortran_order, dtype];
+        });
+        this.registerFunction('numpy.lib._format_impl.descr_to_dtype', (descr) => {
+            if (typeof descr === 'string') {
+                return new numpy.dtype(descr);
+            } else if (descr instanceof builtins.tuple) {
+                const dt = numpy.lib._format_impl.descr_to_dtype(descr[0]);
+                return new numpy.dtype([dt, descr[1]]);
+            }
+            const titles = [];
+            const names = [];
+            const formats = [];
+            const offsets = [];
+            let offset = 0;
+            for (const field of descr) {
+                let name = null;
+                let dt = null;
+                let descr_str = null;
+                let shape = null;
+                let title = null;
+                if (field.length === 2) {
+                    [name, descr_str] = field;
+                    dt = numpy.lib._format_impl.descr_to_dtype(descr_str);
+                } else {
+                    [name, descr_str, shape] = field;
+                    dt = new numpy.dtype([numpy.lib._format_impl.descr_to_dtype(descr_str), shape]);
+                }
+                const is_pad = name === '' && dt.type === numpy.void && dt.names === null;
+                if (!is_pad) {
+                    [title, name] = name instanceof builtins.tuple ? name :  [null, name];
+                    titles.push(title);
+                    names.push(name);
+                    formats.push(dt);
+                    offsets.push(offset);
+                }
+                offset += dt.itemsize;
+            }
+            return new numpy.dtype({ names, formats, titles, offsets, itemsize: offset });
+        });
+        this.registerFunction('numpy.lib._format_impl._filter_header', (s) => {
+            const tokens = [];
+            const tokenizer = new ast._Tokenizer(s, '');
+            while (!tokenizer.match('eof')) {
+                const token = tokenizer.read();
+                if (token.type === 'int') {
+                    const next = tokenizer.peek();
+                    if (next.type === 'id' && next.value === 'L') {
+                        tokenizer.read();
+                    }
+                }
+                tokens.push(token.value);
+            }
+            return tokens.join('');
+        });
         this.registerFunction('numpy.amin');
         this.registerFunction('numpy.amax');
         this.registerFunction('numpy.std');
@@ -20188,7 +20303,7 @@ python.Execution = class {
 
     exec(code , context) {
         const ast = this.ast;
-        const program = ast.parse(code, '', null);
+        const program = ast.parse(code, '', null, null);
         if (!program) {
             throw new python.Error("Module '?' parse error.");
         }
@@ -20217,7 +20332,7 @@ python.Execution = class {
     parse(filename, buffer, debug) {
         const ast = this.ast;
         const source = this._utf8Decoder.decode(buffer);
-        const program = ast.parse(source, filename, debug);
+        const program = ast.parse(source, filename, null, debug);
         if (!program) {
             throw new python.Error(`Module '${filename}' parse error.`);
         }

+ 11 - 15
source/view.js

@@ -5998,22 +5998,18 @@ view.Context = class {
                             break;
                         }
                         case 'npz': {
-                            try {
-                                const content = new Map();
-                                const entries = await this.peek('zip');
-                                if (entries instanceof Map && entries.size > 0 &&
-                                    Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) {
-                                    const python = await import('./python.js');
-                                    const execution = new python.Execution();
-                                    for (const [name, stream] of entries) {
-                                        const bytes = execution.invoke('io.BytesIO', [stream]);
-                                        const array = execution.invoke('numpy.load', [bytes]);
-                                        content.set(name, array);
-                                    }
-                                    this._content.set(type, content);
+                            const content = new Map();
+                            const entries = await this.peek('zip');
+                            if (entries instanceof Map && entries.size > 0 &&
+                                Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) {
+                                const python = await import('./python.js');
+                                const execution = new python.Execution();
+                                for (const [name, stream] of entries) {
+                                    const bytes = execution.invoke('io.BytesIO', [stream]);
+                                    const array = execution.invoke('numpy.load', [bytes]);
+                                    content.set(name, array);
                                 }
-                            } catch {
-                                // continue regardless of error
+                                this._content.set(type, content);
                             }
                             break;
                         }

+ 7 - 1
test/models.json

@@ -3934,6 +3934,13 @@
     "format":   "NumPy Archive",
     "link":     "https://github.com/Xilinx/BNN-PYNQ"
   },
+  {
+    "type":     "numpy",
+    "target":   "face_images.npy",
+    "source":   "https://github.com/user-attachments/files/22711415/face_images.npy.zip[face_images.npy]",
+    "format":   "NumPy Array",
+    "link":     "https://github.com/lutzroeder/netron/issues/711"
+  },
   {
     "type":     "numpy",
     "target":   "float8_e5m2.npy",
@@ -3974,7 +3981,6 @@
     "target":   "struct.npy",
     "source":   "https://github.com/user-attachments/files/16075809/struct.npy.zip[struct.npy]",
     "format":   "NumPy Array",
-    "error":    "Unexpected token '(', ...\" \"<U10\"], (\"age\", \"<\"... is not valid JSON",
     "link":     "https://github.com/lutzroeder/netron/issues/711"
   },
   {