2
0

server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. ''' Python Server implementation '''
  2. import codecs
  3. import errno
  4. import http.server
  5. import importlib.util
  6. import os
  7. import random
  8. import re
  9. import socket
  10. import socketserver
  11. import sys
  12. import threading
  13. import time
  14. import webbrowser
  15. import urllib.parse
  16. __version__ = '0.0.0'
  17. class HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
  18. ''' HTTP Request Handler '''
  19. file = ""
  20. data = bytearray()
  21. folder = ""
  22. log = False
  23. mime_types_map = {
  24. '.html': 'text/html',
  25. '.js': 'text/javascript',
  26. '.css': 'text/css',
  27. '.png': 'image/png',
  28. '.gif': 'image/gif',
  29. '.jpg': 'image/jpeg',
  30. '.ico': 'image/x-icon',
  31. '.json': 'application/json',
  32. '.pb': 'application/octet-stream',
  33. '.ttf': 'font/truetype',
  34. '.otf': 'font/opentype',
  35. '.eot': 'application/vnd.ms-fontobject',
  36. '.woff': 'font/woff',
  37. '.woff2': 'application/font-woff2',
  38. '.svg': 'image/svg+xml'
  39. }
  40. def do_GET(self): # pylint: disable=invalid-name
  41. ''' Serve a GET request '''
  42. path = urllib.parse.urlparse(self.path).path
  43. status_code = 0
  44. headers = {}
  45. buffer = None
  46. if path in ('/', '/index.html'):
  47. meta = []
  48. meta.append('<meta name="type" content="Python">')
  49. meta.append('<meta name="version" content="' + __version__ + '">')
  50. if self.file:
  51. meta.append('<meta name="file" content="/data/' + self.file + '">')
  52. basedir = os.path.dirname(os.path.realpath(__file__))
  53. with codecs.open(basedir + '/index.html', mode="r", encoding="utf-8") as open_file:
  54. buffer = open_file.read()
  55. meta = '\n'.join(meta)
  56. buffer = re.sub(r'<meta name="version" content="\d+.\d+.\d+">', meta, buffer)
  57. buffer = buffer.encode('utf-8')
  58. headers['Content-Type'] = 'text/html'
  59. headers['Content-Length'] = len(buffer)
  60. status_code = 200
  61. elif path.startswith('/data/'):
  62. status_code = 404
  63. path = urllib.parse.unquote(path[len('/data/'):])
  64. if path == self.file and self.data:
  65. buffer = self.data
  66. else:
  67. basedir = os.path.realpath(self.folder)
  68. path = os.path.normpath(os.path.realpath(basedir + '/' + path))
  69. if os.path.commonprefix([basedir, path]) == basedir:
  70. if os.path.exists(path) and not os.path.isdir(path):
  71. with open(path, 'rb') as file:
  72. buffer = file.read()
  73. if buffer:
  74. headers['Content-Type'] = 'application/octet-stream'
  75. headers['Content-Length'] = len(buffer)
  76. status_code = 200
  77. else:
  78. status_code = 404
  79. basedir = os.path.dirname(os.path.realpath(__file__))
  80. path = os.path.normpath(os.path.realpath(basedir + path))
  81. if os.path.commonprefix([basedir, path]) == basedir:
  82. if os.path.exists(path) and not os.path.isdir(path):
  83. extension = os.path.splitext(path)[1]
  84. content_type = self.mime_types_map[extension]
  85. if content_type:
  86. with open(path, 'rb') as file:
  87. buffer = file.read()
  88. headers['Content-Type'] = content_type
  89. headers['Content-Length'] = len(buffer)
  90. status_code = 200
  91. if self.log:
  92. sys.stdout.write(str(status_code) + ' ' + self.command + ' ' + self.path + '\n')
  93. sys.stdout.flush()
  94. self.send_response(status_code)
  95. for key, value in headers.items():
  96. self.send_header(key, value)
  97. self.end_headers()
  98. if self.command != 'HEAD':
  99. if status_code == 404 and buffer is None:
  100. self.wfile.write(bytes(status_code))
  101. elif (status_code in (200, 404)) and buffer is not None:
  102. self.wfile.write(buffer)
  103. def do_HEAD(self): # pylint: disable=invalid-name
  104. ''' Serve a HEAD request '''
  105. self.do_GET()
  106. def log_message(self, format, *args): # pylint: disable=redefined-builtin
  107. return
  108. class HTTPServerThread(threading.Thread):
  109. ''' HTTP Server Thread '''
  110. def __init__(self, data, file, address, log):
  111. threading.Thread.__init__(self)
  112. class ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
  113. ''' Threaded HTTP Server '''
  114. self.address = address
  115. self.url = 'http://' + address[0] + ':' + str(address[1])
  116. self.file = file
  117. self.server = ThreadedHTTPServer(address, HTTPRequestHandler)
  118. self.server.timeout = 0.25
  119. if file:
  120. folder = os.path.dirname(file) if os.path.dirname(file) else '.'
  121. self.server.RequestHandlerClass.folder = folder
  122. self.server.RequestHandlerClass.file = os.path.basename(file)
  123. else:
  124. self.server.RequestHandlerClass.folder = ''
  125. self.server.RequestHandlerClass.file = ''
  126. self.server.RequestHandlerClass.data = data
  127. self.server.RequestHandlerClass.log = log
  128. self.terminate_event = threading.Event()
  129. self.terminate_event.set()
  130. self.stop_event = threading.Event()
  131. def run(self):
  132. self.stop_event.clear()
  133. self.terminate_event.clear()
  134. try:
  135. while not self.stop_event.is_set():
  136. self.server.handle_request()
  137. except: # pylint: disable=bare-except
  138. pass
  139. self.terminate_event.set()
  140. self.stop_event.clear()
  141. def stop(self):
  142. ''' Stop server '''
  143. if self.alive():
  144. sys.stdout.write("Stopping " + self.url + "\n")
  145. self.stop_event.set()
  146. self.server.server_close()
  147. self.terminate_event.wait(1000)
  148. def alive(self):
  149. ''' Check server status '''
  150. return not self.terminate_event.is_set()
  151. _thread_list = []
  152. def _add_thread(thread):
  153. global _thread_list
  154. _thread_list.append(thread)
  155. def _update_thread_list(address=None):
  156. global _thread_list
  157. _thread_list = [ thread for thread in _thread_list if thread.alive() ]
  158. threads = _thread_list
  159. if address is not None:
  160. address = _make_address(address)
  161. threads = [ _ for _ in threads if address[0] == _.address[0] ]
  162. if address[1]:
  163. threads = [ _ for _ in threads if address[1] == _.address[1] ]
  164. return threads
  165. def _make_address(address):
  166. if address is None or isinstance(address, int):
  167. port = address
  168. address = ('localhost', port)
  169. if isinstance(address, tuple) and len(address) == 2:
  170. host = address[0]
  171. port = address[1]
  172. if isinstance(host, str) and (port is None or isinstance(port, int)):
  173. return address
  174. raise ValueError('Invalid address.')
  175. def _make_port(address):
  176. if address[1] is None or address[1] == 0:
  177. ports = []
  178. if address[1] != 0:
  179. ports.append(8080)
  180. ports.append(8081)
  181. rnd = random.Random()
  182. for _ in range(4):
  183. port = rnd.randrange(15000, 25000)
  184. if port not in ports:
  185. ports.append(port)
  186. ports.append(0)
  187. for port in ports:
  188. temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  189. temp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  190. temp_socket.settimeout(1)
  191. try:
  192. temp_socket.bind((address[0], port))
  193. sockname = temp_socket.getsockname()
  194. address = (address[0], sockname[1])
  195. return address
  196. except: # pylint: disable=bare-except
  197. pass
  198. finally:
  199. temp_socket.close()
  200. if isinstance(address[1], int):
  201. return address
  202. raise ValueError('Failed to allocate port.')
  203. def stop(address=None):
  204. '''Stop serving model at address.
  205. Args:
  206. address (tuple, optional): A (host, port) tuple, or a port number.
  207. '''
  208. threads = _update_thread_list(address)
  209. for thread in threads:
  210. thread.stop()
  211. _update_thread_list()
  212. def status(adrress=None):
  213. '''Is model served at address.
  214. Args:
  215. address (tuple, optional): A (host, port) tuple, or a port number.
  216. '''
  217. threads = _update_thread_list(adrress)
  218. return len(threads) > 0
  219. def wait():
  220. '''Wait for console exit and stop all model servers.'''
  221. try:
  222. while len(_update_thread_list()) > 0:
  223. time.sleep(1000)
  224. except (KeyboardInterrupt, SystemExit):
  225. sys.stdout.write('\n')
  226. sys.stdout.flush()
  227. stop()
  228. def serve(file, data, address=None, browse=False, log=False):
  229. '''Start serving model from file or data buffer at address and open in web browser.
  230. Args:
  231. file (string): Model file to serve. Required to detect format.
  232. data (bytes): Model data to serve. None will load data from file.
  233. address (tuple, optional): A (host, port) tuple, or a port number.
  234. browse (bool, optional): Launch web browser. Default: True
  235. log (bool, optional): Log details to console. Default: False
  236. Returns:
  237. A (host, port) address tuple.
  238. '''
  239. if not data and file and not os.path.exists(file):
  240. raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
  241. if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
  242. registry = dict([
  243. ('onnx.onnx_ml_pb2.ModelProto', '.onnx'),
  244. ('torch.Graph', '.pytorch'),
  245. ('torch._C.Graph', '.pytorch'),
  246. ('torch.nn.modules.module.Module', '.pytorch')
  247. ])
  248. queue = [ data.__class__ ]
  249. while len(queue) > 0:
  250. current = queue.pop(0)
  251. if current.__module__ and current.__name__:
  252. name = current.__module__ + '.' + current.__name__
  253. if name in registry:
  254. module_name = registry[name]
  255. if module_name.startswith('.'):
  256. file = os.path.join(os.path.dirname(__file__), module_name[1:] + '.py')
  257. spec = importlib.util.spec_from_file_location(module_name, file)
  258. module = importlib.util.module_from_spec(spec)
  259. spec.loader.exec_module(module)
  260. else:
  261. module = __import__(module_name)
  262. model_factory = module.ModelFactory()
  263. data = model_factory.serialize(data)
  264. file = 'test.json'
  265. break
  266. for base in current.__bases__:
  267. if isinstance(base, type):
  268. queue.append(base)
  269. _update_thread_list()
  270. address = _make_address(address)
  271. if isinstance(address[1], int) and address[1] != 0:
  272. stop(address)
  273. else:
  274. address = _make_port(address)
  275. _update_thread_list()
  276. thread = HTTPServerThread(data, file, address, log)
  277. thread.start()
  278. while not thread.alive():
  279. time.sleep(10)
  280. _add_thread(thread)
  281. if file:
  282. sys.stdout.write("Serving '" + file + "' at " + thread.url + "\n")
  283. else:
  284. sys.stdout.write("Serving at " + thread.url + "\n")
  285. sys.stdout.flush()
  286. if browse:
  287. webbrowser.open(thread.url)
  288. return address
  289. def start(file=None, address=None, browse=True, log=False):
  290. '''Start serving model file at address and open in web browser.
  291. Args:
  292. file (string): Model file to serve.
  293. log (bool, optional): Log details to console. Default: False
  294. browse (bool, optional): Launch web browser, Default: True
  295. address (tuple, optional): A (host, port) tuple, or a port number.
  296. Returns:
  297. A (host, port) address tuple.
  298. '''
  299. return serve(file, None, browse=browse, address=address, log=log)