Browse Source

Update paddle.js

Lutz Roeder 3 years ago
parent
commit
a01bb522e7
1 changed files with 18 additions and 13 deletions
  1. 18 13
      source/paddle.js

+ 18 - 13
source/paddle.js

@@ -130,7 +130,7 @@ paddle.ModelFactory = class {
                         const loadParams = (metadata, program, stream) => {
                             const weights = new Map();
                             while (stream.position < stream.length) {
-                                const tensor = paddle.Utility.openTensor(stream);
+                                const tensor = paddle.Utility.openTensorDesc(stream);
                                 weights.set(program.vars.shift(), tensor);
                             }
                             return weights;
@@ -159,7 +159,7 @@ paddle.ModelFactory = class {
                                     const promises = program.vars.map((name) => context.request(name, null));
                                     return Promise.all(promises).then((streams) => {
                                         for (let i = 0; i < program.vars.length; i++) {
-                                            const tensor = paddle.Utility.openTensor(streams[i]);
+                                            const tensor = paddle.Utility.openTensorDesc(streams[i]);
                                             tensors.set(program.vars[i], tensor);
                                         }
                                         return createModel(metadata, program.format, program.desc, tensors);
@@ -167,11 +167,11 @@ paddle.ModelFactory = class {
                                         return createModel(metadata, program.format, program.desc, tensors);
                                     });
                                 };
-                                const openPickle = (stream, weights) => {
+                                const openNumPyArrayPickle = (stream, weights) => {
                                     const execution = new python.Execution(null);
                                     const unpickler = python.Unpickler.open(stream);
                                     const obj = unpickler.load((name, args) => execution.invoke(name, args));
-                                    paddle.Utility.openPickle(obj, weights);
+                                    paddle.Utility.openNumPyArrayList(obj, weights);
                                 };
                                 const program = openProgram(context.stream, match);
                                 if (extension === 'pdmodel') {
@@ -181,16 +181,16 @@ paddle.ModelFactory = class {
                                     }).catch((/* err */) => {
                                         const weights = new Map();
                                         return context.request(base + '.pdparams', null).then((stream) => {
-                                            openPickle(stream, weights);
+                                            openNumPyArrayPickle(stream, weights);
                                             return context.request(base + '.pdopt', null).then((stream) => {
-                                                openPickle(stream, weights);
+                                                openNumPyArrayPickle(stream, weights);
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             }).catch((/* err */) => {
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             });
                                         }).catch((/* err */) => {
                                             return context.request(base + '.pdopt', null).then((stream) => {
-                                                openPickle(stream, weights);
+                                                openNumPyArrayPickle(stream, weights);
                                                 return createModel(metadata, program.format, program.desc, weights);
                                             }).catch((/* err */) => {
                                                 return loadEntries(context, program);
@@ -781,7 +781,7 @@ paddle.Utility = class {
         return new paddle.TensorType(dataType, new paddle.TensorShape(shape));
     }
 
-    static openTensor(stream) {
+    static openTensorDesc(stream) {
         const signature = stream.read(16);
         if (!signature.every((value) => value === 0x00)) {
             throw new paddle.Error('Invalid paddle.TensorDesc signature.');
@@ -806,12 +806,12 @@ paddle.Utility = class {
         return new paddle.Tensor(type, data);
     }
 
-    static openPickle(obj, weights) {
+    static openNumPyArrayList(obj, weights) {
         const map = null; // this._data['StructuredToParameterName@@'];
         for (const entry of Object.entries(obj)) {
             const key = entry[0];
             const value = entry[1];
-            if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
+            if (paddle.Utility.isNumPyArray(value)) {
                 const name = map ? map[key] : key;
                 const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
                 const data = value.data;
@@ -820,6 +820,10 @@ paddle.Utility = class {
             }
         }
     }
+
+    static isNumPyArray(value) {
+        return value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray';
+    }
 };
 
 paddle.Entries = class {
@@ -869,7 +873,7 @@ paddle.Entries = class {
                 if (entry[0].startsWith(rootFolder)) {
                     const name = entry[0].substring(rootFolder.length);
                     const stream = entry[1];
-                    const tensor = paddle.Utility.openTensor(stream);
+                    const tensor = paddle.Utility.openTensorDesc(stream);
                     this._weights.set(name, tensor);
                 }
             }
@@ -881,7 +885,8 @@ paddle.Pickle = class {
 
     static open(context) {
         const obj = context.open('pkl');
-        if (obj && !Array.isArray(obj) && Object(obj) === obj) {
+        if (obj && !Array.isArray(obj) && Object(obj) === obj &&
+            Object.entries(obj).filter((entry) => paddle.Utility.isNumPyArray(entry[1])).length > 0) {
             return new paddle.Pickle(obj);
         }
         return null;
@@ -908,7 +913,7 @@ paddle.Pickle = class {
     _initialize() {
         if (!this._weights) {
             this._weights = new Map();
-            paddle.Utility.openPickle(this._data, this._weights);
+            paddle.Utility.openNumPyArrayList(this._data, this._weights);
         }
     }
 };