| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337 |
- ''' Python Server implementation '''
- import codecs
- import errno
- import http.server
- import importlib.util
- import os
- import random
- import re
- import socket
- import socketserver
- import sys
- import threading
- import time
- import webbrowser
- import urllib.parse
- __version__ = '0.0.0'
- class HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
- ''' HTTP Request Handler '''
- file = ""
- data = bytearray()
- folder = ""
- verbosity = 1
- mime_types_map = {
- '.html': 'text/html',
- '.js': 'text/javascript',
- '.css': 'text/css',
- '.png': 'image/png',
- '.gif': 'image/gif',
- '.jpg': 'image/jpeg',
- '.ico': 'image/x-icon',
- '.json': 'application/json',
- '.pb': 'application/octet-stream',
- '.ttf': 'font/truetype',
- '.otf': 'font/opentype',
- '.eot': 'application/vnd.ms-fontobject',
- '.woff': 'font/woff',
- '.woff2': 'application/font-woff2',
- '.svg': 'image/svg+xml'
- }
- def do_GET(self): # pylint: disable=invalid-name
- ''' Serve a GET request '''
- path = urllib.parse.urlparse(self.path).path
- status_code = 0
- headers = {}
- buffer = None
- if path in ('/', '/index.html'):
- meta = []
- meta.append('<meta name="type" content="Python">')
- meta.append('<meta name="version" content="' + __version__ + '">')
- if self.file:
- meta.append('<meta name="file" content="/data/' + self.file + '">')
- basedir = os.path.dirname(os.path.realpath(__file__))
- with codecs.open(basedir + '/index.html', mode="r", encoding="utf-8") as open_file:
- buffer = open_file.read()
- meta = '\n'.join(meta)
- buffer = re.sub(r'<meta name="version" content="\d+.\d+.\d+">', meta, buffer)
- buffer = buffer.encode('utf-8')
- headers['Content-Type'] = 'text/html'
- headers['Content-Length'] = len(buffer)
- status_code = 200
- elif path.startswith('/data/'):
- status_code = 404
- path = urllib.parse.unquote(path[len('/data/'):])
- if path == self.file and self.data:
- buffer = self.data
- else:
- basedir = os.path.realpath(self.folder)
- path = os.path.normpath(os.path.realpath(basedir + '/' + path))
- if os.path.commonprefix([basedir, path]) == basedir:
- if os.path.exists(path) and not os.path.isdir(path):
- with open(path, 'rb') as file:
- buffer = file.read()
- if buffer:
- headers['Content-Type'] = 'application/octet-stream'
- headers['Content-Length'] = len(buffer)
- status_code = 200
- else:
- status_code = 404
- basedir = os.path.dirname(os.path.realpath(__file__))
- path = os.path.normpath(os.path.realpath(basedir + path))
- if os.path.commonprefix([basedir, path]) == basedir:
- if os.path.exists(path) and not os.path.isdir(path):
- extension = os.path.splitext(path)[1]
- content_type = self.mime_types_map[extension]
- if content_type:
- with open(path, 'rb') as file:
- buffer = file.read()
- headers['Content-Type'] = content_type
- headers['Content-Length'] = len(buffer)
- status_code = 200
- if self.verbosity > 1:
- sys.stdout.write(str(status_code) + ' ' + self.command + ' ' + self.path + '\n')
- sys.stdout.flush()
- self.send_response(status_code)
- for key, value in headers.items():
- self.send_header(key, value)
- self.end_headers()
- if self.command != 'HEAD':
- if status_code == 404 and buffer is None:
- self.wfile.write(bytes(status_code))
- elif (status_code in (200, 404)) and buffer is not None:
- self.wfile.write(buffer)
- def do_HEAD(self): # pylint: disable=invalid-name
- ''' Serve a HEAD request '''
- self.do_GET()
- def log_message(self, format, *args): # pylint: disable=redefined-builtin
- return
- class HTTPServerThread(threading.Thread):
- ''' HTTP Server Thread '''
- def __init__(self, data, file, address, verbosity):
- threading.Thread.__init__(self)
- class ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
- ''' Threaded HTTP Server '''
- self.verbosity = verbosity
- self.address = address
- self.url = 'http://' + address[0] + ':' + str(address[1])
- self.file = file
- self.server = ThreadedHTTPServer(address, HTTPRequestHandler)
- self.server.timeout = 0.25
- if file:
- folder = os.path.dirname(file) if os.path.dirname(file) else '.'
- self.server.RequestHandlerClass.folder = folder
- self.server.RequestHandlerClass.file = os.path.basename(file)
- else:
- self.server.RequestHandlerClass.folder = ''
- self.server.RequestHandlerClass.file = ''
- self.server.RequestHandlerClass.data = data
- self.server.RequestHandlerClass.verbosity = verbosity
- self.terminate_event = threading.Event()
- self.terminate_event.set()
- self.stop_event = threading.Event()
- def run(self):
- self.stop_event.clear()
- self.terminate_event.clear()
- try:
- while not self.stop_event.is_set():
- self.server.handle_request()
- except: # pylint: disable=bare-except
- pass
- self.terminate_event.set()
- self.stop_event.clear()
- def stop(self):
- ''' Stop server '''
- if self.alive():
- if self.verbosity > 0:
- sys.stdout.write("Stopping " + self.url + "\n")
- sys.stdout.flush()
- self.stop_event.set()
- self.server.server_close()
- self.terminate_event.wait(1000)
- def alive(self):
- ''' Check server status '''
- return not self.terminate_event.is_set()
- _thread_list = []
- def _add_thread(thread):
- global _thread_list
- _thread_list.append(thread)
- def _update_thread_list(address=None):
- global _thread_list
- _thread_list = [ thread for thread in _thread_list if thread.alive() ]
- threads = _thread_list
- if address is not None:
- address = _make_address(address)
- threads = [ _ for _ in threads if address[0] == _.address[0] ]
- if address[1]:
- threads = [ _ for _ in threads if address[1] == _.address[1] ]
- return threads
- def _make_address(address):
- if address is None or isinstance(address, int):
- port = address
- address = ('localhost', port)
- if isinstance(address, tuple) and len(address) == 2:
- host = address[0]
- port = address[1]
- if isinstance(host, str) and (port is None or isinstance(port, int)):
- return address
- raise ValueError('Invalid address.')
- def _make_port(address):
- if address[1] is None or address[1] == 0:
- ports = []
- if address[1] != 0:
- ports.append(8080)
- ports.append(8081)
- rnd = random.Random()
- for _ in range(4):
- port = rnd.randrange(15000, 25000)
- if port not in ports:
- ports.append(port)
- ports.append(0)
- for port in ports:
- temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- temp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- temp_socket.settimeout(1)
- try:
- temp_socket.bind((address[0], port))
- sockname = temp_socket.getsockname()
- address = (address[0], sockname[1])
- return address
- except: # pylint: disable=bare-except
- pass
- finally:
- temp_socket.close()
- if isinstance(address[1], int):
- return address
- raise ValueError('Failed to allocate port.')
- def stop(address=None):
- '''Stop serving model at address.
- Args:
- address (tuple, optional): A (host, port) tuple, or a port number.
- '''
- threads = _update_thread_list(address)
- for thread in threads:
- thread.stop()
- _update_thread_list()
- def status(adrress=None):
- '''Is model served at address.
- Args:
- address (tuple, optional): A (host, port) tuple, or a port number.
- '''
- threads = _update_thread_list(adrress)
- return len(threads) > 0
- def wait():
- '''Wait for console exit and stop all model servers.'''
- try:
- while len(_update_thread_list()) > 0:
- time.sleep(1000)
- except (KeyboardInterrupt, SystemExit):
- sys.stdout.write('\n')
- sys.stdout.flush()
- stop()
- def serve(file, data, address=None, browse=False, verbosity=1):
- '''Start serving model from file or data buffer at address and open in web browser.
- Args:
- file (string): Model file to serve. Required to detect format.
- data (bytes): Model data to serve. None will load data from file.
- address (tuple, optional): A (host, port) tuple, or a port number.
- browse (bool, optional): Launch web browser. Default: True
- log (bool, optional): Log details to console. Default: False
- Returns:
- A (host, port) address tuple.
- '''
- verbosity = { '0': 0, 'quiet': 0, '1': 1, 'default': 1, '2': 2, 'debug': 2 }[str(verbosity)]
- if not data and file and not os.path.exists(file):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
- if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
- registry = dict([
- ('onnx.onnx_ml_pb2.ModelProto', '.onnx'),
- ('torch.Graph', '.pytorch'),
- ('torch._C.Graph', '.pytorch'),
- ('torch.nn.modules.module.Module', '.pytorch')
- ])
- queue = [ data.__class__ ]
- while len(queue) > 0:
- current = queue.pop(0)
- if current.__module__ and current.__name__:
- name = current.__module__ + '.' + current.__name__
- if name in registry:
- module_name = registry[name]
- if module_name.startswith('.'):
- file = os.path.join(os.path.dirname(__file__), module_name[1:] + '.py')
- spec = importlib.util.spec_from_file_location(module_name, file)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- else:
- module = __import__(module_name)
- model_factory = module.ModelFactory()
- if verbosity > 1:
- sys.stdout.write('Experimental\n')
- sys.stdout.flush()
- data = model_factory.serialize(data)
- file = 'test.json'
- break
- for base in current.__bases__:
- if isinstance(base, type):
- queue.append(base)
- _update_thread_list()
- address = _make_address(address)
- if isinstance(address[1], int) and address[1] != 0:
- stop(address)
- else:
- address = _make_port(address)
- _update_thread_list()
- thread = HTTPServerThread(data, file, address, verbosity)
- thread.start()
- while not thread.alive():
- time.sleep(10)
- _add_thread(thread)
- if file:
- if verbosity > 0:
- sys.stdout.write("Serving '" + file + "' at " + thread.url + "\n")
- sys.stdout.flush()
- else:
- if verbosity > 0:
- sys.stdout.write("Serving at " + thread.url + "\n")
- sys.stdout.flush()
- if browse:
- webbrowser.open(thread.url)
- return address
- def start(file=None, address=None, browse=True, verbosity=1):
- '''Start serving model file at address and open in web browser.
- Args:
- file (string): Model file to serve.
- log (bool, optional): Log details to console. Default: False
- browse (bool, optional): Launch web browser, Default: True
- address (tuple, optional): A (host, port) tuple, or a port number.
- Returns:
- A (host, port) address tuple.
- '''
- return serve(file, None, browse=browse, address=address, verbosity=verbosity)
|