Răsfoiți Sursa

Strip initializer raw_data for fast browser loading

Lutz Roeder 8 ani în urmă
părinte
comite
03a5c4c31f
2 a modificat fișierele cu 59 adăugiri și 45 ștergeri
  1. 4 1
      setup.py
  2. 55 44
      src/__init__.py

+ 4 - 1
setup.py

@@ -55,7 +55,7 @@ class build_py(setuptools.command.build_py.build_py):
         print("## get_outputs ##")
         return result
 
-packages = ['netron']
+packages = [ 'netron' ]
 
 package_data={
     'netron': [ 
@@ -73,6 +73,8 @@ package_data={
         ]
 }
 
+install_requires = [ 'protobuf' ]
+
 scripts = [ 'src/netron' ]
 
 setuptools.setup(
@@ -84,6 +86,7 @@ setuptools.setup(
     package_dir={ 'netron': 'src' },
     packages=packages,
     package_data=package_data,
+    install_requires=install_requires,
     author='Lutz Roeder',
     author_email='[email protected]',
     url='https://github.com/lutzroeder/Netron',

+ 55 - 44
src/__init__.py

@@ -1,11 +1,10 @@
 #!/usr/bin/python
 
-import codecs
 import os
 import platform
-import re
 import sys
 import base64
+import onnx
 
 if sys.version_info[0] > 2:
     from urllib.parse import urlparse
@@ -18,13 +17,28 @@ else:
 
 class MyHTTPRequestHandler(BaseHTTPRequestHandler):
     def handler(self):
+        if not hasattr(self, 'mime_types_map'):
+            self.mime_types_map = {
+                '.html': 'text/html',
+                '.js':   'text/javascript',
+                '.css':  'text/css',
+                '.png':  'image/png',
+                '.gif':  'image/gif',
+                '.jpg':  'image/jpeg',
+                '.json': 'application/json',
+                '.ttf': 'font/truetype',
+                '.woff': 'font/woff',
+                '.otf': 'font/opentype',
+                '.eot': 'application/vnd.ms-fontobject',
+                '.woff': 'application/font-woff',
+                '.woff2': 'application/font-woff2',
+                '.svg': 'image/svg+xml'
+            }
         pathname = urlparse(self.path).path
         folder = os.path.dirname(os.path.realpath(__file__))
         if pathname == '/':
             pathname = '/view-browser.html'
         location = folder + pathname;
-        if pathname == '/model':
-            location = self.model_file;
         status_code = 0
         headers = {}
         buffer = None
@@ -36,40 +50,23 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
                     status_code = 302
                     headers = { 'Location': pathname + '/' }
         if status_code == 0:
-            if os.path.exists(location) and not os.path.isdir(location):
+            if pathname == '/model':
+                buffer = base64.b64encode(self.buffer)
+                headers['Content-Type'] = 'text/plain'
+                headers['Content-Length'] = len(buffer)
                 status_code = 200
             else:
-                status_code = 404
-            if os.path.exists(location) and not os.path.isdir(location):
-                with open(location, 'rb') as binary:
-                    buffer = binary.read()
-                headers['Content-Length'] = len(buffer)
-                extension = os.path.splitext(location)[1]
-                if pathname == '/model':
-                    content_type = 'text/plain'
-                    buffer = base64.b64encode(buffer)
-                    headers['Content-Length'] = len(buffer)
-                else:
-                    if not hasattr(self, 'mime_types_map'):
-                        self.mime_types_map = {
-                            '.html': 'text/html',
-                            '.js':   'text/javascript',
-                            '.css':  'text/css',
-                            '.png':  'image/png',
-                            '.gif':  'image/gif',
-                            '.jpg':  'image/jpeg',
-                            '.json': 'application/json',
-                            '.ttf': 'font/truetype',
-                            '.woff': 'font/woff',
-                            '.otf': 'font/opentype',
-                            '.eot': 'application/vnd.ms-fontobject',
-                            '.woff': 'application/font-woff',
-                            '.woff2': 'application/font-woff2',
-                            '.svg': 'image/svg+xml'
-                        }
+                if os.path.exists(location) and not os.path.isdir(location):
+                    extension = os.path.splitext(location)[1]
                     content_type = self.mime_types_map[extension]
-                if content_type:
-                    headers['Content-Type'] = content_type
+                    if content_type:
+                        with open(location, 'rb') as binary:
+                            buffer = binary.read()
+                        headers['Content-Type'] = content_type
+                        headers['Content-Length'] = len(buffer)
+                        status_code = 200
+                else:
+                    status_code = 404
         # print(str(status_code) + ' ' + self.command + ' ' + self.path)
         sys.stdout.flush()
         self.send_response(status_code)
@@ -90,8 +87,8 @@ class MyHTTPRequestHandler(BaseHTTPRequestHandler):
         return
 
 class MyHTTPServer(HTTPServer):
-    def serve_forever(self, model_file):
-        self.RequestHandlerClass.model_file = model_file 
+    def serve_forever(self, buffer):
+        self.RequestHandlerClass.buffer = buffer 
         HTTPServer.serve_forever(self)
 
 def show_help():
@@ -103,12 +100,14 @@ def show_help():
     print('  --help          Show help.')
     print('  --port <port>   Port to serve (default: 8080).')
     print('  --browse        Launch web browser.')
+    print('  --initializer   Keep graph initializer tensors.')
     print('')
 
 def serve(args):
     port = 8080
     browse = False
-    model_file = ''
+    initializer = False
+    file = ''
     while len(args) > 0:
         arg = args.pop(0)
         if (arg == '--help' or arg == '-h'):
@@ -118,17 +117,29 @@ def serve(args):
             port = int(args.pop(0))
         elif arg == '--browse' or arg == '-b':
             browse = True
+        elif arg == '--initialier' or arg == '-i':
+            initialier = True
         elif not arg.startswith('-'):
-            model_file = arg
-    if len(model_file) == 0:
+            file = arg
+    if len(file) == 0:
         show_help()
         return
-    if not os.path.exists(model_file):
-        print("Model file '" + model_file + "' does not exist.")
+    if not os.path.exists(file):
+        print("Model file '" + file + "' does not exist.")
         return
     server = MyHTTPServer(('localhost', port), MyHTTPRequestHandler)
     url = 'http://localhost:' + str(port)
-    print("Serving '" + model_file + "' at " + url + "...")
+    buffer = None
+    with open(file, 'rb') as binary:
+        buffer = binary.read()
+    if not initializer:
+        # Remove raw initializer data
+        model = onnx.ModelProto()
+        model.ParseFromString(buffer)
+        for initializer in model.graph.initializer:
+          initializer.raw_data = ""
+        buffer = model.SerializeToString()
+    print("Serving '" + file + "' at " + url + "...")
     if browse:
         command = 'xdg-open';
         if platform.system() == 'Darwin':
@@ -137,4 +148,4 @@ def serve(args):
             command = 'start ""'
         os.system(command + ' "' + url.replace('"', '\"') + '"')
     sys.stdout.flush()
-    server.serve_forever(model_file)
+    server.serve_forever(buffer)