|
|
@@ -29,6 +29,7 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
|
|
|
'.jpg': 'image/jpeg',
|
|
|
'.ico': 'image/x-icon',
|
|
|
'.json': 'application/json',
|
|
|
+ '.pb': 'application/octet-stream',
|
|
|
'.ttf': 'font/truetype',
|
|
|
'.woff': 'font/woff',
|
|
|
'.otf': 'font/opentype',
|
|
|
@@ -94,48 +95,53 @@ class MyHTTPServer(HTTPServer):
|
|
|
self.RequestHandlerClass.data = data
|
|
|
self.RequestHandlerClass.verbose = verbose
|
|
|
|
|
|
-class OnnxModel:
|
|
|
- def __init__(self, data, file):
|
|
|
- self.data = data
|
|
|
- self.file = file
|
|
|
- def optimize(self):
|
|
|
- # Remove raw initializer data
|
|
|
- model = ModelProto()
|
|
|
- model.ParseFromString(self.data)
|
|
|
- for initializer in model.graph.initializer:
|
|
|
- self.remove_tensor_data(initializer)
|
|
|
- for node in model.graph.node:
|
|
|
- for attribute in node.attribute:
|
|
|
- if attribute.HasField('t'):
|
|
|
- self.remove_tensor_data(attribute.t)
|
|
|
- self.data = model.SerializeToString()
|
|
|
- def remove_tensor_data(self, tensor):
|
|
|
+def optimize_onnx(model):
|
|
|
+ def remove_tensor_data(tensor):
|
|
|
del tensor.string_data[:]
|
|
|
del tensor.int32_data[:]
|
|
|
del tensor.int64_data[:]
|
|
|
del tensor.float_data[:]
|
|
|
tensor.raw_data = None
|
|
|
+ # Remove raw initializer data
|
|
|
+ onnx_model = ModelProto()
|
|
|
+ try:
|
|
|
+ onnx_model.ParseFromString(model.data)
|
|
|
+ except:
|
|
|
+ return False
|
|
|
+ for initializer in onnx_model.graph.initializer:
|
|
|
+ remove_tensor_data(initializer)
|
|
|
+ for node in onnx_model.graph.node:
|
|
|
+ for attribute in node.attribute:
|
|
|
+ if attribute.HasField('t'):
|
|
|
+ remove_tensor_data(attribute.t)
|
|
|
+ model.data = onnx_model.SerializeToString()
|
|
|
+ return True
|
|
|
|
|
|
-class TensorFlowLiteModel:
|
|
|
+def optimize_tf(model):
|
|
|
+ return True;
|
|
|
+
|
|
|
+def optimize_tflite(model):
|
|
|
+ return True;
|
|
|
+
|
|
|
+class Model:
|
|
|
def __init__(self, data, file):
|
|
|
self.data = data
|
|
|
self.file = file
|
|
|
- def optimize(self):
|
|
|
- return
|
|
|
|
|
|
def serve_data(data, file, verbose=False, browse=False, port=8080, host='localhost', tensor=False):
|
|
|
server = MyHTTPServer((host, port), MyHTTPRequestHandler)
|
|
|
- model = None
|
|
|
- if file.endswith('.tflite'):
|
|
|
- model = TensorFlowLiteModel(data, file)
|
|
|
- elif os.path.basename(file) == 'saved_model.pb':
|
|
|
- print('Not supported.')
|
|
|
- return
|
|
|
- else:
|
|
|
- model = OnnxModel(data, file)
|
|
|
+ model = Model(data, file)
|
|
|
if not tensor:
|
|
|
print("Processing '" + file + "'...")
|
|
|
- model.optimize()
|
|
|
+ ok = False
|
|
|
+ if not ok and file.endswith('.tflite'):
|
|
|
+ ok = optimize_tflite(model)
|
|
|
+ if not ok and os.path.basename(file) == 'saved_model.pb':
|
|
|
+ ok = optimize_tf(model)
|
|
|
+ if not ok and file.endswith('.onnx') or file.endswith('.pb'):
|
|
|
+ ok = optimize_onnx(model)
|
|
|
+ if not ok and file.endswith('.pb'):
|
|
|
+ ok = optimize_tf(model)
|
|
|
url = 'http://' + host + ':' + str(port)
|
|
|
print("Serving '" + file + "' at " + url + "...")
|
|
|
server.initialize_data(model, verbose)
|