Przeglądaj źródła

Fix lint issues

Lutz Roeder 3 lat temu
rodzic
commit
369adeaa7f
4 zmienionych plików z 29 dodań i 24 usunięć
  1. 7 2
      source/__init__.py
  2. 13 12
      source/onnx.py
  3. 5 5
      source/pytorch.py
  4. 4 5
      source/server.py

+ 7 - 2
source/__init__.py

@@ -1,3 +1,5 @@
+''' Python Server entry point '''
+
 import argparse
 import sys
 import os
@@ -10,8 +12,11 @@ from .server import serve
 from .__version__ import __version__
 
 def main():
-    parser = argparse.ArgumentParser(description='Viewer for neural network, deep learning and machine learning models.')
-    parser.add_argument('file', metavar='MODEL_FILE', help='model file to serve', nargs='?', default=None)
+    ''' main entry point '''
+    parser = argparse.ArgumentParser(
+        description='Viewer for neural network, deep learning and machine learning models.')
+    parser.add_argument('file',
+        metavar='MODEL_FILE', help='model file to serve', nargs='?', default=None)
     parser.add_argument('-v', '--version', help="print version", action='store_true')
     parser.add_argument('-b', '--browse', help='launch web browser', action='store_true')
     parser.add_argument('-p', '--port', help='port to serve', type=int)

+ 13 - 12
source/onnx.py

@@ -105,41 +105,42 @@ class ModelFactory:
                         'arguments': [ argument(value) ]
                     })
             json_node['attributes'] = []
+            AttributeProto = onnx.onnx_pb.AttributeProto
             for _ in node.attribute:
-                if _.type == onnx.onnx_pb.AttributeProto.UNDEFINED:
+                if _.type == AttributeProto.UNDEFINED:
                     attribute_type = None
                     value = None
-                elif _.type == onnx.onnx_pb.AttributeProto.FLOAT:
+                elif _.type == AttributeProto.FLOAT:
                     attribute_type = 'float32'
                     value = _.f
-                elif _.type == onnx.onnx_pb.AttributeProto.INT:
+                elif _.type == AttributeProto.INT:
                     attribute_type = 'int64'
                     value = _.i
-                elif _.type == onnx.onnx_pb.AttributeProto.STRING:
+                elif _.type == AttributeProto.STRING:
                     attribute_type = 'string'
                     value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
-                elif _.type == onnx.onnx_pb.AttributeProto.TENSOR:
+                elif _.type == AttributeProto.TENSOR:
                     attribute_type = 'tensor'
                     value = tensor(_.t)
-                elif _.type == onnx.onnx_pb.AttributeProto.GRAPH:
+                elif _.type == AttributeProto.GRAPH:
                     attribute_type = 'tensor'
                     raise Exception('Unsupported graph attribute type')
-                elif _.type == onnx.onnx_pb.AttributeProto.FLOATS:
+                elif _.type == AttributeProto.FLOATS:
                     attribute_type = 'float32[]'
                     value = [ item for item in _.floats ]
-                elif _.type == onnx.onnx_pb.AttributeProto.INTS:
+                elif _.type == AttributeProto.INTS:
                     attribute_type = 'int64[]'
                     value = [ item for item in _.ints ]
-                elif _.type == onnx.onnx_pb.AttributeProto.STRINGS:
+                elif _.type == AttributeProto.STRINGS:
                     attribute_type = 'string[]'
                     value = [ item.decode('utf-8') for item in _.strings ]
-                elif _.type == onnx.onnx_pb.AttributeProto.TENSORS:
+                elif _.type == AttributeProto.TENSORS:
                     attribute_type = 'tensor[]'
                     raise Exception('Unsupported tensors attribute type')
-                elif _.type == onnx.onnx_pb.AttributeProto.GRAPHS:
+                elif _.type == AttributeProto.GRAPHS:
                     attribute_type = 'graph[]'
                     raise Exception('Unsupported graphs attribute type')
-                elif _.type == onnx.onnx_pb.AttributeProto.SPARSE_TENSOR:
+                elif _.type == AttributeProto.SPARSE_TENSOR:
                     attribute_type = 'tensor'
                     value = tensor(_.sparse_tensor)
                 else:

+ 5 - 5
source/pytorch.py

@@ -27,11 +27,11 @@ class ModelFactory:
         json_graph['outputs'] = []
         json_model['graphs'].append(json_graph)
         data_type_map = dict([
-            [ torch.float16, 'float16'],
-            [ torch.float32, 'float32'],
-            [ torch.float64, 'float64'],
-            [ torch.int32, 'int32'],
-            [ torch.int64, 'int64'],
+            [ torch.float16, 'float16'], # pylint: disable=no-member
+            [ torch.float32, 'float32'], # pylint: disable=no-member
+            [ torch.float64, 'float64'], # pylint: disable=no-member
+            [ torch.int32, 'int32'], # pylint: disable=no-member
+            [ torch.int64, 'int64'], # pylint: disable=no-member
         ])
         arguments_map = {}
         def argument(value):

+ 4 - 5
source/server.py

@@ -1,4 +1,4 @@
-''' Python Server '''
+''' Python Server implementation '''
 
 import codecs
 import errno
@@ -166,10 +166,9 @@ def _update_thread_list(address=None):
     threads = _thread_list
     if address is not None:
         address = _make_address(address)
-        if address[1] is None:
-            threads = [ _ for _ in threads if address[0] == _.address[0] ]
-        else:
-            threads = [ _ for _ in threads if address[0] == _.address[0] and address[1] == _.address[1] ]
+        threads = [ _ for _ in threads if address[0] == _.address[0] ]
+        if address[1]:
+            threads = [ _ for _ in threads if address[1] == _.address[1] ]
     return threads
 
 def _make_address(address):