فهرست منبع

Add PaddlePaddle test files (#552)

Lutz Roeder 3 سال پیش
والد
کامیت
7506017b08
2فایلهای تغییر یافته به همراه132 افزوده شده و 103 حذف شده
  1. 111 103
      source/paddle.js
  2. 21 0
      test/models.json

+ 111 - 103
source/paddle.js

@@ -174,11 +174,12 @@ paddle.ModelFactory = class {
                                         return createModel(metadata, program.format, program.desc, weights);
                                     });
                                 };
-                                const openNumPyArrayPickle = (stream, weights) => {
+                                const openNumPyArrayPickle = (stream) => {
                                     const execution = new python.Execution(null);
                                     const unpickler = python.Unpickler.open(stream);
                                     const obj = unpickler.load((name, args) => execution.invoke(name, args));
-                                    paddle.Utility.openNumPyArrayList(obj, weights);
+                                    const container = new paddle.Pickle(obj);
+                                    return container.weights || new Map();
                                 };
                                 const program = openProgram(context.stream, match);
                                 if (extension === 'pdmodel') {
@@ -187,18 +188,21 @@ paddle.ModelFactory = class {
                                         const weights = mapParams(params, program);
                                         return createModel(metadata, program.format, program.desc, weights);
                                     }).catch((/* err */) => {
-                                        const weights = new Map();
                                         return context.request(base + '.pdparams', null).then((stream) => {
-                                            openNumPyArrayPickle(stream, weights);
+                                            const weights = openNumPyArrayPickle(stream);
                                             return context.request(base + '.pdopt', null).then((stream) => {
-                                                openNumPyArrayPickle(stream, weights);
+                                                for (const entry of openNumPyArrayPickle(stream)) {
+                                                    if (!weights.has(entry[0])) {
+                                                        weights.set(entry[0], entry[1]);
+                                                    }
+                                                }
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             }).catch((/* err */) => {
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             });
                                         }).catch((/* err */) => {
                                             return context.request(base + '.pdopt', null).then((stream) => {
-                                                openNumPyArrayPickle(stream, weights);
+                                                const weights = openNumPyArrayPickle(stream);
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             }).catch((/* err */) => {
                                                 return loadEntries(context, program);
@@ -774,76 +778,6 @@ paddle.TensorShape = class {
     }
 };
 
-paddle.Utility = class {
-
-    static createTensorType(data_type, shape) {
-        if (!paddle.Utility._dataTypes) {
-            const length = Math.max.apply(null, Object.entries(paddle.DataType).map((entry) => entry[1]));
-            paddle.Utility._dataTypes = new Array(length);
-            const map = new Map([ [ 'bool', 'boolean' ], [ 'bf16', 'bfloat16' ], [ 'fp16', 'float16' ], [ 'fp32', 'float32' ], [ 'fp64', 'float64' ] ]);
-            for (const entry of Object.entries(paddle.DataType)) {
-                const index = entry[1];
-                const key = entry[0].toLowerCase();
-                paddle.Utility._dataTypes[index] = map.has(key) ? map.get(key) : key;
-            }
-        }
-        const dataType = data_type < paddle.Utility._dataTypes.length ? paddle.Utility._dataTypes[data_type] : '?';
-        return new paddle.TensorType(dataType, new paddle.TensorShape(shape));
-    }
-
-    static openTensorDesc(stream) {
-        const signature = stream.read(16);
-        if (!signature.every((value) => value === 0x00)) {
-            throw new paddle.Error('Invalid paddle.TensorDesc signature.');
-        }
-        const length = new base.BinaryReader(stream.read(4)).uint32();
-        const buffer = stream.read(length);
-        const reader = protobuf.BinaryReader.open(buffer);
-        const tensorDesc = paddle.proto.VarType.TensorDesc.decode(reader);
-        const size = tensorDesc.dims.reduce((a, b) => a * b.toNumber(), 1);
-        let itemsize = 0;
-        switch (tensorDesc.data_type) {
-            case paddle.DataType.FP16: itemsize = 2; break;
-            case paddle.DataType.FP32: itemsize = 4; break;
-            case paddle.DataType.FP64: itemsize = 8; break;
-            case paddle.DataType.INT8: itemsize = 1; break;
-            case paddle.DataType.INT16: itemsize = 2; break;
-            case paddle.DataType.INT32: itemsize = 4; break;
-            case paddle.DataType.INT64: itemsize = 8; break;
-            case paddle.DataType.UINT8: itemsize = 1; break;
-            default: throw new paddle.Error("Invalid inference params data type '" + tensorDesc.data_type + "'.");
-        }
-        const type = paddle.Utility.createTensorType(tensorDesc.data_type, tensorDesc.dims);
-        const data = stream.read(itemsize * size);
-        return new paddle.Tensor(type, data);
-    }
-
-    static openNumPyArrayList(obj, weights) {
-        const map = null; // this._data['StructuredToParameterName@@'];
-        for (const entry of Object.entries(obj)) {
-            const key = entry[0];
-            let value = entry[1];
-            if (Array.isArray(value) && value.length === 2 && value[0] === key) {
-                value = value[1];
-            }
-            if (paddle.Utility.isNumPyArray(value)) {
-                const name = map ? map[key] : key;
-                const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
-                const data = value.data;
-                const tensor = new paddle.Tensor(type, data, 'NumPy Array');
-                weights.set(name, tensor);
-            }
-        }
-    }
-
-    static isNumPyArray(value, name) {
-        if (Array.isArray(value) && value.length === 2 && value[0] === name) {
-            value = value[1];
-        }
-        return value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray';
-    }
-};
-
 paddle.Entries = class {
 
     static open(context) {
@@ -865,17 +799,12 @@ paddle.Entries = class {
         return 'PaddlePaddle Weights';
     }
 
-    get model() {
-        this._initialize();
-        return this._model;
-    }
-
     get weights() {
-        this._initialize();
+        this._read();
         return this._weights;
     }
 
-    _initialize() {
+    _read() {
         if (!this._weights) {
             let rootFolder = null;
             for (const entry of this._data) {
@@ -903,36 +832,69 @@ paddle.Pickle = class {
 
     static open(context) {
         const obj = context.open('pkl');
-        if (obj && !Array.isArray(obj) && Object(obj) === obj &&
-            Object.entries(obj).filter((entry) => paddle.Utility.isNumPyArray(entry[1], entry[0])).length > 0) {
-            return new paddle.Pickle(obj);
-        }
-        return null;
+        const container = new paddle.Pickle(obj);
+        return container.weights !== null ? container : null;
     }
 
-    constructor(data) {
-        this._data = data;
+    constructor(obj) {
+        this._weights = null;
+        if (obj && !Array.isArray(obj) && (obj instanceof Map || Object(obj) === obj)) {
+            const entries = (obj) => {
+                return obj instanceof Map ? Array.from(obj) : Object(obj) === obj ? Object.entries(obj) : [];
+            };
+            const filter = (obj) => {
+                const list = [];
+                if (obj && !Array.isArray(obj)) {
+                    for (const entry of entries(obj)) {
+                        const name = entry[0];
+                        if (name !== 'StructuredToParameterName@@') {
+                            let value = entry[1];
+                            value = value && Array.isArray(value) && value.length === 2 && value[0] === name ? value[1] : value;
+                            if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
+                                list.push([ name, value ]);
+                            }
+                        }
+                    }
+                }
+                return list;
+            };
+            const weights = filter(obj);
+            if (weights.length > 0) {
+                this._weights = weights;
+            }
+            else {
+                const list = entries(obj);
+                if (list.filter((entry) => entry[0] !== 'StructuredToParameterName@@').length === 1) {
+                    const weights = filter(list[0][1]);
+                    if (weights.length > 0) {
+                        this._weights = weights;
+                    }
+                }
+                if (this._weights === null && list.filter((entry) => entry[0] === 'StructuredToParameterName@@').length > 0) {
+                    this._weights = [];
+                }
+            }
+        }
     }
 
     get format() {
         return 'PaddlePaddle Pickle';
     }
 
-    get model() {
-        this._initialize();
-        return this._model;
-    }
-
     get weights() {
-        this._initialize();
-        return this._weights;
-    }
-
-    _initialize() {
-        if (!this._weights) {
-            this._weights = new Map();
-            paddle.Utility.openNumPyArrayList(this._data, this._weights);
+        if (this._weights && Array.isArray(this._weights)) {
+            const weights = new Map();
+            for (const entry of this._weights) {
+                const name = entry[0];
+                const value = entry[1];
+                const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
+                const data = value.data;
+                const tensor = new paddle.Tensor(type, data, 'NumPy Array');
+                weights.set(name, tensor);
+            }
+            this._weights = weights;
         }
+        return this._weights;
     }
 };
 
@@ -1035,6 +997,52 @@ paddle.NaiveBuffer = class {
     }
 };
 
+
+paddle.Utility = class {
+
+    static createTensorType(data_type, shape) {
+        if (!paddle.Utility._dataTypes) {
+            const length = Math.max.apply(null, Object.entries(paddle.DataType).map((entry) => entry[1]));
+            paddle.Utility._dataTypes = new Array(length);
+            const map = new Map([ [ 'bool', 'boolean' ], [ 'bf16', 'bfloat16' ], [ 'fp16', 'float16' ], [ 'fp32', 'float32' ], [ 'fp64', 'float64' ] ]);
+            for (const entry of Object.entries(paddle.DataType)) {
+                const index = entry[1];
+                const key = entry[0].toLowerCase();
+                paddle.Utility._dataTypes[index] = map.has(key) ? map.get(key) : key;
+            }
+        }
+        const dataType = data_type < paddle.Utility._dataTypes.length ? paddle.Utility._dataTypes[data_type] : '?';
+        return new paddle.TensorType(dataType, new paddle.TensorShape(shape));
+    }
+
+    static openTensorDesc(stream) {
+        const signature = stream.read(16);
+        if (!signature.every((value) => value === 0x00)) {
+            throw new paddle.Error('Invalid paddle.TensorDesc signature.');
+        }
+        const length = new base.BinaryReader(stream.read(4)).uint32();
+        const buffer = stream.read(length);
+        const reader = protobuf.BinaryReader.open(buffer);
+        const tensorDesc = paddle.proto.VarType.TensorDesc.decode(reader);
+        const size = tensorDesc.dims.reduce((a, b) => a * b.toNumber(), 1);
+        let itemsize = 0;
+        switch (tensorDesc.data_type) {
+            case paddle.DataType.FP16: itemsize = 2; break;
+            case paddle.DataType.FP32: itemsize = 4; break;
+            case paddle.DataType.FP64: itemsize = 8; break;
+            case paddle.DataType.INT8: itemsize = 1; break;
+            case paddle.DataType.INT16: itemsize = 2; break;
+            case paddle.DataType.INT32: itemsize = 4; break;
+            case paddle.DataType.INT64: itemsize = 8; break;
+            case paddle.DataType.UINT8: itemsize = 1; break;
+            default: throw new paddle.Error("Invalid inference params data type '" + tensorDesc.data_type + "'.");
+        }
+        const type = paddle.Utility.createTensorType(tensorDesc.data_type, tensorDesc.dims);
+        const data = stream.read(itemsize * size);
+        return new paddle.Tensor(type, data);
+    }
+};
+
 paddle.DataType = {
     BOOL: 0,
     INT16: 1,

+ 21 - 0
test/models.json

@@ -4309,6 +4309,13 @@
     "format":   "OpenVINO IR",
     "link":     "https://download.01.org/openvinotoolkit"
   },
+  {
+    "type":     "paddle",
+    "target":   "adam.pdopt",
+    "source":   "https://github.com/lutzroeder/netron/files/8882900/adam.pdopt.zip[adam.pdopt]",
+    "format":   "PaddlePaddle Pickle",
+    "link":     "https://github.com/lutzroeder/netron/issues/552"
+  },
   {
     "type":     "paddle",
     "target":   "assign.pbtxt",
@@ -4337,6 +4344,20 @@
     "format":   "PaddlePaddle Pickle",
     "link":     "https://github.com/lutzroeder/netron/issues/552"
   },
+  {
+    "type":     "paddle",
+    "target":   "EDVR_M_wo_tsa_SRx4.pdparams",
+    "source":   "https://github.com/lutzroeder/netron/files/8882904/EDVR_M_wo_tsa_SRx4.pdparams.zip[EDVR_M_wo_tsa_SRx4.pdparams]",
+    "format":   "PaddlePaddle Pickle",
+    "link":     "https://github.com/lutzroeder/netron/issues/552"
+  },
+  {
+    "type":     "paddle",
+    "target":   "emb.pdparams",
+    "source":   "https://github.com/lutzroeder/netron/files/8882903/emb.pdparams.zip[emb.pdparams]",
+    "format":   "PaddlePaddle Pickle",
+    "link":     "https://github.com/lutzroeder/netron/issues/552"
+  },
   {
     "type":     "paddle",
     "target":   "lite_naive_model_opt.nb.tar.gz",