Просмотр исходного кода

Add TensorFlow BundleEntryProto (#342)

Lutz Roeder 6 лет назад
Родитель
Сommit
e466093285
3 измененных файлов с 347 добавлено и 60 удалено
  1. 283 0
      src/tf-proto.js
  2. 61 59
      src/tf.js
  3. 3 1
      tools/tf

+ 283 - 0
src/tf-proto.js

@@ -5379,6 +5379,289 @@
             return TypeSpecProto;
         })();
     
+        tensorflow.BundleHeaderProto = (function() {
+    
+            function BundleHeaderProto(properties) {
+                if (properties)
+                    for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i)
+                        if (properties[keys[i]] != null)
+                            this[keys[i]] = properties[keys[i]];
+            }
+    
+            BundleHeaderProto.prototype.num_shards = 0;
+            BundleHeaderProto.prototype.endianness = 0;
+            BundleHeaderProto.prototype.version = null;
+    
+            BundleHeaderProto.decode = function decode(reader, length) {
+                if (!(reader instanceof $Reader))
+                    reader = $Reader.create(reader);
+                var end = length === undefined ? reader.len : reader.pos + length, message = new $root.tensorflow.BundleHeaderProto();
+                while (reader.pos < end) {
+                    var tag = reader.uint32();
+                    switch (tag >>> 3) {
+                    case 1:
+                        message.num_shards = reader.int32();
+                        break;
+                    case 2:
+                        message.endianness = reader.int32();
+                        break;
+                    case 3:
+                        message.version = $root.tensorflow.VersionDef.decode(reader, reader.uint32());
+                        break;
+                    default:
+                        reader.skipType(tag & 7);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            BundleHeaderProto.decodeText = function decodeText(reader) {
+                var message = new $root.tensorflow.BundleHeaderProto();
+                reader.start();
+                while (!reader.end()) {
+                    var tag = reader.tag();
+                    switch (tag) {
+                    case "num_shards":
+                        message.num_shards = reader.int32();
+                        break;
+                    case "endianness":
+                        message.endianness = reader.enum($root.tensorflow.BundleHeaderProto.Endianness);
+                        break;
+                    case "version":
+                        message.version = $root.tensorflow.VersionDef.decodeText(reader, true);
+                        break;
+                    default:
+                        reader.field(tag, message);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            BundleHeaderProto.Endianness = (function() {
+                var valuesById = {}, values = Object.create(valuesById);
+                values[valuesById[0] = "LITTLE"] = 0;
+                values[valuesById[1] = "BIG"] = 1;
+                return values;
+            })();
+    
+            return BundleHeaderProto;
+        })();
+    
+        tensorflow.BundleEntryProto = (function() {
+    
+            function BundleEntryProto(properties) {
+                this.slices = [];
+                if (properties)
+                    for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i)
+                        if (properties[keys[i]] != null)
+                            this[keys[i]] = properties[keys[i]];
+            }
+    
+            BundleEntryProto.prototype.dtype = 0;
+            BundleEntryProto.prototype.shape = null;
+            BundleEntryProto.prototype.shard_id = 0;
+            BundleEntryProto.prototype.offset = $util.Long ? $util.Long.fromBits(0,0,false) : 0;
+            BundleEntryProto.prototype.size = $util.Long ? $util.Long.fromBits(0,0,false) : 0;
+            BundleEntryProto.prototype.crc32c = 0;
+            BundleEntryProto.prototype.slices = $util.emptyArray;
+    
+            BundleEntryProto.decode = function decode(reader, length) {
+                if (!(reader instanceof $Reader))
+                    reader = $Reader.create(reader);
+                var end = length === undefined ? reader.len : reader.pos + length, message = new $root.tensorflow.BundleEntryProto();
+                while (reader.pos < end) {
+                    var tag = reader.uint32();
+                    switch (tag >>> 3) {
+                    case 1:
+                        message.dtype = reader.int32();
+                        break;
+                    case 2:
+                        message.shape = $root.tensorflow.TensorShapeProto.decode(reader, reader.uint32());
+                        break;
+                    case 3:
+                        message.shard_id = reader.int32();
+                        break;
+                    case 4:
+                        message.offset = reader.int64();
+                        break;
+                    case 5:
+                        message.size = reader.int64();
+                        break;
+                    case 6:
+                        message.crc32c = reader.fixed32();
+                        break;
+                    case 7:
+                        if (!(message.slices && message.slices.length))
+                            message.slices = [];
+                        message.slices.push($root.tensorflow.TensorSliceProto.decode(reader, reader.uint32()));
+                        break;
+                    default:
+                        reader.skipType(tag & 7);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            BundleEntryProto.decodeText = function decodeText(reader) {
+                var message = new $root.tensorflow.BundleEntryProto();
+                reader.start();
+                while (!reader.end()) {
+                    var tag = reader.tag();
+                    switch (tag) {
+                    case "dtype":
+                        message.dtype = reader.enum($root.tensorflow.DataType);
+                        break;
+                    case "shape":
+                        message.shape = $root.tensorflow.TensorShapeProto.decodeText(reader, true);
+                        break;
+                    case "shard_id":
+                        message.shard_id = reader.int32();
+                        break;
+                    case "offset":
+                        message.offset = reader.int64();
+                        break;
+                    case "size":
+                        message.size = reader.int64();
+                        break;
+                    case "crc32c":
+                        message.crc32c = reader.fixed32();
+                        break;
+                    case "slices":
+                        if (!(message.slices && message.slices.length))
+                            message.slices = [];
+                        message.slices.push($root.tensorflow.TensorSliceProto.decodeText(reader, true));
+                        break;
+                    default:
+                        reader.field(tag, message);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            return BundleEntryProto;
+        })();
+    
+        tensorflow.TensorSliceProto = (function() {
+    
+            function TensorSliceProto(properties) {
+                this.extent = [];
+                if (properties)
+                    for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i)
+                        if (properties[keys[i]] != null)
+                            this[keys[i]] = properties[keys[i]];
+            }
+    
+            TensorSliceProto.prototype.extent = $util.emptyArray;
+    
+            TensorSliceProto.decode = function decode(reader, length) {
+                if (!(reader instanceof $Reader))
+                    reader = $Reader.create(reader);
+                var end = length === undefined ? reader.len : reader.pos + length, message = new $root.tensorflow.TensorSliceProto();
+                while (reader.pos < end) {
+                    var tag = reader.uint32();
+                    switch (tag >>> 3) {
+                    case 1:
+                        if (!(message.extent && message.extent.length))
+                            message.extent = [];
+                        message.extent.push($root.tensorflow.TensorSliceProto.Extent.decode(reader, reader.uint32()));
+                        break;
+                    default:
+                        reader.skipType(tag & 7);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            TensorSliceProto.decodeText = function decodeText(reader) {
+                var message = new $root.tensorflow.TensorSliceProto();
+                reader.start();
+                while (!reader.end()) {
+                    var tag = reader.tag();
+                    switch (tag) {
+                    case "extent":
+                        if (!(message.extent && message.extent.length))
+                            message.extent = [];
+                        message.extent.push($root.tensorflow.TensorSliceProto.Extent.decodeText(reader, true));
+                        break;
+                    default:
+                        reader.field(tag, message);
+                        break;
+                    }
+                }
+                return message;
+            };
+    
+            TensorSliceProto.Extent = (function() {
+    
+                function Extent(properties) {
+                    if (properties)
+                        for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i)
+                            if (properties[keys[i]] != null)
+                                this[keys[i]] = properties[keys[i]];
+                }
+    
+                Extent.prototype.start = $util.Long ? $util.Long.fromBits(0,0,false) : 0;
+                Extent.prototype.length = $util.Long ? $util.Long.fromBits(0,0,false) : 0;
+    
+                var $oneOfFields;
+    
+                Object.defineProperty(Extent.prototype, "has_length", {
+                    get: $util.oneOfGetter($oneOfFields = ["length"]),
+                    set: $util.oneOfSetter($oneOfFields)
+                });
+    
+                Extent.decode = function decode(reader, length) {
+                    if (!(reader instanceof $Reader))
+                        reader = $Reader.create(reader);
+                    var end = length === undefined ? reader.len : reader.pos + length, message = new $root.tensorflow.TensorSliceProto.Extent();
+                    while (reader.pos < end) {
+                        var tag = reader.uint32();
+                        switch (tag >>> 3) {
+                        case 1:
+                            message.start = reader.int64();
+                            break;
+                        case 2:
+                            message.length = reader.int64();
+                            break;
+                        default:
+                            reader.skipType(tag & 7);
+                            break;
+                        }
+                    }
+                    return message;
+                };
+    
+                Extent.decodeText = function decodeText(reader) {
+                    var message = new $root.tensorflow.TensorSliceProto.Extent();
+                    reader.start();
+                    while (!reader.end()) {
+                        var tag = reader.tag();
+                        switch (tag) {
+                        case "start":
+                            message.start = reader.int64();
+                            break;
+                        case "length":
+                            message.length = reader.int64();
+                            break;
+                        default:
+                            reader.field(tag, message);
+                            break;
+                        }
+                    }
+                    return message;
+                };
+    
+                return Extent;
+            })();
+    
+            return TensorSliceProto;
+        })();
+    
         return tensorflow;
     })();
     

+ 61 - 59
src/tf.js

@@ -14,69 +14,73 @@ tf.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (extension == 'meta') {
-            const tags = context.tags('pb');
-            if (tags.size === 0) {
-                return false;
-            }
-            return true;
-        }
-        if (extension == 'pb') {
-            if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
+        switch (extension) {
+            case 'meta': {
+                const tags = context.tags('pb');
+                if (tags.size !== 0) {
+                    return true;
+                }
                 return false;
             }
-            if (identifier == 'tfhub_module.pb') {
-                const buffer = context.buffer;
-                if (buffer && buffer.length == 2 && buffer[0] == 0x08 && buffer[1] == 0x03) {
+            case 'pb': {
+                if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
                     return false;
                 }
+                if (identifier == 'tfhub_module.pb') {
+                    const buffer = context.buffer;
+                    if (buffer && buffer.length == 2 && buffer[0] == 0x08 && buffer[1] == 0x03) {
+                        return false;
+                    }
+                }
+                const tags = context.tags('pb');
+                if (tags.size === 0) {
+                    const tags = context.tags('pbtxt');
+                    if (!tags.has('node') && !tags.has('saved_model_schema_version') && !tags.has('meta_graphs') && !tags.has('graph_def')) {
+                        return false;
+                    }
+                    if (tags.has('input_stream') || tags.has('output_stream')) {
+                        return false;
+                    }
+                }
+                else {
+                    // ignore input_0.pb, output_0.pb
+                    if (tags.has(1) && tags.get(1) === 0 && 
+                        tags.has(2) && tags.get(2) === 0 && 
+                        tags.has(9) && tags.get(9) === 2) {
+                        return false;
+                    }
+                    if (Array.from(tags.values()).some((v) => v === 5)) {
+                        return false;
+                    }
+                }
+                return true;    
             }
-            const tags = context.tags('pb');
-            if (tags.size === 0) {
-                const tags = context.tags('pbtxt');
-                if (!tags.has('node') && !tags.has('saved_model_schema_version') && !tags.has('meta_graphs') && !tags.has('graph_def')) {
+            case 'pbtxt':
+            case 'prototxt': {
+                if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
+                    identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
                     return false;
                 }
+                const tags = context.tags('pbtxt');
                 if (tags.has('input_stream') || tags.has('output_stream')) {
                     return false;
                 }
-            }
-            else {
-                // ignore input_0.pb, output_0.pb
-                if (tags.has(1) && tags.get(1) === 0 && 
-                    tags.has(2) && tags.get(2) === 0 && 
-                    tags.has(9) && tags.get(9) === 2) {
-                    return false;
-                }
-                if (Array.from(tags.values()).some((v) => v === 5)) {
+                if (!tags.has('node') && !tags.has('saved_model_schema_version') && !tags.has('meta_graphs') && !tags.has('graph_def')) {
                     return false;
                 }
+                return true;
             }
-            return true;
-        }
-        if (extension == 'pbtxt' || extension == 'prototxt') {
-            if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
-                identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
-                return false;
-            }
-            const tags = context.tags('pbtxt');
-            if (!tags.has('node') && !tags.has('saved_model_schema_version') && !tags.has('meta_graphs') && !tags.has('graph_def')) {
-                return false;
-            }
-            if (tags.has('input_stream') || tags.has('output_stream')) {
-                return false;
-            }
-            return true;
-        }
-        if (extension == 'json') {
-            try {
-                const root = JSON.parse(context.text);
-                if (root && root.format && root.format === 'graph-model' && root.modelTopology) {
-                    return true;
+            case 'json': {
+                try {
+                    const root = JSON.parse(context.text);
+                    if (root && root.format && root.format === 'graph-model' && root.modelTopology) {
+                        return true;
+                    }
                 }
-            }
-            catch (err) {
-                // continue regardless of error
+                catch (err) {
+                    // continue regardless of error
+                }
+                return false;
             }
         }
         return false;
@@ -85,8 +89,6 @@ tf.ModelFactory = class {
     open(context, host) { 
         return host.require('./tf-proto').then(() => {
             tf.proto = protobuf.roots.tf.tensorflow;
-            let graph_def = null;
-            let meta_graph = null;
             let saved_model = null;
             let format = null;
             let producer = null;
@@ -112,7 +114,7 @@ tf.ModelFactory = class {
                     else if (tags.has('graph_def')) {
                         try {
                             if (!saved_model) {
-                                meta_graph = tf.proto.MetaGraphDef.decodeText(prototxt.TextReader.create(context.text));
+                                const meta_graph = tf.proto.MetaGraphDef.decodeText(prototxt.TextReader.create(context.text));
                                 saved_model = new tf.proto.SavedModel();
                                 saved_model.meta_graphs.push(meta_graph);
                                 format = 'TensorFlow MetaGraph';
@@ -124,8 +126,8 @@ tf.ModelFactory = class {
                     }
                     else if (tags.has('node')) {
                         try {
-                            graph_def = tf.proto.GraphDef.decodeText(prototxt.TextReader.create(context.text));
-                            meta_graph = new tf.proto.MetaGraphDef();
+                            const graph_def = tf.proto.GraphDef.decodeText(prototxt.TextReader.create(context.text));
+                            let meta_graph = new tf.proto.MetaGraphDef();
                             meta_graph.graph_def = graph_def;
                             saved_model = new tf.proto.SavedModel();
                             saved_model.meta_graphs.push(meta_graph);
@@ -154,7 +156,7 @@ tf.ModelFactory = class {
                     }
                     try {
                         if (!saved_model && extension == 'meta') {
-                            meta_graph = tf.proto.MetaGraphDef.decode(context.buffer);
+                            const meta_graph = tf.proto.MetaGraphDef.decode(context.buffer);
                             saved_model = new tf.proto.SavedModel();
                             saved_model.meta_graphs.push(meta_graph);
                             format = 'TensorFlow MetaGraph';
@@ -165,8 +167,8 @@ tf.ModelFactory = class {
                     }
                     try {
                         if (!saved_model) {
-                            graph_def = tf.proto.GraphDef.decode(context.buffer);
-                            meta_graph = new tf.proto.MetaGraphDef();
+                            const graph_def = tf.proto.GraphDef.decode(context.buffer);
+                            let meta_graph = new tf.proto.MetaGraphDef();
                             meta_graph.graph_def = graph_def;
                             saved_model = new tf.proto.SavedModel();
                             saved_model.meta_graphs.push(meta_graph);
@@ -187,8 +189,8 @@ tf.ModelFactory = class {
             else {
                 try {
                     const root = JSON.parse(context.text);
-                    graph_def = new tf.proto.GraphDef();
-                    meta_graph = new tf.proto.MetaGraphDef();
+                    let graph_def = new tf.proto.GraphDef();
+                    let meta_graph = new tf.proto.MetaGraphDef();
                     meta_graph.graph_def = graph_def;
                     saved_model = new tf.proto.SavedModel();
                     saved_model.meta_graphs.push(meta_graph);

+ 3 - 1
tools/tf

@@ -74,7 +74,9 @@ schema() {
         ${third_party}/${identifier}/tensorflow/core/framework/resource_handle.proto \
         ${third_party}/${identifier}/tensorflow/core/protobuf/saved_object_graph.proto \
         ${third_party}/${identifier}/tensorflow/core/protobuf/trackable_object_graph.proto \
-        ${third_party}/${identifier}/tensorflow/core/protobuf/struct.proto
+        ${third_party}/${identifier}/tensorflow/core/protobuf/struct.proto \
+        ${third_party}/${identifier}/tensorflow/core/protobuf/tensor_bundle.proto \
+        ${third_party}/${identifier}/tensorflow/core/framework/tensor_slice.proto
 }
 
 metadata() {