|
|
@@ -6,7 +6,6 @@ import platform
|
|
|
import sys
|
|
|
import threading
|
|
|
import webbrowser
|
|
|
-from .onnx_ml_pb2 import ModelProto
|
|
|
|
|
|
if sys.version_info[0] > 2:
|
|
|
from urllib.parse import urlparse
|
|
|
@@ -95,59 +94,14 @@ class MyHTTPServer(HTTPServer):
|
|
|
self.RequestHandlerClass.data = data
|
|
|
self.RequestHandlerClass.verbose = verbose
|
|
|
|
|
|
-def optimize_onnx(model):
|
|
|
- def remove_tensor_data(tensor):
|
|
|
- del tensor.string_data[:]
|
|
|
- del tensor.int32_data[:]
|
|
|
- del tensor.int64_data[:]
|
|
|
- del tensor.float_data[:]
|
|
|
- if tensor.HasField('raw_data'):
|
|
|
- tensor.raw_data = b''
|
|
|
- # Remove raw initializer data
|
|
|
- onnx_model = ModelProto()
|
|
|
- try:
|
|
|
- onnx_model.ParseFromString(model.data)
|
|
|
- except:
|
|
|
- return False
|
|
|
- for initializer in onnx_model.graph.initializer:
|
|
|
- remove_tensor_data(initializer)
|
|
|
- for node in onnx_model.graph.node:
|
|
|
- for attribute in node.attribute:
|
|
|
- if attribute.HasField('t'):
|
|
|
- remove_tensor_data(attribute.t)
|
|
|
- model.data = onnx_model.SerializeToString()
|
|
|
- return True
|
|
|
-
|
|
|
-def optimize_tf(model):
|
|
|
- return True
|
|
|
-
|
|
|
-def optimize_tflite(model):
|
|
|
- return True
|
|
|
-
|
|
|
-def optimize_keras(model):
|
|
|
- return True
|
|
|
-
|
|
|
class Model:
|
|
|
def __init__(self, data, file):
|
|
|
self.data = data
|
|
|
self.file = file
|
|
|
|
|
|
-def serve_data(data, file, verbose=False, browse=False, port=8080, host='localhost', optimize=False):
|
|
|
+def serve_data(data, file, verbose=False, browse=False, port=8080, host='localhost'):
|
|
|
server = MyHTTPServer((host, port), MyHTTPRequestHandler)
|
|
|
model = Model(data, file)
|
|
|
- if optimize:
|
|
|
- print("Processing '" + file + "'...")
|
|
|
- ok = False
|
|
|
- if not ok and file.endswith('.tflite'):
|
|
|
- ok = optimize_tflite(model)
|
|
|
- if not ok and os.path.basename(file) == 'saved_model.pb':
|
|
|
- ok = optimize_tf(model)
|
|
|
- if not ok and file.endswith('.onnx') or file.endswith('.pb'):
|
|
|
- ok = optimize_onnx(model)
|
|
|
- if not ok and file.endswith('.json') or file.endswith('.h5') or file.endswith('.keras'):
|
|
|
- ok = optimize_keras(model)
|
|
|
- if not ok and file.endswith('.pb'):
|
|
|
- ok = optimize_tf(model)
|
|
|
url = 'http://' + host + ':' + str(port)
|
|
|
print("Serving '" + file + "' at " + url + "...")
|
|
|
server.initialize_data(model, verbose)
|
|
|
@@ -160,9 +114,9 @@ def serve_data(data, file, verbose=False, browse=False, port=8080, host='localho
|
|
|
print("\nStopping...")
|
|
|
server.server_close()
|
|
|
|
|
|
-def serve_file(file, verbose=False, browse=False, port=8080, host='localhost', optimize=False):
|
|
|
+def serve_file(file, verbose=False, browse=False, port=8080, host='localhost'):
|
|
|
print("Reading '" + file + "'...")
|
|
|
data = None
|
|
|
with open(file, 'rb') as binary:
|
|
|
data = binary.read()
|
|
|
- serve_data(data, file, verbose=verbose, browse=browse, port=port, host=host, optimize=optimize)
|
|
|
+ serve_data(data, file, verbose=verbose, browse=browse, port=port, host=host)
|