|
|
@@ -105,41 +105,42 @@ class ModelFactory:
|
|
|
'arguments': [ argument(value) ]
|
|
|
})
|
|
|
json_node['attributes'] = []
|
|
|
+ AttributeProto = onnx.onnx_pb.AttributeProto
|
|
|
for _ in node.attribute:
|
|
|
- if _.type == onnx.onnx_pb.AttributeProto.UNDEFINED:
|
|
|
+ if _.type == AttributeProto.UNDEFINED:
|
|
|
attribute_type = None
|
|
|
value = None
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.FLOAT:
|
|
|
+ elif _.type == AttributeProto.FLOAT:
|
|
|
attribute_type = 'float32'
|
|
|
value = _.f
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.INT:
|
|
|
+ elif _.type == AttributeProto.INT:
|
|
|
attribute_type = 'int64'
|
|
|
value = _.i
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.STRING:
|
|
|
+ elif _.type == AttributeProto.STRING:
|
|
|
attribute_type = 'string'
|
|
|
value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.TENSOR:
|
|
|
+ elif _.type == AttributeProto.TENSOR:
|
|
|
attribute_type = 'tensor'
|
|
|
value = tensor(_.t)
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.GRAPH:
|
|
|
+ elif _.type == AttributeProto.GRAPH:
|
|
|
attribute_type = 'tensor'
|
|
|
raise Exception('Unsupported graph attribute type')
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.FLOATS:
|
|
|
+ elif _.type == AttributeProto.FLOATS:
|
|
|
attribute_type = 'float32[]'
|
|
|
value = [ item for item in _.floats ]
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.INTS:
|
|
|
+ elif _.type == AttributeProto.INTS:
|
|
|
attribute_type = 'int64[]'
|
|
|
value = [ item for item in _.ints ]
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.STRINGS:
|
|
|
+ elif _.type == AttributeProto.STRINGS:
|
|
|
attribute_type = 'string[]'
|
|
|
value = [ item.decode('utf-8') for item in _.strings ]
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.TENSORS:
|
|
|
+ elif _.type == AttributeProto.TENSORS:
|
|
|
attribute_type = 'tensor[]'
|
|
|
raise Exception('Unsupported tensors attribute type')
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.GRAPHS:
|
|
|
+ elif _.type == AttributeProto.GRAPHS:
|
|
|
attribute_type = 'graph[]'
|
|
|
raise Exception('Unsupported graphs attribute type')
|
|
|
- elif _.type == onnx.onnx_pb.AttributeProto.SPARSE_TENSOR:
|
|
|
+ elif _.type == AttributeProto.SPARSE_TENSOR:
|
|
|
attribute_type = 'tensor'
|
|
|
value = tensor(_.sparse_tensor)
|
|
|
else:
|