Răsfoiți Sursa

PaddlePaddle support (#198)

Lutz Roeder 4 ani în urmă
părinte
comite
7f54e78afd
3 a modificat fișierele cu 116 adăugiri și 44 ștergeri
  1. 32 0
      source/paddle-metadata.json
  2. 72 43
      source/paddle.js
  3. 12 1
      source/view.js

+ 32 - 0
source/paddle-metadata.json

@@ -28,6 +28,10 @@
       { "name": "strides", "default": [   1,   1 ] }
     ]
   },
+  {
+    "name": "conv2d_transpose",
+    "category": "Layer"
+  },
   {
     "name": "depthwise_conv2d",
     "category": "Layer",
@@ -70,10 +74,22 @@
       { "name": "paddings", "default": [   0,   0 ] }
     ]
   },
+  {
+    "name": "hard_swish",
+    "category": "Activation"
+  },
+  {
+    "name": "hard_sigmoid",
+    "category": "Activation"
+  },
   {
     "name": "relu",
     "category": "Activation"
   },
+  {
+    "name": "sigmoid",
+    "category": "Activation"
+  },
   {
     "name": "reshape",
     "category": "Shape"
@@ -82,6 +98,22 @@
     "name": "reshape2",
     "category": "Shape"
   },
+  {
+    "name": "scale",
+    "category": "Layer"
+  },
+  {
+    "name": "rnn",
+    "category": "Layer"
+  },
+  {
+    "name": "transpose",
+    "category": "Transform"
+  },
+  {
+    "name": "transpose2",
+    "category": "Transform"
+  },
   {
     "name": "softmax",
     "category": "Activation",

+ 72 - 43
source/paddle.js

@@ -8,7 +8,8 @@ paddle.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (identifier === '__model__' || extension === 'paddle' || extension === 'pdmodel') {
+        if (identifier === '__model__' || identifier === 'model' ||
+            extension === 'paddle' || extension === 'pdmodel') {
             return true;
         }
         if (extension === 'pbtxt' || extension === 'txt') {
@@ -20,6 +21,10 @@ paddle.ModelFactory = class {
         if (paddle.Container.open(context)) {
             return true;
         }
+        const stream = context.stream;
+        if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
+            return true;
+        }
         return false;
     }
 
@@ -27,24 +32,21 @@ paddle.ModelFactory = class {
         return paddle.Metadata.open(context).then((metadata) => {
             return context.require('./paddle-proto').then(() => {
                 paddle.proto = protobuf.get('paddle').paddle.framework.proto;
-                const container = paddle.Container.open(context);
-                if (container) {
-                    return new paddle.Model(metadata, container.format, null, container.weights);
-                }
-                else {
-                    let programDesc = null;
-                    let format = 'PaddlePaddle';
-                    const identifier = context.identifier;
-                    const parts = identifier.split('.');
-                    const extension = parts.pop().toLowerCase();
-                    const base = parts.join('.');
+                const stream = context.stream;
+                const identifier = context.identifier;
+                const parts = identifier.split('.');
+                const extension = parts.pop().toLowerCase();
+                const base = parts.join('.');
+                const openProgram = (stream, extension) => {
+                    const program = {};
+                    program.format = 'PaddlePaddle';
                     switch (extension) {
                         case 'pbtxt':
                         case 'txt': {
                             try {
-                                const buffer = context.stream.peek();
+                                const buffer = stream.peek();
                                 const reader = protobuf.TextReader.create(buffer);
-                                programDesc = paddle.proto.ProgramDesc.decodeText(reader);
+                                program.desc = paddle.proto.ProgramDesc.decodeText(reader);
                             }
                             catch (error) {
                                 const message = error && error.message ? error.message : error.toString();
@@ -54,9 +56,9 @@ paddle.ModelFactory = class {
                         }
                         default: {
                             try {
-                                const buffer = context.stream.peek();
+                                const buffer = stream.peek();
                                 const reader = protobuf.Reader.create(buffer);
-                                programDesc = paddle.proto.ProgramDesc.decode(reader);
+                                program.desc = paddle.proto.ProgramDesc.decode(reader);
                             }
                             catch (error) {
                                 const message = error && error.message ? error.message : error.toString();
@@ -65,6 +67,7 @@ paddle.ModelFactory = class {
                             break;
                         }
                     }
+                    const programDesc = program.desc;
                     if (programDesc.version && programDesc.version.version && programDesc.version.version.toNumber) {
                         const version = programDesc.version.version.toNumber();
                         if (version > 0) {
@@ -75,7 +78,7 @@ paddle.ModelFactory = class {
                                     list.pop();
                                 }
                             }
-                            format += ' v' + list.map((item) => item.toString()).join('.');
+                            program.format += ' v' + list.map((item) => item.toString()).join('.');
                         }
                     }
                     const variables = new Set();
@@ -98,30 +101,56 @@ paddle.ModelFactory = class {
                             }
                         }
                     }
-                    const vars = Array.from(variables).sort();
+                    program.vars = Array.from(variables).sort();
+                    return program;
+                };
+                const loadParams = (metadata, program, stream) => {
                     const tensors = new Map();
-                    const load_entries = (programDesc) => {
-                        const promises = vars.map((name) => context.request(name, null));
+                    while (stream.position < stream.length) {
+                        tensors.set(program.vars.shift(), new paddle.Tensor(null, stream));
+                    }
+                    return new paddle.Model(metadata, program.format, program.desc, tensors);
+                };
+                const container = paddle.Container.open(context);
+                if (container) {
+                    return new paddle.Model(metadata, container.format, null, container.weights);
+                }
+                else if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
+                    const file = identifier !== 'params' ? base + '.pdmodel' : 'model';
+                    return context.request(file, null).then((stream) => {
+                        const program = openProgram(stream, '');
+                        return loadParams(metadata, program, context.stream);
+                    });
+                }
+                else {
+                    const program = openProgram(context.stream, extension);
+                    const loadEntries = (context, program) => {
+                        const promises = program.vars.map((name) => context.request(name, null));
+                        const tensors = new Map();
                         return Promise.all(promises).then((streams) => {
-                            for (let i = 0; i < vars.length; i++) {
-                                tensors.set(vars[i], new paddle.Tensor(null, streams[i]));
+                            for (let i = 0; i < program.vars.length; i++) {
+                                tensors.set(program.vars[i], new paddle.Tensor(null, streams[i]));
                             }
-                            return new paddle.Model(metadata, format, programDesc, tensors);
+                            return new paddle.Model(metadata, program.format, program.desc, tensors);
                         }).catch((/* err */) => {
-                            return new paddle.Model(metadata, format, programDesc, tensors);
+                            return new paddle.Model(metadata, program.format, program.desc, tensors);
                         });
                     };
                     if (extension === 'pdmodel') {
                         return context.request(base + '.pdiparams', null).then((stream) => {
-                            while (stream.position < stream.length) {
-                                tensors.set(vars.shift(), new paddle.Tensor(null, stream));
-                            }
-                            return new paddle.Model(metadata, format, programDesc, tensors);
+                            return loadParams(metadata, program, stream);
+                        }).catch((/* err */) => {
+                            return loadEntries(context, program);
+                        });
+                    }
+                    if (identifier === 'model') {
+                        return context.request('params', null).then((stream) => {
+                            return loadParams(metadata, program, stream);
                         }).catch((/* err */) => {
-                            return load_entries(programDesc, null);
+                            return loadEntries(context, program);
                         });
                     }
-                    return load_entries(programDesc, null);
+                    return loadEntries(context, program);
                 }
             });
         });
@@ -498,7 +527,7 @@ paddle.Tensor = class {
     constructor(type, data) {
         this._type = type;
         if (data && !Array.isArray(data)) {
-            if (data.__module__ === 'numpy' && data.__name__ === 'ndarray') {
+            if (data.__class__ && data.__class__.__module__ === 'numpy' && data.__class__.__name__ === 'ndarray') {
                 this._type = new paddle.TensorType(data.dtype.name, new paddle.TensorShape(data.shape));
                 this._data = data.data;
                 this._kind = 'NumPy Array';
@@ -797,7 +826,7 @@ paddle.Container = class {
                     this._weights = new Map();
                     for (const key of Object.keys(this._data)) {
                         const value = this._data[key];
-                        if (value && !Array.isArray(value) && value.__module__ === 'numpy' && value.__name__ === 'ndarray') {
+                        if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
                             const name = map ? map[key] : key;
                             this._weights.set(name, new paddle.Tensor(null, value));
                         }
@@ -826,8 +855,8 @@ paddle.Metadata = class {
     }
 
     constructor(data) {
-        this._map = {};
-        this._attributeCache = {};
+        this._map = new Map();
+        this._attributeCache = new Map();
         if (data) {
             const metadata = JSON.parse(data);
             this._map = new Map(metadata.map((item) => [ item.name, item ]));
@@ -835,22 +864,22 @@ paddle.Metadata = class {
     }
 
     type(name) {
-        return this._map[name] || null;
+        return this._map.get(name) || null;
     }
 
     attribute(type, name) {
-        let map = this._attributeCache[type];
+        let map = this._attributeCache.get(type);
         if (!map) {
-            map = {};
-            const schema = this.type(type);
-            if (schema && schema.attributes && schema.attributes.length > 0) {
-                for (const attribute of schema.attributes) {
-                    map[attribute.name] = attribute;
+            map = new Map();
+            const metadata = this.type(type);
+            if (metadata && metadata.attributes && metadata.attributes.length > 0) {
+                for (const attribute of metadata.attributes) {
+                    map.set(attribute.name, attribute);
                 }
             }
-            this._attributeCache[type] = map;
+            this._attributeCache.set(type, map);
         }
-        return map[name] || null;
+        return map.get(name) || null;
     }
 };
 

+ 12 - 1
source/view.js

@@ -1451,7 +1451,7 @@ view.ModelFactoryService = class {
         this.register('./sklearn', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./pickle', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
-        this.register('./paddle', [ '.pdmodel', '.pdparams', '.paddle', '__model__', '.pbtxt', '.txt', '.tar', '.tar.gz' ]);
+        this.register('./paddle', [ '.pdmodel', '.pdparams', '.pdiparams', '.paddle', '__model__', '.pbtxt', '.txt', '.tar', '.tar.gz', 'model', 'params' ]);
         this.register('./bigdl', [ '.model', '.bigdl' ]);
         this.register('./darknet', [ '.cfg', '.model', '.txt', '.weights' ]);
         this.register('./weka', [ '.model' ]);
@@ -1741,6 +1741,17 @@ view.ModelFactoryService = class {
                             matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
                             matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
                         }
+                        // Paddle
+                        if (matches.length > 0 &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.pdmodel')) &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.pdiparams'))) {
+                            matches = matches.filter((e) => e.name.toLowerCase().endsWith('.pdmodel'));
+                        }
+                        if (matches.length > 0 &&
+                            matches.some((e) => e.name.split('/').pop().toLowerCase() === 'model') &&
+                            matches.some((e) => e.name.split('/').pop().toLowerCase() === 'params')) {
+                            matches = matches.filter((e) => e.name.toLowerCase().split('/').pop().toLowerCase() === 'model');
+                        }
                         // TensorFlow Bundle
                         if (matches.length > 1 &&
                             matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {