|
|
@@ -1,11 +1,10 @@
|
|
|
#!/usr/bin/python
|
|
|
|
|
|
-import codecs
|
|
|
import os
|
|
|
import platform
|
|
|
-import re
|
|
|
import sys
|
|
|
import base64
|
|
|
+import onnx
|
|
|
|
|
|
if sys.version_info[0] > 2:
|
|
|
from urllib.parse import urlparse
|
|
|
@@ -18,13 +17,28 @@ else:
|
|
|
|
|
|
class MyHTTPRequestHandler(BaseHTTPRequestHandler):
|
|
|
def handler(self):
|
|
|
+ if not hasattr(self, 'mime_types_map'):
|
|
|
+ self.mime_types_map = {
|
|
|
+ '.html': 'text/html',
|
|
|
+ '.js': 'text/javascript',
|
|
|
+ '.css': 'text/css',
|
|
|
+ '.png': 'image/png',
|
|
|
+ '.gif': 'image/gif',
|
|
|
+ '.jpg': 'image/jpeg',
|
|
|
+ '.json': 'application/json',
|
|
|
+ '.ttf': 'font/truetype',
|
|
|
+ '.woff': 'font/woff',
|
|
|
+ '.otf': 'font/opentype',
|
|
|
+ '.eot': 'application/vnd.ms-fontobject',
|
|
|
+ '.woff': 'application/font-woff',
|
|
|
+ '.woff2': 'application/font-woff2',
|
|
|
+ '.svg': 'image/svg+xml'
|
|
|
+ }
|
|
|
pathname = urlparse(self.path).path
|
|
|
folder = os.path.dirname(os.path.realpath(__file__))
|
|
|
if pathname == '/':
|
|
|
pathname = '/view-browser.html'
|
|
|
location = folder + pathname;
|
|
|
- if pathname == '/model':
|
|
|
- location = self.model_file;
|
|
|
status_code = 0
|
|
|
headers = {}
|
|
|
buffer = None
|
|
|
@@ -36,40 +50,23 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
|
|
|
status_code = 302
|
|
|
headers = { 'Location': pathname + '/' }
|
|
|
if status_code == 0:
|
|
|
- if os.path.exists(location) and not os.path.isdir(location):
|
|
|
+ if pathname == '/model':
|
|
|
+ buffer = base64.b64encode(self.buffer)
|
|
|
+ headers['Content-Type'] = 'text/plain'
|
|
|
+ headers['Content-Length'] = len(buffer)
|
|
|
status_code = 200
|
|
|
else:
|
|
|
- status_code = 404
|
|
|
- if os.path.exists(location) and not os.path.isdir(location):
|
|
|
- with open(location, 'rb') as binary:
|
|
|
- buffer = binary.read()
|
|
|
- headers['Content-Length'] = len(buffer)
|
|
|
- extension = os.path.splitext(location)[1]
|
|
|
- if pathname == '/model':
|
|
|
- content_type = 'text/plain'
|
|
|
- buffer = base64.b64encode(buffer)
|
|
|
- headers['Content-Length'] = len(buffer)
|
|
|
- else:
|
|
|
- if not hasattr(self, 'mime_types_map'):
|
|
|
- self.mime_types_map = {
|
|
|
- '.html': 'text/html',
|
|
|
- '.js': 'text/javascript',
|
|
|
- '.css': 'text/css',
|
|
|
- '.png': 'image/png',
|
|
|
- '.gif': 'image/gif',
|
|
|
- '.jpg': 'image/jpeg',
|
|
|
- '.json': 'application/json',
|
|
|
- '.ttf': 'font/truetype',
|
|
|
- '.woff': 'font/woff',
|
|
|
- '.otf': 'font/opentype',
|
|
|
- '.eot': 'application/vnd.ms-fontobject',
|
|
|
- '.woff': 'application/font-woff',
|
|
|
- '.woff2': 'application/font-woff2',
|
|
|
- '.svg': 'image/svg+xml'
|
|
|
- }
|
|
|
+ if os.path.exists(location) and not os.path.isdir(location):
|
|
|
+ extension = os.path.splitext(location)[1]
|
|
|
content_type = self.mime_types_map[extension]
|
|
|
- if content_type:
|
|
|
- headers['Content-Type'] = content_type
|
|
|
+ if content_type:
|
|
|
+ with open(location, 'rb') as binary:
|
|
|
+ buffer = binary.read()
|
|
|
+ headers['Content-Type'] = content_type
|
|
|
+ headers['Content-Length'] = len(buffer)
|
|
|
+ status_code = 200
|
|
|
+ else:
|
|
|
+ status_code = 404
|
|
|
# print(str(status_code) + ' ' + self.command + ' ' + self.path)
|
|
|
sys.stdout.flush()
|
|
|
self.send_response(status_code)
|
|
|
@@ -90,8 +87,8 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
|
|
|
return
|
|
|
|
|
|
class MyHTTPServer(HTTPServer):
|
|
|
- def serve_forever(self, model_file):
|
|
|
- self.RequestHandlerClass.model_file = model_file
|
|
|
+ def serve_forever(self, buffer):
|
|
|
+ self.RequestHandlerClass.buffer = buffer
|
|
|
HTTPServer.serve_forever(self)
|
|
|
|
|
|
def show_help():
|
|
|
@@ -103,12 +100,14 @@ def show_help():
|
|
|
print(' --help Show help.')
|
|
|
print(' --port <port> Port to serve (default: 8080).')
|
|
|
print(' --browse Launch web browser.')
|
|
|
+ print(' --initializer Keep graph initializer tensors.')
|
|
|
print('')
|
|
|
|
|
|
def serve(args):
|
|
|
port = 8080
|
|
|
browse = False
|
|
|
- model_file = ''
|
|
|
+ initializer = False
|
|
|
+ file = ''
|
|
|
while len(args) > 0:
|
|
|
arg = args.pop(0)
|
|
|
if (arg == '--help' or arg == '-h'):
|
|
|
@@ -118,17 +117,29 @@ def serve(args):
|
|
|
port = int(args.pop(0))
|
|
|
elif arg == '--browse' or arg == '-b':
|
|
|
browse = True
|
|
|
+ elif arg == '--initialier' or arg == '-i':
|
|
|
+ initialier = True
|
|
|
elif not arg.startswith('-'):
|
|
|
- model_file = arg
|
|
|
- if len(model_file) == 0:
|
|
|
+ file = arg
|
|
|
+ if len(file) == 0:
|
|
|
show_help()
|
|
|
return
|
|
|
- if not os.path.exists(model_file):
|
|
|
- print("Model file '" + model_file + "' does not exist.")
|
|
|
+ if not os.path.exists(file):
|
|
|
+ print("Model file '" + file + "' does not exist.")
|
|
|
return
|
|
|
server = MyHTTPServer(('localhost', port), MyHTTPRequestHandler)
|
|
|
url = 'http://localhost:' + str(port)
|
|
|
- print("Serving '" + model_file + "' at " + url + "...")
|
|
|
+ buffer = None
|
|
|
+ with open(file, 'rb') as binary:
|
|
|
+ buffer = binary.read()
|
|
|
+ if not initializer:
|
|
|
+ # Remove raw initializer data
|
|
|
+ model = onnx.ModelProto()
|
|
|
+ model.ParseFromString(buffer)
|
|
|
+ for initializer in model.graph.initializer:
|
|
|
+ initializer.raw_data = ""
|
|
|
+ buffer = model.SerializeToString()
|
|
|
+ print("Serving '" + file + "' at " + url + "...")
|
|
|
if browse:
|
|
|
command = 'xdg-open';
|
|
|
if platform.system() == 'Darwin':
|
|
|
@@ -137,4 +148,4 @@ def serve(args):
|
|
|
command = 'start ""'
|
|
|
os.system(command + ' "' + url.replace('"', '\"') + '"')
|
|
|
sys.stdout.flush()
|
|
|
- server.serve_forever(model_file)
|
|
|
+ server.serve_forever(buffer)
|