Bladeren bron

Python logging

Lutz Roeder 9 maanden geleden
bovenliggende
commit
32e7ebead2
7 gewijzigde bestanden met toevoegingen van 67 en 49 verwijderingen
  1. 13 5
      source/__init__.py
  2. 1 1
      source/onnx.py
  3. 19 29
      source/server.py
  4. 15 5
      test/backend.py
  5. 10 4
      test/measures.py
  6. 6 1
      tools/pytorch_script.py
  7. 3 4
      tools/tf_script.py

+ 13 - 5
source/__init__.py

@@ -1,6 +1,7 @@
 """ Python Server entry point """
 
 import argparse
+import logging
 import os
 import sys
 
@@ -20,18 +21,25 @@ def main():
     parser.add_argument("--host",
         metavar="ADDR", help="host to serve", default="localhost")
     parser.add_argument("--verbosity",
-        metavar="LEVEL", help="output verbosity (quiet, default, debug)",
-        choices=[ "quiet", "default", "debug", "0", "1", "2" ], default="default")
+        metavar="LEVEL", help="log verbosity (quiet, default, debug)",
+        choices=[ "quiet", "debug", "default" ], default="default")
     parser.add_argument("--version", help="print version", action="store_true")
     args = parser.parse_args()
+    levels = {
+        "quiet": logging.CRITICAL,
+        "default": logging.INFO,
+        "debug": logging.DEBUG,
+    }
+    logging.basicConfig(level=levels[args.verbosity], format="%(message)s")
+    logger = logging.getLogger(__name__)
     if args.file and not os.path.exists(args.file):
-        print("Model file '" + args.file + "' does not exist.")
+        logger.error(f"Model file '{args.file}' does not exist.")
         sys.exit(2)
     if args.version:
-        print(__version__)
+        logger.info(__version__)
         sys.exit(0)
     address = (args.host, args.port) if args.host else args.port if args.port else None
-    start(args.file, address=address, browse=args.browse, verbosity=args.verbosity)
+    start(args.file, address=address, browse=args.browse)
     wait()
     sys.exit(0)
 

+ 1 - 1
source/onnx.py

@@ -112,7 +112,7 @@ class _Graph:
             attribute_type = "tensor"
             value = self._tensor(_.t)
         elif _.type == _AttributeType.GRAPH:
-            attribute_type = "tensor"
+            attribute_type = "graph"
             raise Exception("Unsupported graph attribute type")
         elif _.type == _AttributeType.FLOATS:
             attribute_type = "float32[]"

+ 19 - 29
source/server.py

@@ -2,14 +2,14 @@
 
 import errno
 import http.server
-import importlib.util
+import importlib
 import json
+import logging
 import os
 import random
 import re
 import socket
 import socketserver
-import sys
 import threading
 import time
 import urllib.parse
@@ -17,6 +17,8 @@ import webbrowser
 
 __version__ = "0.0.0"
 
+logger = logging.getLogger(__name__)
+
 class _ContentProvider:
     data = bytearray()
     base_dir = ""
@@ -43,7 +45,6 @@ class _ContentProvider:
 
 class _HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
     content = None
-    verbosity = 1
     mime_types = {
         ".html": "text/html",
         ".js":   "text/javascript",
@@ -107,10 +108,9 @@ class _HTTPRequestHandler(http.server.BaseHTTPRequestHandler):
                     content = re.sub(regex, lambda _: meta, content)
                     content = content.encode("utf-8")
                 status_code = 200
-        _log(self.verbosity > 1, f"{str(status_code)} {self.command} {self.path}\n")
         self._write(status_code, content_type, content)
     def log_message(self, format, *args):
-        return
+        logger.debug(" ".join(args))
     def _write(self, status_code, content_type, content):
         self.send_response(status_code)
         if content:
@@ -127,16 +127,14 @@ class _ThreadedHTTPServer(socketserver.ThreadingMixIn, http.server.HTTPServer):
     pass
 
 class _HTTPServerThread(threading.Thread):
-    def __init__(self, content, address, verbosity):
+    def __init__(self, content, address):
         threading.Thread.__init__(self)
-        self.verbosity = verbosity
         self.address = address
         self.url = "http://" + address[0] + ":" + str(address[1])
         self.server = _ThreadedHTTPServer(address, _HTTPRequestHandler)
         self.server.timeout = 0.25
         self.server.block_on_close = False
         self.server.RequestHandlerClass.content = content
-        self.server.RequestHandlerClass.verbosity = verbosity
         self.terminate_event = threading.Event()
         self.terminate_event.set()
         self.stop_event = threading.Event()
@@ -155,10 +153,10 @@ class _HTTPServerThread(threading.Thread):
     def stop(self):
         """ Stop server """
         if self.alive():
-            _log(self.verbosity > 0, "Stopping " + self.url + "\n")
+            logger.info("Stopping " + self.url)
             self.stop_event.set()
             self.server.server_close()
-            self.terminate_event.wait(1000)
+            self.terminate_event.wait(1)
 
     def alive(self):
         """ Check server status """
@@ -198,11 +196,6 @@ def _threads(address=None):
             threads = [ _ for _ in threads if address[1] == _.address[1] ]
     return threads
 
-def _log(condition, message):
-    if condition:
-        sys.stdout.write(message)
-        sys.stdout.flush()
-
 def _make_address(address):
     if address is None or isinstance(address, int):
         port = address
@@ -253,13 +246,13 @@ def stop(address=None):
     for thread in threads:
         thread.stop()
 
-def status(adrress=None):
+def status(address=None):
     """Is model served at address.
 
     Args:
         address (tuple, optional): A (host, port) tuple, or a port number.
     """
-    threads = _threads(adrress)
+    threads = _threads(address)
     return len(threads) > 0
 
 def wait():
@@ -268,10 +261,10 @@ def wait():
         while len(_threads()) > 0:
             time.sleep(0.1)
     except (KeyboardInterrupt, SystemExit):
-        _log(True, "\n")
+        logger.info("")
         stop()
 
-def serve(file, data=None, address=None, browse=False, verbosity=1):
+def serve(file, data=None, address=None, browse=False):
     """Start serving model from file or data buffer at address and open in web browser.
 
     Args:
@@ -279,13 +272,12 @@ def serve(file, data=None, address=None, browse=False, verbosity=1):
         data (bytes): Model data to serve. None will load data from file.
         address (tuple, optional): A (host, port) tuple, or a port number.
         browse (bool, optional): Launch web browser. Default: True
-        log (bool, optional): Log details to console. Default: False
 
     Returns:
         A (host, port) address tuple.
     """
-    verbosities = { "0": 0, "quiet": 0, "1": 1, "default": 1, "2": 2, "debug": 2 }
-    verbosity = verbosities[str(verbosity)]
+    if not logging.getLogger().hasHandlers():
+        logging.basicConfig(level=logging.INFO, format="%(message)s")
 
     if not data and file and not os.path.exists(file):
         raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
@@ -293,7 +285,7 @@ def serve(file, data=None, address=None, browse=False, verbosity=1):
     content = _ContentProvider(data, file, file, file)
 
     if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
-        _log(verbosity > 1, "Experimental\n")
+        logger.info("Experimental")
         model = _open(data)
         if model:
             text = json.dumps(model.to_json(), indent=2, ensure_ascii=False)
@@ -305,31 +297,29 @@ def serve(file, data=None, address=None, browse=False, verbosity=1):
     else:
         address = _make_port(address)
 
-    thread = _HTTPServerThread(content, address, verbosity)
+    thread = _HTTPServerThread(content, address)
     thread.start()
     while not thread.alive():
         time.sleep(0.01)
     state = ("Serving '" + file + "'") if file else "Serving"
-    message = f"{state} at {thread.url}\n"
-    _log(verbosity > 0, message)
+    logger.info(f"{state} at {thread.url}")
     if browse:
         webbrowser.open(thread.url)
 
     return address
 
-def start(file=None, address=None, browse=True, verbosity=1):
+def start(file=None, address=None, browse=True):
     """Start serving model file at address and open in web browser.
 
     Args:
         file (string): Model file to serve.
-        log (bool, optional): Log details to console. Default: False
         browse (bool, optional): Launch web browser, Default: True
         address (tuple, optional): A (host, port) tuple, or a port number.
 
     Returns:
         A (host, port) address tuple.
     """
-    return serve(file, None, browse=browse, address=address, verbosity=verbosity)
+    return serve(file, None, browse=browse, address=address)
 
 def widget(address, height=800):
     """ Open address as Jupyter Notebook IFrame.

+ 15 - 5
test/backend.py

@@ -2,6 +2,7 @@
 
 """ Expermiental Python Server backend test """
 
+import logging
 import os
 import sys
 
@@ -13,6 +14,9 @@ netron = __import__("source")
 third_party_dir = os.path.join(root_dir, "third_party")
 test_data_dir = os.path.join(third_party_dir, "test")
 
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO, format="%(message)s")
+
 def _test_onnx():
     file = os.path.join(test_data_dir, "onnx", "candy.onnx")
     onnx = __import__("onnx")
@@ -20,16 +24,22 @@ def _test_onnx():
     netron.serve(None, model)
 
 def _test_onnx_iterate():
+    logging.getLogger(netron.__name__).setLevel(logging.WARNING)
     folder = os.path.join(test_data_dir, "onnx")
     for item in os.listdir(folder):
         file = os.path.join(folder, item)
-        if file.endswith(".onnx") and \
-            item != "super_resolution.onnx" and \
-            item != "arcface-resnet100.onnx":
-            print(item)
+        skip = (
+            "super_resolution.onnx",
+            "arcface-resnet100.onnx",
+            "aten_sum_dim_onnx_inlined.onnx",
+            "phi3-mini-128k-instruct-cuda-fp16.onnx",
+            "if_k1.onnx"
+        )
+        if file.endswith(".onnx") and item not in skip:
+            logger.info(item)
             onnx = __import__("onnx")
             model = onnx.load(file)
-            address = netron.serve(file, model, verbosity="quiet")
+            address = netron.serve(file, model)
             netron.stop(address)
 
 def _test_torchscript(file):

+ 10 - 4
test/measures.py

@@ -2,10 +2,16 @@
 
 """ Test Measures Script """
 
+import logging
+
 import pandas
 
 pandas.set_option("display.max_rows", None)
 
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO, format="%(message)s")
+
+
 def _summarize(summary_df, measures_df, column, threshold):
     measures_df = measures_df.sort_values(column, ascending=False)
     total = measures_df[column].sum()
@@ -19,10 +25,10 @@ def main():
     measures_df = pandas.read_csv("dist/test/measures.csv")
     measures_df.fillna(0, inplace=True)
     summary_df = pandas.DataFrame(columns=[ "Name", "Total", "Top", "Count", "Ratio" ])
-    print(_summarize(summary_df, measures_df, "load", 1))
-    print(_summarize(summary_df, measures_df, "validate", 1))
-    print(_summarize(summary_df, measures_df, "render", 1))
-    print(summary_df.to_string(index=False))
+    logger.info(_summarize(summary_df, measures_df, "load", 1))
+    logger.info(_summarize(summary_df, measures_df, "validate", 1))
+    logger.info(_summarize(summary_df, measures_df, "render", 1))
+    logger.info(summary_df.to_string(index=False))
 
 if __name__ == "__main__":
     main()

+ 6 - 1
tools/pytorch_script.py

@@ -2,6 +2,7 @@
 
 import collections
 import json
+import logging
 import os
 import re
 import sys
@@ -15,6 +16,10 @@ third_party_dir = os.path.join(root_dir, "third_party")
 metadata_file = os.path.join(source_dir, "pytorch-metadata.json")
 pytorch_source_dir = os.path.join(third_party_dir, "source", "pytorch")
 
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO, format="%(message)s")
+
+
 def _read(path):
     with open(path, encoding="utf-8") as file:
         return file.read()
@@ -333,7 +338,7 @@ def _parse_schemas():
         if key not in schemas:
             schemas[key] = schema
         else:
-            print(f"-> {key}")
+            logging.warning(f"-> {key}")
     return schemas
 
 def _filter_schemas(schemas, types):

+ 3 - 4
tools/tf_script.py

@@ -146,10 +146,9 @@ attr_type_table = {
 }
 
 def _convert_attr_type(attr_type):
-    if attr_type in attr_type_table:
-        return attr_type_table[attr_type]
-    print(attr_type)
-    return attr_type
+    if attr_type not in attr_type_table:
+        raise ValueError(f"Unknown attribute type '{attr_type}'")
+    return attr_type_table[attr_type]
 
 def _convert_attr_list(attr_value):
     result = []