Browse Source

Add TensorFlow.js Gzip support (#294) (#563)

Lutz Roeder 4 years ago
parent
commit
8c231b067f
3 changed files with 62 additions and 9 deletions
  1. 33 7
      source/tf.js
  2. 22 2
      source/view.js
  3. 7 0
      test/models.json

+ 33 - 7
source/tf.js

@@ -4,6 +4,7 @@
 
 var tf = tf || {};
 var base = base || require('./base');
+var gzip = gzip || require('./gzip');
 var json = json || require('./json');
 var protobuf = protobuf || require('./protobuf');
 
@@ -173,9 +174,11 @@ tf.ModelFactory = class {
             }
         }
         if (extension === 'json') {
-            const obj = context.open('json');
-            if (obj && obj.modelTopology && (obj.format === 'graph-model' || Array.isArray(obj.modelTopology.node))) {
-                return 'tf.json';
+            for (const type of [ 'json', 'json.gz' ]) {
+                const obj = context.open(type);
+                if (obj && obj.modelTopology && (obj.format === 'graph-model' || Array.isArray(obj.modelTopology.node))) {
+                    return 'tf.' + type;
+                }
             }
         }
         if (extension === 'index' || extension === 'ckpt') {
@@ -338,9 +341,9 @@ tf.ModelFactory = class {
                 }
                 return openSavedModel(saved_model, format, producer);
             };
-            const openJson = (context) => {
+            const openJson = (context, type) => {
                 try {
-                    const obj = context.open('json');
+                    const obj = context.open(type);
                     const format = 'TensorFlow.js ' + (obj.format || 'graph-model');
                     const producer = obj.convertedBy || obj.generatedBy || '';
                     const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
@@ -399,7 +402,28 @@ tf.ModelFactory = class {
                     };
                     return Promise.all(shards.values()).then((streams) => {
                         for (const key of shards.keys()) {
-                            shards.set(key, streams.shift().peek());
+                            const stream = streams.shift();
+                            const buffer = stream.peek();
+                            shards.set(key, buffer);
+                        }
+                        if (type === 'json.gz') {
+                            try {
+                                for (const key of shards.keys()) {
+                                    const stream = shards.get(key);
+                                    const archive = gzip.Archive.open(stream);
+                                    if (archive) {
+                                        const entries = archive.entries;
+                                        if (entries.size === 1) {
+                                            const stream = entries.values().next().value;
+                                            const buffer = stream.peek();
+                                            shards.set(key, buffer);
+                                        }
+                                    }
+                                }
+                            }
+                            catch (error) {
+                                // continue regardless of error
+                            }
                         }
                         return openShards(shards);
                     }).catch(() => {
@@ -528,7 +552,9 @@ tf.ModelFactory = class {
                 case 'tf.events':
                     return openEventFile(context);
                 case 'tf.json':
-                    return openJson(context);
+                    return openJson(context, 'json');
+                case 'tf.json.gz':
+                    return openJson(context, 'json.gz');
                 case 'tf.pbtxt.GraphDef':
                     return openTextGraphDef(context);
                 case 'tf.pbtxt.MetaGraphDef':

+ 22 - 2
source/view.js

@@ -1363,6 +1363,26 @@ view.ModelContext = class {
                         }
                         break;
                     }
+                    case 'json.gz': {
+                        try {
+                            const archive = gzip.Archive.open(stream);
+                            if (archive) {
+                                const entries = archive.entries;
+                                if (entries.size === 1) {
+                                    const stream = entries.values().next().value;
+                                    const reader = json.TextReader.open(stream);
+                                    if (reader) {
+                                        const obj = reader.read();
+                                        this._content.set(type, obj);
+                                    }
+                                }
+                            }
+                        }
+                        catch (err) {
+                            // continue regardless of error
+                        }
+                        break;
+                    }
                     case 'pkl': {
                         let unpickler = null;
                         try {
@@ -1595,8 +1615,8 @@ view.ModelFactoryService = class {
                     if (archive) {
                         const entries = archive.entries;
                         containers.set('gzip', entries);
-                        if (archive.entries.size === 1) {
-                            stream = archive.entries.values().next().value;
+                        if (entries.size === 1) {
+                            stream = entries.values().next().value;
                         }
                     }
                 }

+ 7 - 0
test/models.json

@@ -5616,6 +5616,13 @@
     "format": "TensorFlow.js graph-model",
     "link":   "https://github.com/intel/webml-polyfill/issues/880"
   },
+  {
+    "type":   "tfjs",
+    "target": "posenet_mobilenet_float_075_1_default_1.zip",
+    "source": "https://github.com/lutzroeder/netron/files/7204409/posenet_mobilenet_float_075_1_default_1.zip",
+    "format": "TensorFlow.js graph-model",
+    "link":   "https://github.com/lutzroeder/netron/issues/294"
+  },
   {
     "type":   "tfjs",
     "target": "sentiment_cnn_v1/model.json",