Selaa lähdekoodia

TensorFlow Lite metadata prototype (#481)

Lutz Roeder 5 vuotta sitten
vanhempi
sitoutus
b7162d366c
10 muutettua tiedostoa jossa 286 lisäystä ja 188 poistoa
  1. 1 1
      src/armnn-schema.js
  2. 2 2
      src/armnn.js
  3. 1 1
      src/mnn-schema.js
  4. 2 2
      src/mnn.js
  5. 150 150
      src/tflite-schema.js
  6. 113 20
      src/tflite.js
  7. 5 5
      test/models.json
  8. 1 1
      tools/armnn
  9. 1 1
      tools/mnn
  10. 10 5
      tools/tflite

+ 1 - 1
src/armnn-schema.js

@@ -14694,5 +14694,5 @@ armnnSerializer.SerializedGraph.createSerializedGraph = function(builder, layers
 }
 
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
-  module.exports = armnnSerializer;
+  module.exports = { armnn_schema: armnnSerializer };
 }

+ 2 - 2
src/armnn.js

@@ -17,13 +17,13 @@ armnn.ModelFactory = class {
     }
 
     open(context, host) {
-        return host.require('./armnn-schema').then((armnn_schema) => {
+        return host.require('./armnn-schema').then((schema) => {
             const identifier = context.identifier;
             let model = null;
             try {
                 const buffer = context.buffer;
                 const byteBuffer = new flatbuffers.ByteBuffer(buffer);
-                armnn.schema = armnn_schema;
+                armnn.schema = schema.armnn_schema;
                 model = armnn.schema.SerializedGraph.getRootAsSerializedGraph(byteBuffer);
             }
             catch (error) {

+ 1 - 1
src/mnn-schema.js

@@ -18360,5 +18360,5 @@ MNN.Net.createNet = function(builder, bizCodeOffset, extraTensorDescribeOffset,
 }
 
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
-  module.exports = MNN;
+  module.exports = { mnn_schema: MNN };
 }

+ 2 - 2
src/mnn.js

@@ -16,11 +16,11 @@ mnn.ModelFactory = class {
     }
 
     open(context, host) {
-        return host.require('./mnn-schema').then((mnn_schema) => {
+        return host.require('./mnn-schema').then((schema) => {
             return mnn.Metadata.open(host).then((metadata) => {
                 const identifier = context.identifier;
                 try {
-                    mnn.schema = mnn_schema;
+                    mnn.schema = schema.mnn_schema;
                     const byteBuffer = new flatbuffers.ByteBuffer(context.buffer);
                     const net = mnn.schema.Net.getRootAsNet(byteBuffer);
                     return new mnn.Model(metadata, net);

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 150 - 150
src/tflite-schema.js


+ 113 - 20
src/tflite.js

@@ -21,12 +21,13 @@ tflite.ModelFactory = class {
     }
 
     open(context, host) {
-        return host.require('./tflite-schema').then((tflite_schema) => {
+        return host.require('./tflite-schema').then((schema) => {
             return tflite.Metadata.open(host).then((metadata) => {
                 const identifier = context.identifier;
                 try {
                     const buffer = new flatbuffers.ByteBuffer(context.buffer);
-                    tflite.schema = tflite_schema;
+                    tflite.schema = schema.tflite_schema;
+                    tflite.metadata_schema = schema.tflite_metadata_schema;
                     if (!tflite.schema.Model.bufferHasIdentifier(buffer)) {
                         throw new tflite.Error("File format is not tflite.Model.");
                     }
@@ -72,12 +73,7 @@ tflite.Model = class {
             }
             operators.push(custom ? { name: name, custom: true } : { name: name });
         }
-        const subgraphsLength = model.subgraphsLength();
-        for (let i = 0; i < subgraphsLength; i++) {
-            const subgraph = model.subgraphs(i);
-            const name = subgraphsLength > 1 ? i.toString() : '';
-            this._graphs.push(new tflite.Graph(metadata, subgraph, name, operators, model));
-        }
+        let modelMetadata = null;
         for (let i = 0; i < model.metadataLength(); i++) {
             const metadata = model.metadata(i);
             switch (metadata.name()) {
@@ -86,20 +82,55 @@ tflite.Model = class {
                     this._runtime = data ? new TextDecoder().decode(data) : undefined;
                     break;
                 }
+                case 'TFLITE_METADATA': {
+                    const buffer = new flatbuffers.ByteBuffer(model.buffers(metadata.buffer()).dataArray() || []);
+                    if (tflite.metadata_schema.ModelMetadata.bufferHasIdentifier(buffer)) {
+                        modelMetadata = tflite.metadata_schema.ModelMetadata.getRootAsModelMetadata(buffer);
+                        this._name = modelMetadata.name() || '';
+                        this._version = modelMetadata.version() || '';
+                        this._description = modelMetadata.description() ? [ this.description, modelMetadata.description()].join(' ') : this._description;
+                        this._author = modelMetadata.author() || '';
+                        this._license = modelMetadata.license() || '';
+                    }
+                    break;
+                }
             }
         }
+        const subgraphsLength = model.subgraphsLength();
+        for (let i = 0; i < subgraphsLength; i++) {
+            const subgraph = model.subgraphs(i);
+            const name = subgraphsLength > 1 ? i.toString() : '';
+            const subgraphMetadata = modelMetadata && i < modelMetadata.subgraphMetadataLength() ? modelMetadata.subgraphMetadata(i) : null;
+            this._graphs.push(new tflite.Graph(metadata, subgraph, subgraphMetadata, name, operators, model));
+        }
     }
 
     get format() {
         return this._format;
     }
 
+    get runtime() {
+        return this._runtime;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get version() {
+        return this._version;
+    }
+
     get description() {
         return this._description;
     }
 
-    get runtime() {
-        return this._runtime;
+    get author() {
+        return this._author;
+    }
+
+    get license() {
+        return this._license;
     }
 
     get graphs() {
@@ -109,31 +140,77 @@ tflite.Model = class {
 
 tflite.Graph = class {
 
-    constructor(metadata, graph, name, operators, model) {
-        this._name = graph.name() || name;
+    constructor(metadata, subgraph, subgraphMetadata, name, operators, model) {
+        this._name = subgraph.name() || name;
         this._nodes = [];
         this._inputs = [];
         this._outputs = [];
         const args = [];
         const tensorNames = [];
-        for (let i = 0; i < graph.tensorsLength(); i++) {
-            const tensor = graph.tensors(i);
+        for (let i = 0; i < subgraph.tensorsLength(); i++) {
+            const tensor = subgraph.tensors(i);
             const buffer = model.buffers(tensor.buffer());
             const is_variable = tensor.isVariable();
             const initializer = buffer.dataLength() > 0 || is_variable ? new tflite.Tensor(i, tensor, buffer, is_variable) : null;
             args.push(new tflite.Argument(i, tensor, initializer));
             tensorNames.push(tensor.name());
         }
-        for (let i = 0; i < graph.operatorsLength(); i++) {
-            const node = graph.operators(i);
+        for (let i = 0; i < subgraph.operatorsLength(); i++) {
+            const node = subgraph.operators(i);
             const index = node.opcodeIndex();
             const operator = index < operators.length ? operators[index] : { name: '(' + index.toString() + ')' };
             this._nodes.push(new tflite.Node(metadata, node, operator, i.toString(), args));
         }
-        const inputs = Array.from(graph.inputsArray() || []);
-        this._inputs = inputs.map((input) => new tflite.Parameter(tensorNames[input], true, [ args[input] ]));
-        const outputs = Array.from(graph.outputsArray() || []);
-        this._outputs = outputs.map((output) => new tflite.Parameter(tensorNames[output], true, [ args[output] ]));
+        const applyTensorMetadata = (argument, tensorMetadata) => {
+            if (tensorMetadata) {
+                const description = tensorMetadata.description();
+                if (description) {
+                    argument.description = description;
+                }
+                const content = tensorMetadata.content();
+                if (argument.type && content) {
+                    let denotation = null;
+                    switch (content.contentPropertiesType()) {
+                        case 1: {
+                            denotation = 'Feature';
+                            break;
+                        }
+                        case 2: {
+                            denotation = 'Image';
+                            const imageProperties = content.contentProperties(Reflect.construct(tflite.metadata_schema.ImageProperties, []));
+                            switch(imageProperties.colorSpace()) {
+                                case 1: denotation += '(RGB)'; break;
+                                case 2: denotation += '(Grayscale)'; break;
+                            }
+                            break;
+                        }
+                        case 3: {
+                            denotation = 'BoundingBox';
+                            break;
+                        }
+                    }
+                    if (denotation) {
+                        argument.type.denotation = denotation;
+                    }
+                }
+            }
+        };
+        for (let i = 0; i < subgraph.inputsLength(); i++) {
+            const input = subgraph.inputs(i);
+            const argument = args[input];
+            if (subgraphMetadata && i < subgraphMetadata.inputTensorMetadataLength()) {
+                applyTensorMetadata(argument, subgraphMetadata.inputTensorMetadata(i));
+            }
+            this._inputs.push(new tflite.Parameter(tensorNames[input], true, [ argument ]));
+        }
+        for (let i = 0; i < subgraph.outputsLength(); i++) {
+            const output = subgraph.outputs(i);
+            const argument = args[output];
+            if (subgraphMetadata && i < subgraphMetadata.outputTensorMetadataLength()) {
+                applyTensorMetadata(argument, subgraphMetadata.outputTensorMetadata(i));
+            }
+            this._outputs.push(new tflite.Parameter(tensorNames[output], true, [ argument ]));
+        }
     }
 
     get name() {
@@ -432,6 +509,14 @@ tflite.Argument = class {
         return this._quantization;
     }
 
+    set description(value) {
+        this._description = value;
+    }
+
+    get description() {
+        return this._description;
+    }
+
     get initializer() {
         return this._initializer;
     }
@@ -613,6 +698,14 @@ tflite.TensorType = class {
         return this._shape;
     }
 
+    set denotation(value) {
+        this._denotation = value;
+    }
+
+    get denotation() {
+        return this._denotation;
+    }
+
     toString() {
         return this.dataType + this._shape.toString();
     }

+ 5 - 5
test/models.json

@@ -5186,17 +5186,17 @@
   },
   {
     "type":   "tflite",
-    "target": "mobilenet_v1_1.0_224.lite",
-    "source": "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz[./mobilenet_v1_1.0_224.tflite]",
+    "target": "mobilenet_v1_1.0_224_quant.tflite",
+    "source": "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz[./mobilenet_v1_1.0_224_quant.tflite]",
     "format": "TensorFlow Lite v3",
     "link":   "https://www.tensorflow.org/lite/models"
   },
   {
     "type":   "tflite",
-    "target": "mobilenet_v1_1.0_224_quant.tflite",
-    "source": "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz[./mobilenet_v1_1.0_224_quant.tflite]",
+    "target": "mobilenet_v1_0.75_160_quantized.tflite",
+    "source": "https://github.com/lutzroeder/netron/files/4569400/mobilenet_v1_0.75_160_quantized.zip[mobilenet_v1_0.75_160_quantized.tflite]",
     "format": "TensorFlow Lite v3",
-    "link":   "https://www.tensorflow.org/lite/models"
+    "link":   "https://github.com/lutzroeder/netron/issues/481"
   },
   {
     "type":   "tflite",

+ 1 - 1
tools/armnn

@@ -57,7 +57,7 @@ schema() {
     mv ./tools/ArmnnSchema_generated.js ./src/armnn-schema.js
     cat <<EOT >> ./src/armnn-schema.js
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
-  module.exports = armnnSerializer;
+  module.exports = { armnn_schema: armnnSerializer };
 }
 EOT
     if [[ -n ${crlf} ]]; then

+ 1 - 1
tools/mnn

@@ -57,7 +57,7 @@ schema() {
     mv ./tools/MNN_generated.js ./src/mnn-schema.js
     cat <<EOT >> ./src/mnn-schema.js
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
-  module.exports = MNN;
+  module.exports = { mnn_schema: MNN };
 }
 EOT
     if [[ -n ${crlf} ]]; then

+ 10 - 5
tools/tflite

@@ -71,15 +71,20 @@ schema() {
             ;;
     esac
     [[ $(grep -U $'\x0D' ./src/tflite-schema.js) ]] && crlf=1
-    sed 's/namespace tflite;/namespace TFLITE;/g' < ./third_party/src/tensorflow/tensorflow/lite/schema/schema.fbs > ./tools/tflite.schema.fbs
-    flatc --no-js-exports --js ./tools/tflite.schema.fbs
-    rm ./tools/tflite.schema.fbs
-    mv ./tflite.schema_generated.js ./src/tflite-schema.js
+    sed 's/namespace tflite;/namespace tflite_schema;/g' < ./third_party/src/tensorflow/tensorflow/lite/schema/schema.fbs > ./tools/tflite_schema.fbs
+    sed 's/namespace tflite;/namespace tflite_metadata_schema;/g' < ./third_party/src/tensorflow/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs > ./tools/tflite_metadata_schema.fbs
+    flatc --no-js-exports --js ./tools/tflite_schema.fbs
+    flatc --no-js-exports --js ./tools/tflite_metadata_schema.fbs
+    mv ./tflite_schema_generated.js ./src/tflite-schema.js
+    cat ./tflite_metadata_schema_generated.js >> ./src/tflite-schema.js
     cat <<EOT >> ./src/tflite-schema.js
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {
-  module.exports = TFLITE;
+  module.exports = { tflite_schema: tflite_schema, tflite_metadata_schema: tflite_metadata_schema };
 }
 EOT
+    rm ./tools/tflite_schema.fbs
+    rm ./tools/tflite_metadata_schema.fbs
+    rm ./tflite_metadata_schema_generated.js
     if [[ -n ${crlf} ]]; then
         unix2dos --quiet --newfile ./src/tflite-schema.js ./src/tflite-schema.js
     fi

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä