| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- """ PyTorch backend """
- import json
- import os
- class ModelFactory:
- """ PyTorch backend model factory """
- def open(self, model):
- metadata = {}
- metadata_files = [
- ("pytorch-metadata.json", ""),
- ("onnx-metadata.json", "onnx::")
- ]
- path = os.path.dirname(__file__)
- for entry in metadata_files:
- file = os.path.join(path, entry[0])
- with open(file, encoding="utf-8") as handle:
- for item in json.load(handle):
- name = entry[1] + item["name"].split("(", 1)[0]
- metadata[name] = item
- metadata = Metadata(metadata)
- return _Model(metadata, model)
- class _Model:
- def __init__(self, metadata, model):
- self.graph = _Graph(metadata, model)
- def to_json(self):
- """ Serialize model to JSON message """
- import torch
- json_model = {
- "signature": "netron:pytorch",
- "format": "TorchScript v" + torch.__version__,
- "graphs": [ self.graph.to_json() ]
- }
- return json_model
- class _Graph:
- def __init__(self, metadata, model):
- self.metadata = metadata
- self.param = model
- self.value = model.graph
- self.nodes = []
- def _getattr(self, node):
- if node.kind() == "prim::Param":
- return (self.param, "")
- if node.kind() == "prim::GetAttr":
- name = node.s("name")
- obj, parent = self._getattr(node.input().node())
- value = getattr(obj, name)
- path = parent + "." + name if len(parent) > 0 else name
- return (value, path)
- raise NotImplementedError()
- def to_json(self):
- import torch
- graph = self.value
- json_graph = {
- "values": [],
- "nodes": [],
- "inputs": [],
- "outputs": []
- }
- data_type_map = dict([
- [ torch.float16, "float16"],
- [ torch.float32, "float32"],
- [ torch.float64, "float64"],
- [ torch.int32, "int32"],
- [ torch.int64, "int64"],
- ])
- def constant_value(node):
- if node.hasAttribute("value"):
- selector = node.kindOf("value")
- return getattr(node, selector)("value")
- return None
- values_index = {}
- def argument(value):
- if value not in values_index:
- json_value = {}
- json_value["name"] = str(value.unique())
- node = value.node()
- if node.kind() == "prim::GetAttr":
- tensor, name = self._getattr(node)
- if tensor is not None and len(name) > 0 and \
- isinstance(tensor, torch.Tensor):
- json_tensor_shape = {
- "dimensions": list(tensor.shape)
- }
- tensor_type = {
- "dataType": data_type_map[tensor.dtype],
- "shape": json_tensor_shape
- }
- json_value["name"] = name
- json_value["type"] = tensor_type
- json_value["initializer"] = { "type": tensor_type }
- elif node.kind() == "prim::Constant":
- tensor = constant_value(node)
- if tensor and isinstance(tensor, torch.Tensor):
- json_tensor_shape = {
- "dimensions": list(tensor.shape)
- }
- tensor_type = {
- "dataType": data_type_map[tensor.dtype],
- "shape": json_tensor_shape
- }
- json_value["type"] = tensor_type
- json_value["initializer"] = { "type": tensor_type }
- elif value.isCompleteTensor():
- json_tensor_shape = {
- "dimensions": value.type().sizes()
- }
- json_value["type"] = {
- "dataType": data_type_map[value.type().dtype()],
- "shape": json_tensor_shape
- }
- values = json_graph["values"]
- values_index[value] = len(values)
- values.append(json_value)
- return values_index[value]
- for value in graph.inputs():
- if len(value.uses()) != 0 and value.type().kind() != "ClassType":
- json_graph["inputs"].append({
- "name": value.debugName(),
- "value": [ argument(value) ]
- })
- for value in graph.outputs():
- json_graph["outputs"].append({
- "name": value.debugName(),
- "value": [ argument(value) ]
- })
- constants = {}
- for node in graph.nodes():
- if node.kind() == "prim::Constant":
- constants[node] = 0
- lists = {}
- for node in graph.nodes():
- if node.kind() == "prim::ListConstruct":
- if all(_.node() in constants for _ in node.inputs()):
- for _ in node.inputs():
- constants[_.node()] += 1
- lists[node] = 0
- def create_node(node):
- identifier = node.schema()
- schema, category = self.metadata.type(identifier)
- json_node = {
- "type": {
- "name": node.kind(),
- "category": category
- },
- "inputs": [],
- "outputs": [],
- "attributes": []
- }
- json_graph["nodes"].append(json_node)
- for name in node.attributeNames():
- selector = node.kindOf(name)
- value = getattr(node, selector)(name)
- json_attribute = {
- "name": name,
- "value": value
- }
- if torch.is_tensor(value):
- json_node["inputs"].append({
- "name": name,
- "value": []
- })
- else:
- json_node["attributes"].append(json_attribute)
- for i, value in enumerate(node.inputs()):
- arg = None
- if schema and i < len(schema.arguments):
- arg = schema.arguments[i]
- parameter_name = arg.name if arg else "input"
- real_type = arg.real_type if arg else None
- input_node = value.node()
- if input_node in constants:
- if (real_type and real_type.kind() == "TensorType") or \
- value.type().kind() == "TensorType":
- json_node["inputs"].append({
- "name": parameter_name,
- "value": [ argument(value) ]
- })
- else:
- json_attribute = {
- "name": parameter_name,
- "value": constant_value(input_node)
- }
- if real_type:
- json_attribute["type"] = self._argument_type(real_type)
- json_node["attributes"].append(json_attribute)
- constants[input_node] = constants[input_node] + 1
- continue
- if input_node in lists:
- value = [ constant_value(_.node()) for _ in input_node.inputs() ]
- json_attribute = {
- "name": parameter_name,
- "value": value
- }
- json_node["attributes"].append(json_attribute)
- lists[input_node] += 1
- continue
- if input_node.kind() == "prim::TupleUnpack":
- continue
- if input_node.kind() == "prim::TupleConstruct":
- continue
- json_node["inputs"].append({
- "name": parameter_name,
- "value": [ argument(value) ]
- })
- for i, value in enumerate(node.outputs()):
- ret = schema.returns[i] if schema and i < len(schema.returns) else None
- name = ret.name if ret else "output"
- json_node["outputs"].append({
- "name": name,
- "value": [ argument(value) ]
- })
- for node in graph.nodes():
- if node in lists:
- continue
- if node in constants:
- continue
- if node.kind() == "prim::GetAttr":
- continue
- create_node(node)
- for node in graph.nodes():
- if node.kind() == "prim::Constant" and \
- node in constants and constants[node] != len(node.output().uses()):
- create_node(node)
- if node.kind() == "prim::ListConstruct" and \
- node in lists and lists[node] != len(node.output().uses()):
- create_node(node)
- return json_graph
- def _argument_type(self, value):
- if value.kind() == "TensorType":
- return "Tensor"
- if value.kind() == "OptionalType":
- element_type = self._argument_type(value.getElementType())
- return f"{element_type}?"
- if value.kind() == "ListType":
- element_type = self._argument_type(value.getElementType())
- size = str(value.size) if hasattr(value, "size") else ""
- return f"{element_type}[{size}]"
- if value.kind() == "DictType":
- key_type = self._argument_type(value.getKeyType())
- value_type = self._argument_type(value.getValueType())
- return f"Dict({key_type}, {value_type})"
- if value.kind() == "TupleType":
- elements = []
- for element in value.elements():
- elements.append(self._argument_type(element))
- return f"({', '.join(elements)})"
- if value.kind() == "IntType":
- return "int64"
- if value.kind() == "SymIntType":
- return "SymInt"
- if value.kind() == "FloatType":
- return "float32"
- if value.kind() == "BoolType":
- return "boolean"
- if value.kind() == "StringType":
- return "string"
- if value.kind() == "NumberType":
- return "Scalar"
- if value.kind() == "ScalarTypeType":
- return "ScalarType"
- if value.kind() == "LayoutType":
- return "Layout"
- if value.kind() == "MemoryFormatType":
- return "MemoryFormat"
- if value.kind() == "DeviceObjType":
- return "Device"
- if value.kind() == "GeneratorType":
- return "Generator"
- if value.kind() == "VarType":
- return value.annotation_str
- raise NotImplementedError()
- class Metadata:
- def __init__(self, metadata):
- self.types = metadata
- def type(self, identifier):
- if identifier == "(no schema)":
- return (None, "")
- key = identifier.split("(", 1)[0]
- value = self.types.get(key)
- category = value["category"] if value and "category" in value else ""
- name, overload_name = key.split(".", 1) if key.find(".") > 0 else (key, "")
- import torch
- schema = torch._C._get_schema(name, overload_name)
- return (schema, category)
|