Pārlūkot izejas kodu

TensorFlow support in Python server

Lutz Roeder 8 gadi atpakaļ
vecāks
revīzija
efb2917fce
3 mainītis faili ar 35 papildinājumiem un 30 dzēšanām
  1. 1 1
      setup.py
  2. 0 1
      src/netron
  3. 34 28
      src/netron.py

+ 1 - 1
setup.py

@@ -24,7 +24,7 @@ package_data={
         'onnx.js',
         'onnx-operator.json',
         'tf.js',
-        'tf-operator.json',
+        'tf-operator.pb',
         'tflite.js',
         'tflite-operator.json',
         'favicon.ico',

+ 0 - 1
src/netron

@@ -14,7 +14,6 @@ if __name__ == '__main__':
     parser.add_argument('--browse', help='launch web browser', action='store_true')
     parser.add_argument('-t', '--tensor', help='skip removing tensor data', action='store_true')
     args = parser.parse_args()
-    print(args.file)
     if not os.path.exists(args.file):
         print("Model file '" + args.file + "' does not exist.")
         sys.exit(2)

+ 34 - 28
src/netron.py

@@ -29,6 +29,7 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
                 '.jpg':  'image/jpeg',
                 '.ico':  'image/x-icon',
                 '.json': 'application/json',
+                '.pb': 'application/octet-stream',
                 '.ttf': 'font/truetype',
                 '.woff': 'font/woff',
                 '.otf': 'font/opentype',
@@ -94,48 +95,53 @@ class MyHTTPServer(HTTPServer):
         self.RequestHandlerClass.data = data
         self.RequestHandlerClass.verbose = verbose
 
-class OnnxModel:
-    def __init__(self, data, file):
-        self.data = data
-        self.file = file
-    def optimize(self):
-        # Remove raw initializer data
-        model = ModelProto()
-        model.ParseFromString(self.data)
-        for initializer in model.graph.initializer:
-            self.remove_tensor_data(initializer)
-        for node in model.graph.node:
-            for attribute in node.attribute:
-                if attribute.HasField('t'):
-                    self.remove_tensor_data(attribute.t)
-        self.data = model.SerializeToString()
-    def remove_tensor_data(self, tensor):
+def optimize_onnx(model):
+    def remove_tensor_data(tensor):
         del tensor.string_data[:]
         del tensor.int32_data[:]
         del tensor.int64_data[:]
         del tensor.float_data[:]
         tensor.raw_data = None
+    # Remove raw initializer data
+    onnx_model = ModelProto()
+    try:
+        onnx_model.ParseFromString(model.data)
+    except:
+        return False
+    for initializer in onnx_model.graph.initializer:
+        remove_tensor_data(initializer)
+    for node in onnx_model.graph.node:
+        for attribute in node.attribute:
+            if attribute.HasField('t'):
+                remove_tensor_data(attribute.t)
+    model.data = onnx_model.SerializeToString()
+    return True
 
-class TensorFlowLiteModel:
+def optimize_tf(model):
+    return True;
+
+def optimize_tflite(model):
+    return True;
+
+class Model:
     def __init__(self, data, file):
         self.data = data
         self.file = file
-    def optimize(self):
-        return
 
 def serve_data(data, file, verbose=False, browse=False, port=8080, host='localhost', tensor=False):
     server = MyHTTPServer((host, port), MyHTTPRequestHandler)
-    model = None
-    if file.endswith('.tflite'):
-        model = TensorFlowLiteModel(data, file)
-    elif os.path.basename(file) == 'saved_model.pb':
-        print('Not supported.')
-        return
-    else:
-        model = OnnxModel(data, file)
+    model = Model(data, file)
     if not tensor:
         print("Processing '" + file + "'...")
-        model.optimize()
+        ok = False
+        if not ok and file.endswith('.tflite'):
+             ok = optimize_tflite(model)
+        if not ok and os.path.basename(file) == 'saved_model.pb':
+            ok = optimize_tf(model)
+        if not ok and file.endswith('.onnx') or file.endswith('.pb'):
+            ok = optimize_onnx(model)
+        if not ok and file.endswith('.pb'):
+            ok = optimize_tf(model)
     url = 'http://' + host + ':' + str(port)
     print("Serving '" + file + "' at " + url + "...")
     server.initialize_data(model, verbose)