""" Python Server implementation """ import errno import http.server import importlib import importlib.metadata import json import logging import os import random import re import socket import socketserver import threading import time import urllib.parse import webbrowser __version__ = "0.0.0" logger = logging.getLogger(__name__) class _ContentProvider: data = bytearray() base_dir = "" base = "" identifier = "" def __init__(self, data, path, file, name): self.data = data if data else bytearray() self.identifier = os.path.basename(file) if file else "" self.name = name if path: self.dir = os.path.dirname(path) if os.path.dirname(path) else "." self.base = os.path.basename(path) def read(self, path): if path == self.base and self.data: return self.data base_dir = os.path.realpath(self.dir) filename = os.path.normpath(os.path.realpath(base_dir + "/" + path)) if os.path.commonpath([ base_dir, filename ]) == base_dir: if os.path.exists(filename) and not os.path.isdir(filename): with open(filename, "rb") as file: return file.read() return None class _HTTPRequestHandler(http.server.BaseHTTPRequestHandler): content = None mime_types = { ".html": "text/html", ".js": "text/javascript", ".css": "text/css", ".png": "image/png", ".gif": "image/gif", ".jpg": "image/jpeg", ".ico": "image/x-icon", ".json": "application/json", ".pb": "application/octet-stream", ".ttf": "font/truetype", ".otf": "font/opentype", ".eot": "application/vnd.ms-fontobject", ".woff": "font/woff", ".woff2": "font/woff2", ".svg": "image/svg+xml" } def do_HEAD(self): self.do_GET() def do_GET(self): path = urllib.parse.urlparse(self.path).path path = "/index.html" if path == "/" else path status_code = 404 content = None content_type = None if path.startswith("/data/"): path = urllib.parse.unquote(path[len("/data/"):]) content = self.content.read(path) if content: content_type = "application/octet-stream" status_code = 200 else: base_dir = os.path.dirname(os.path.realpath(__file__)) filename = os.path.normpath(os.path.realpath(base_dir + path)) extension = os.path.splitext(filename)[1] if os.path.commonpath([base_dir, filename]) == base_dir and \ os.path.exists(filename) and not os.path.isdir(filename) and \ extension in self.mime_types: content_type = self.mime_types[extension] with open(filename, "rb") as file: content = file.read() if path == "/index.html": content = content.decode("utf-8") meta = [ '', '' ] base = self.content.base if base: meta.append('') name = self.content.name if name: meta.append('') identifier = self.content.identifier if identifier: meta.append(f'') meta = "\n".join(meta) regex = r'' content = re.sub(regex, lambda _: meta, content) content = content.encode("utf-8") status_code = 200 self._write(status_code, content_type, content) def log_message(self, format, *args): logger.debug(" ".join(args)) def _write(self, status_code, content_type, content): self.send_response(status_code) if content: self.send_header("Content-Type", content_type) self.send_header("Content-Length", len(content)) self.end_headers() if self.command != "HEAD": if status_code == 404 and content is None: self.wfile.write(str(status_code).encode("utf-8")) elif (status_code in (200, 404)) and content is not None: self.wfile.write(content) class _ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer): pass class _HTTPServerThread(threading.Thread): def __init__(self, content, address): threading.Thread.__init__(self) self.daemon = True self.address = address self.url = "http://" + address[0] + ":" + str(address[1]) self.server = _ThreadedHTTPServer(address, _HTTPRequestHandler) self.server.timeout = 0.25 self.server.block_on_close = False self.server.RequestHandlerClass.content = content self.terminate_event = threading.Event() self.terminate_event.set() self.stop_event = threading.Event() def run(self): self.stop_event.clear() self.terminate_event.clear() try: while not self.stop_event.is_set(): self.server.handle_request() except: # noqa: E722 pass self.terminate_event.set() self.stop_event.clear() def stop(self): if self.alive(): logger.info("Stopping " + self.url) self.stop_event.set() self.server.server_close() def alive(self): value = not self.terminate_event.is_set() return value def _open(data): registry = dict([ ("onnx.onnx_ml_pb2.ModelProto", ".onnx"), ("torch.jit._script.ScriptModule", ".pytorch"), ("torch.Graph", ".pytorch"), ("torch._C.Graph", ".pytorch"), ("torch.nn.modules.module.Module", ".pytorch") ]) queue = [ data.__class__ ] while len(queue) > 0: current = queue.pop(0) if current.__module__ and current.__name__: name = current.__module__ + "." + current.__name__ if name in registry: module_name = registry[name] module = importlib.import_module(module_name, package=__package__) model_factory = module.ModelFactory() return model_factory.open(data) queue.extend(_ for _ in current.__bases__ if isinstance(_, type)) return None def _threads(address=None): threads = [] for thread in threading.enumerate(): if isinstance(thread, _HTTPServerThread) and thread.alive(): threads.append(thread) if address is not None: address = _make_address(address) threads = [ _ for _ in threads if address[0] == _.address[0] ] if address[1]: threads = [ _ for _ in threads if address[1] == _.address[1] ] return threads def _make_address(address): if address is None or isinstance(address, int): port = address address = ("localhost", port) if isinstance(address, tuple) and len(address) == 2: host = address[0] port = address[1] if isinstance(host, str) and (port is None or isinstance(port, int)): return address raise ValueError("Invalid address.") def _make_port(address): if address[1] is None or address[1] == 0: ports = [] if address[1] != 0: ports.append(8080) ports.append(8081) rnd = random.Random() for _ in range(4): port = rnd.randrange(15000, 25000) if port not in ports: ports.append(port) ports.append(0) for port in ports: temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) temp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) temp_socket.settimeout(1) try: temp_socket.bind((address[0], port)) sockname = temp_socket.getsockname() address = (address[0], sockname[1]) return address except: # noqa: E722 pass finally: temp_socket.close() if isinstance(address[1], int): return address raise ValueError("Failed to allocate port.") def stop(address=None): """Stop serving model at address. Args: address (tuple, optional): A (host, port) tuple, or a port number. """ threads = _threads(address) for thread in threads: thread.stop() def status(address=None): """Is model served at address. Args: address (tuple, optional): A (host, port) tuple, or a port number. """ threads = _threads(address) return len(threads) > 0 def wait(): """Wait for console exit and stop all model servers.""" try: while len(_threads()) > 0: time.sleep(0.1) except (KeyboardInterrupt, SystemExit): stop() def serve(file, data=None, address=None, browse=False): """Start serving model from file or data buffer at address and open in web browser. Args: file (string): Model file to serve. Required to detect format. data (bytes): Model data to serve. None will load data from file. address (tuple, optional): A (host, port) tuple, or a port number. browse (bool, optional): Launch web browser. Default: True Returns: A (host, port) address tuple. """ if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s") if not data and file and not os.path.exists(file): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) content = _ContentProvider(data, file, file, file) if data and not isinstance(data, bytearray) and isinstance(data.__class__, type): logger.info("Experimental") model = _open(data) if model: text = json.dumps(model.to_json(), indent=2, ensure_ascii=False) content = _ContentProvider(text.encode("utf-8"), "model.netron", None, file) address = _make_address(address) if isinstance(address[1], int) and address[1] != 0: stop(address) else: address = _make_port(address) thread = _HTTPServerThread(content, address) thread.start() while not thread.alive(): time.sleep(0.01) state = ("Serving '" + file + "'") if file else "Serving" logger.info(f"{state} at {thread.url}") if browse: webbrowser.open(thread.url) return address def start(file=None, address=None, browse=True): """Start serving model file at address and open in web browser. Args: file (string): Model file to serve. browse (bool, optional): Launch web browser, Default: True address (tuple, optional): A (host, port) tuple, or a port number. Returns: A (host, port) address tuple. """ return serve(file, None, browse=browse, address=address) def widget(address, height=800): """ Open address as Jupyter Notebook IFrame. Args: address (tuple, optional): A (host, port) tuple, or a port number. height (int, optional): Height of the IFrame, Default: 800 Returns: A Jupyter Notebook IFrame. """ address = _make_address(address) url = f"http://{address[0]}:{address[1]}" IPython = __import__("IPython") return IPython.display.IFrame(url, width="100%", height=height)