netron.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #!/usr/bin/python
  2. import codecs
  3. import os
  4. import platform
  5. import sys
  6. import threading
  7. import webbrowser
  8. from .onnx_ml_pb2 import ModelProto
  9. if sys.version_info[0] > 2:
  10. from urllib.parse import urlparse
  11. from http.server import HTTPServer
  12. from http.server import BaseHTTPRequestHandler
  13. else:
  14. from urlparse import urlparse
  15. from BaseHTTPServer import HTTPServer
  16. from BaseHTTPServer import BaseHTTPRequestHandler
  17. class MyHTTPRequestHandler(BaseHTTPRequestHandler):
  18. def handler(self):
  19. if not hasattr(self, 'mime_types_map'):
  20. self.mime_types_map = {
  21. '.html': 'text/html',
  22. '.js': 'text/javascript',
  23. '.css': 'text/css',
  24. '.png': 'image/png',
  25. '.gif': 'image/gif',
  26. '.jpg': 'image/jpeg',
  27. '.ico': 'image/x-icon',
  28. '.json': 'application/json',
  29. '.pb': 'application/octet-stream',
  30. '.ttf': 'font/truetype',
  31. '.woff': 'font/woff',
  32. '.otf': 'font/opentype',
  33. '.eot': 'application/vnd.ms-fontobject',
  34. '.woff': 'application/font-woff',
  35. '.woff2': 'application/font-woff2',
  36. '.svg': 'image/svg+xml'
  37. }
  38. pathname = urlparse(self.path).path
  39. folder = os.path.dirname(os.path.realpath(__file__))
  40. location = folder + pathname;
  41. status_code = 0
  42. headers = {}
  43. buffer = None
  44. if status_code == 0:
  45. if pathname == '/':
  46. with codecs.open(location + 'view-browser.html', mode="r", encoding="utf-8") as open_file:
  47. buffer = open_file.read()
  48. buffer = buffer.replace('{{{title}}}', self.data.file)
  49. buffer = buffer.encode('utf-8');
  50. headers['Content-Type'] = 'text/html'
  51. headers['Content-Length'] = len(buffer)
  52. status_code = 200
  53. elif pathname == '/data':
  54. buffer = self.data.data
  55. headers['Content-Type'] = 'application/octet-stream'
  56. headers['Content-Length'] = len(buffer)
  57. status_code = 200
  58. else:
  59. if os.path.exists(location) and not os.path.isdir(location):
  60. extension = os.path.splitext(location)[1]
  61. content_type = self.mime_types_map[extension]
  62. if content_type:
  63. with open(location, 'rb') as binary:
  64. buffer = binary.read()
  65. headers['Content-Type'] = content_type
  66. headers['Content-Length'] = len(buffer)
  67. status_code = 200
  68. else:
  69. status_code = 404
  70. if self.verbose:
  71. print(str(status_code) + ' ' + self.command + ' ' + self.path)
  72. sys.stdout.flush()
  73. self.send_response(status_code)
  74. for key in headers:
  75. self.send_header(key, headers[key])
  76. self.end_headers()
  77. if self.command != 'HEAD':
  78. if status_code == 404 and buffer is None:
  79. self.wfile.write(str(status_code))
  80. elif (status_code == 200 or status_code == 404) and buffer != None:
  81. self.wfile.write(buffer)
  82. return
  83. def do_GET(self):
  84. self.handler()
  85. def do_HEAD(self):
  86. self.handler()
  87. def log_message(self, format, *args):
  88. return
  89. class MyHTTPServer(HTTPServer):
  90. def initialize_data(self, data,verbose):
  91. self.RequestHandlerClass.data = data
  92. self.RequestHandlerClass.verbose = verbose
  93. def optimize_onnx(model):
  94. def remove_tensor_data(tensor):
  95. del tensor.string_data[:]
  96. del tensor.int32_data[:]
  97. del tensor.int64_data[:]
  98. del tensor.float_data[:]
  99. tensor.raw_data = None
  100. # Remove raw initializer data
  101. onnx_model = ModelProto()
  102. try:
  103. onnx_model.ParseFromString(model.data)
  104. except:
  105. return False
  106. for initializer in onnx_model.graph.initializer:
  107. remove_tensor_data(initializer)
  108. for node in onnx_model.graph.node:
  109. for attribute in node.attribute:
  110. if attribute.HasField('t'):
  111. remove_tensor_data(attribute.t)
  112. model.data = onnx_model.SerializeToString()
  113. return True
  114. def optimize_tf(model):
  115. return True;
  116. def optimize_tflite(model):
  117. return True;
  118. class Model:
  119. def __init__(self, data, file):
  120. self.data = data
  121. self.file = file
  122. def serve_data(data, file, verbose=False, browse=False, port=8080, host='localhost', tensor=False):
  123. server = MyHTTPServer((host, port), MyHTTPRequestHandler)
  124. model = Model(data, file)
  125. if not tensor:
  126. print("Processing '" + file + "'...")
  127. ok = False
  128. if not ok and file.endswith('.tflite'):
  129. ok = optimize_tflite(model)
  130. if not ok and os.path.basename(file) == 'saved_model.pb':
  131. ok = optimize_tf(model)
  132. if not ok and file.endswith('.onnx') or file.endswith('.pb'):
  133. ok = optimize_onnx(model)
  134. if not ok and file.endswith('.pb'):
  135. ok = optimize_tf(model)
  136. url = 'http://' + host + ':' + str(port)
  137. print("Serving '" + file + "' at " + url + "...")
  138. server.initialize_data(model, verbose)
  139. sys.stdout.flush()
  140. if browse:
  141. threading.Timer(1, webbrowser.open, args=(url,)).start()
  142. try:
  143. server.serve_forever()
  144. except (KeyboardInterrupt, SystemExit):
  145. print("\nStopping...")
  146. server.server_close()
  147. def serve_file(file, verbose=False, browse=False, port=8080, host='localhost', tensor=False):
  148. print("Reading '" + file + "'...")
  149. data = None
  150. with open(file, 'rb') as binary:
  151. data = binary.read()
  152. serve_data(data, file, verbose=verbose, browse=browse, port=port, host=host, tensor=tensor)