|
|
@@ -106,7 +106,7 @@ class OnnxModel:
|
|
|
self.remove_tensor_data(initializer)
|
|
|
for node in model.graph.node:
|
|
|
for attribute in node.attribute:
|
|
|
- if attribute.t:
|
|
|
+ if attribute.HasField('t'):
|
|
|
self.remove_tensor_data(attribute.t)
|
|
|
self.data = model.SerializeToString()
|
|
|
def remove_tensor_data(self, tensor):
|
|
|
@@ -114,7 +114,7 @@ class OnnxModel:
|
|
|
del tensor.int32_data[:]
|
|
|
del tensor.int64_data[:]
|
|
|
del tensor.float_data[:]
|
|
|
- tensor.raw_data = b''
|
|
|
+ tensor.raw_data = None
|
|
|
|
|
|
class TensorFlowLiteModel:
|
|
|
def __init__(self, data, file):
|