server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """ Python Server implementation """
  2. import errno
  3. import http.server
  4. import importlib
  5. import importlib.metadata
  6. import json
  7. import logging
  8. import os
  9. import random
  10. import re
  11. import socket
  12. import socketserver
  13. import threading
  14. import time
  15. import urllib.parse
  16. import webbrowser
  17. __version__ = "0.0.0"
  18. logger = logging.getLogger(__name__)
  19. class _ContentProvider:
  20. data = bytearray()
  21. base_dir = ""
  22. base = ""
  23. identifier = ""
  24. def __init__(self, data, path, file, name):
  25. self.data = data if data else bytearray()
  26. self.identifier = os.path.basename(file) if file else ""
  27. self.name = name
  28. if path:
  29. self.dir = os.path.dirname(path) if os.path.dirname(path) else "."
  30. self.base = os.path.basename(path)
  31. def read(self, path):
  32. if path == self.base and self.data:
  33. return self.data
  34. base_dir = os.path.realpath(self.dir)
  35. filename = os.path.normpath(os.path.realpath(base_dir + "/" + path))
  36. if os.path.commonprefix([ base_dir, filename ]) == base_dir:
  37. if os.path.exists(filename) and not os.path.isdir(filename):
  38. with open(filename, "rb") as file:
  39. return file.read()
  40. return None
  41. class _HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
  42. content = None
  43. mime_types = {
  44. ".html": "text/html",
  45. ".js": "text/javascript",
  46. ".css": "text/css",
  47. ".png": "image/png",
  48. ".gif": "image/gif",
  49. ".jpg": "image/jpeg",
  50. ".ico": "image/x-icon",
  51. ".json": "application/json",
  52. ".pb": "application/octet-stream",
  53. ".ttf": "font/truetype",
  54. ".otf": "font/opentype",
  55. ".eot": "application/vnd.ms-fontobject",
  56. ".woff": "font/woff",
  57. ".woff2": "application/font-woff2",
  58. ".svg": "image/svg+xml"
  59. }
  60. def do_HEAD(self):
  61. self.do_GET()
  62. def do_GET(self):
  63. path = urllib.parse.urlparse(self.path).path
  64. path = "/index.html" if path == "/" else path
  65. status_code = 404
  66. content = None
  67. content_type = None
  68. if path.startswith("/data/"):
  69. path = urllib.parse.unquote(path[len("/data/"):])
  70. content = self.content.read(path)
  71. if content:
  72. content_type = "application/octet-stream"
  73. status_code = 200
  74. else:
  75. base_dir = os.path.dirname(os.path.realpath(__file__))
  76. filename = os.path.normpath(os.path.realpath(base_dir + path))
  77. extension = os.path.splitext(filename)[1]
  78. if os.path.commonprefix([base_dir, filename]) == base_dir and \
  79. os.path.exists(filename) and not os.path.isdir(filename) and \
  80. extension in self.mime_types:
  81. content_type = self.mime_types[extension]
  82. with open(filename, "rb") as file:
  83. content = file.read()
  84. if path == "/index.html":
  85. content = content.decode("utf-8")
  86. meta = [
  87. '<meta name="type" content="Python">',
  88. '<meta name="version" content="' + __version__ + '">'
  89. ]
  90. base = self.content.base
  91. if base:
  92. meta.append('<meta name="file" content="/data/' + base + '">')
  93. name = self.content.name
  94. if name:
  95. meta.append('<meta name="name" content="' + name + '">')
  96. identifier = self.content.identifier
  97. if identifier:
  98. meta.append(f'<meta name="identifier" content="{identifier}">')
  99. meta = "\n".join(meta)
  100. regex = r'<meta name="version" content=".*">'
  101. content = re.sub(regex, lambda _: meta, content)
  102. content = content.encode("utf-8")
  103. status_code = 200
  104. self._write(status_code, content_type, content)
  105. def log_message(self, format, *args):
  106. logger.debug(" ".join(args))
  107. def _write(self, status_code, content_type, content):
  108. self.send_response(status_code)
  109. if content:
  110. self.send_header("Content-Type", content_type)
  111. self.send_header("Content-Length", len(content))
  112. self.end_headers()
  113. if self.command != "HEAD":
  114. if status_code == 404 and content is None:
  115. self.wfile.write(str(status_code).encode("utf-8"))
  116. elif (status_code in (200, 404)) and content is not None:
  117. self.wfile.write(content)
  118. class _ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
  119. pass
  120. class _HTTPServerThread(threading.Thread):
  121. def __init__(self, content, address):
  122. threading.Thread.__init__(self)
  123. self.address = address
  124. self.url = "http://" + address[0] + ":" + str(address[1])
  125. self.server = _ThreadedHTTPServer(address, _HTTPRequestHandler)
  126. self.server.timeout = 0.25
  127. self.server.block_on_close = False
  128. self.server.RequestHandlerClass.content = content
  129. self.terminate_event = threading.Event()
  130. self.terminate_event.set()
  131. self.stop_event = threading.Event()
  132. def run(self):
  133. self.stop_event.clear()
  134. self.terminate_event.clear()
  135. try:
  136. while not self.stop_event.is_set():
  137. self.server.handle_request()
  138. except: # noqa: E722
  139. pass
  140. self.terminate_event.set()
  141. self.stop_event.clear()
  142. def stop(self):
  143. if self.alive():
  144. logger.info("Stopping " + self.url)
  145. self.stop_event.set()
  146. self.server.server_close()
  147. self.terminate_event.wait(1)
  148. def alive(self):
  149. value = not self.terminate_event.is_set()
  150. return value
  151. def _open(data):
  152. registry = dict([
  153. ("onnx.onnx_ml_pb2.ModelProto", ".onnx"),
  154. ("torch.jit._script.ScriptModule", ".pytorch"),
  155. ("torch.Graph", ".pytorch"),
  156. ("torch._C.Graph", ".pytorch"),
  157. ("torch.nn.modules.module.Module", ".pytorch")
  158. ])
  159. queue = [ data.__class__ ]
  160. while len(queue) > 0:
  161. current = queue.pop(0)
  162. if current.__module__ and current.__name__:
  163. name = current.__module__ + "." + current.__name__
  164. if name in registry:
  165. module_name = registry[name]
  166. module = importlib.import_module(module_name, package=__package__)
  167. model_factory = module.ModelFactory()
  168. return model_factory.open(data)
  169. queue.extend(_ for _ in current.__bases__ if isinstance(_, type))
  170. return None
  171. def _threads(address=None):
  172. threads = []
  173. for thread in threading.enumerate():
  174. if isinstance(thread, _HTTPServerThread) and thread.alive():
  175. threads.append(thread)
  176. if address is not None:
  177. address = _make_address(address)
  178. threads = [ _ for _ in threads if address[0] == _.address[0] ]
  179. if address[1]:
  180. threads = [ _ for _ in threads if address[1] == _.address[1] ]
  181. return threads
  182. def _make_address(address):
  183. if address is None or isinstance(address, int):
  184. port = address
  185. address = ("localhost", port)
  186. if isinstance(address, tuple) and len(address) == 2:
  187. host = address[0]
  188. port = address[1]
  189. if isinstance(host, str) and (port is None or isinstance(port, int)):
  190. return address
  191. raise ValueError("Invalid address.")
  192. def _make_port(address):
  193. if address[1] is None or address[1] == 0:
  194. ports = []
  195. if address[1] != 0:
  196. ports.append(8080)
  197. ports.append(8081)
  198. rnd = random.Random()
  199. for _ in range(4):
  200. port = rnd.randrange(15000, 25000)
  201. if port not in ports:
  202. ports.append(port)
  203. ports.append(0)
  204. for port in ports:
  205. temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  206. temp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  207. temp_socket.settimeout(1)
  208. try:
  209. temp_socket.bind((address[0], port))
  210. sockname = temp_socket.getsockname()
  211. address = (address[0], sockname[1])
  212. return address
  213. except: # noqa: E722
  214. pass
  215. finally:
  216. temp_socket.close()
  217. if isinstance(address[1], int):
  218. return address
  219. raise ValueError("Failed to allocate port.")
  220. def stop(address=None):
  221. """Stop serving model at address.
  222. Args:
  223. address (tuple, optional): A (host, port) tuple, or a port number.
  224. """
  225. threads = _threads(address)
  226. for thread in threads:
  227. thread.stop()
  228. def status(address=None):
  229. """Is model served at address.
  230. Args:
  231. address (tuple, optional): A (host, port) tuple, or a port number.
  232. """
  233. threads = _threads(address)
  234. return len(threads) > 0
  235. def wait():
  236. """Wait for console exit and stop all model servers."""
  237. try:
  238. while len(_threads()) > 0:
  239. time.sleep(0.1)
  240. except (KeyboardInterrupt, SystemExit):
  241. logger.info("")
  242. stop()
  243. def serve(file, data=None, address=None, browse=False):
  244. """Start serving model from file or data buffer at address and open in web browser.
  245. Args:
  246. file (string): Model file to serve. Required to detect format.
  247. data (bytes): Model data to serve. None will load data from file.
  248. address (tuple, optional): A (host, port) tuple, or a port number.
  249. browse (bool, optional): Launch web browser. Default: True
  250. Returns:
  251. A (host, port) address tuple.
  252. """
  253. if not logging.getLogger().hasHandlers():
  254. logging.basicConfig(level=logging.INFO, format="%(message)s")
  255. if not data and file and not os.path.exists(file):
  256. raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
  257. content = _ContentProvider(data, file, file, file)
  258. if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
  259. logger.info("Experimental")
  260. model = _open(data)
  261. if model:
  262. text = json.dumps(model.to_json(), indent=2, ensure_ascii=False)
  263. content = _ContentProvider(text.encode("utf-8"), "model.netron", None, file)
  264. address = _make_address(address)
  265. if isinstance(address[1], int) and address[1] != 0:
  266. stop(address)
  267. else:
  268. address = _make_port(address)
  269. thread = _HTTPServerThread(content, address)
  270. thread.start()
  271. while not thread.alive():
  272. time.sleep(0.01)
  273. state = ("Serving '" + file + "'") if file else "Serving"
  274. logger.info(f"{state} at {thread.url}")
  275. if browse:
  276. webbrowser.open(thread.url)
  277. return address
  278. def start(file=None, address=None, browse=True):
  279. """Start serving model file at address and open in web browser.
  280. Args:
  281. file (string): Model file to serve.
  282. browse (bool, optional): Launch web browser, Default: True
  283. address (tuple, optional): A (host, port) tuple, or a port number.
  284. Returns:
  285. A (host, port) address tuple.
  286. """
  287. return serve(file, None, browse=browse, address=address)
  288. def widget(address, height=800):
  289. """ Open address as Jupyter Notebook IFrame.
  290. Args:
  291. address (tuple, optional): A (host, port) tuple, or a port number.
  292. height (int, optional): Height of the IFrame, Default: 800
  293. Returns:
  294. A Jupyter Notebook IFrame.
  295. """
  296. address = _make_address(address)
  297. url = f"http://{address[0]}:{address[1]}"
  298. IPython = __import__("IPython")
  299. return IPython.display.IFrame(url, width="100%", height=height)