Просмотр исходного кода

Add TorchScript Mobile test files (#1023)

Lutz Roeder 3 лет назад
Родитель
Сommit
7e465ac2f9
5 измененных файлов с 445 добавлено и 7 удалено
  1. 373 0
      source/pytorch-schema.js
  2. 51 4
      source/pytorch.js
  3. 1 2
      source/view.js
  4. 16 0
      test/models.json
  5. 4 1
      tools/pytorch

+ 373 - 0
source/pytorch-schema.js

@@ -0,0 +1,373 @@
+var $root = flatbuffers.get('torch');
+
+$root.torch = $root.torch || {};
+
+$root.torch.jit = $root.torch.jit || {};
+
+$root.torch.jit.mobile = $root.torch.jit.mobile || {};
+
+$root.torch.jit.mobile.serialization = $root.torch.jit.mobile.serialization || {};
+
+$root.torch.jit.mobile.serialization.Int = class Int {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Int();
+        $.int_val = reader.int64(position + 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Bool = class Bool {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Bool();
+        $.bool_val = reader.bool(position + 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Double = class Double {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Double();
+        $.double_val = reader.float64(position + 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.PerTensorAffineSchema = class PerTensorAffineSchema {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.PerTensorAffineSchema();
+        $.q_scale = reader.float64(position + 0);
+        $.q_zero_point = reader.int32(position + 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.QuantizedSchema = class QuantizedSchema {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.QuantizedSchema();
+        $.qscheme = reader.int8_(position, 4, 0);
+        $.scale = reader.float64_(position, 6, 0);
+        $.zero_point = reader.int32_(position, 8, 0);
+        $.scales = reader.table(position, 10, $root.torch.jit.mobile.serialization.TensorMetadata.decode);
+        $.zero_points = reader.table(position, 12, $root.torch.jit.mobile.serialization.TensorMetadata.decode);
+        $.axis = reader.int32_(position, 14, 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.TensorMetadata = class TensorMetadata {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.TensorMetadata();
+        $.storage_location_index = reader.uint32_(position, 4, 0);
+        $.scalar_type = reader.int8_(position, 6, 0);
+        $.storage_offset = reader.int32_(position, 8, 0);
+        $.sizes = reader.typedArray(position, 10, Int32Array);
+        $.strides = reader.typedArray(position, 12, Int32Array);
+        $.requires_grad = reader.bool_(position, 14, false);
+        $.quantized_schema = reader.table(position, 16, $root.torch.jit.mobile.serialization.QuantizedSchema.decode);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.String = class String {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.String();
+        $.data = reader.string_(position, 4, null);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Device = class Device {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Device();
+        $.str = reader.string_(position, 4, null);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.List = class List {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.List();
+        $.items = reader.typedArray(position, 4, Uint32Array);
+        $.annotation_str = reader.string_(position, 6, null);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.IntList = class IntList {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.IntList();
+        $.items = reader.int64s_(position, 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.DoubleList = class DoubleList {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.DoubleList();
+        $.items = reader.typedArray(position, 4, Float64Array);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.BoolList = class BoolList {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.BoolList();
+        $.items = reader.bools_(position, 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Tuple = class Tuple {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Tuple();
+        $.items = reader.typedArray(position, 4, Uint32Array);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Dict = class Dict {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Dict();
+        $.keys = reader.typedArray(position, 4, Uint32Array);
+        $.values = reader.typedArray(position, 6, Uint32Array);
+        $.annotation_str = reader.string_(position, 8, null);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.TypeType = {
+    UNSET: 0,
+    CLASS_WITH_FIELD: 1,
+    CUSTOM_CLASS: 2,
+    CLASS_WITH_SETSTATE: 3,
+    NON_OBJ: 4
+};
+
+$root.torch.jit.mobile.serialization.ObjectType = class ObjectType {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.ObjectType();
+        $.type_name = reader.string_(position, 4, null);
+        $.type = reader.uint8_(position, 6, 0);
+        $.attr_names = reader.strings_(position, 8);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Object = class Object {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Object();
+        $.type_index = reader.uint32_(position, 4, 0);
+        $.state = reader.uint32_(position, 6, 0);
+        $.attrs = reader.typedArray(position, 8, Uint32Array);
+        $.setstate_func = reader.uint32_(position, 10, 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.ComplexDouble = class ComplexDouble {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.ComplexDouble();
+        $.real = reader.float64(position + 0);
+        $.imag = reader.float64(position + 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.EnumValue = class EnumValue {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.EnumValue();
+        $.type_name = reader.string_(position, 4, null);
+        $.value = reader.uint32_(position, 6, 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Instruction = class Instruction {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Instruction();
+        $.op = reader.int8(position + 0);
+        $.n = reader.uint16(position + 2);
+        $.x = reader.int32(position + 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Operator = class Operator {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Operator();
+        $.name = reader.string_(position, 4, null);
+        $.overload_name = reader.string_(position, 6, null);
+        $.num_args_serialized = reader.int32_(position, 8, -1);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Arg = class Arg {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Arg();
+        $.name = reader.string_(position, 4, null);
+        $.type = reader.string_(position, 6, null);
+        $.default_value = reader.uint32_(position, 8, 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Schema = class Schema {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Schema();
+        $.arguments = reader.tableArray(position, 4, $root.torch.jit.mobile.serialization.Arg.decode);
+        $.returns = reader.tableArray(position, 6, $root.torch.jit.mobile.serialization.Arg.decode);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.DebugInfo = class DebugInfo {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.DebugInfo();
+        $.debug_handle = reader.int64s_(position, 4);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Function = class Function {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Function();
+        $.qn = reader.string_(position, 4, null);
+        $.instructions = reader.structArray(position, 6, undefined,$root.torch.jit.mobile.serialization.Instruction.decode);
+        $.operators = reader.tableArray(position, 8, $root.torch.jit.mobile.serialization.Operator.decode);
+        $.constants = reader.typedArray(position, 10, Uint32Array);
+        $.type_annotations = reader.strings_(position, 12);
+        $.register_size = reader.int32_(position, 14, 0);
+        $.schema = reader.table(position, 16, $root.torch.jit.mobile.serialization.Schema.decode);
+        $.debug_info = reader.table(position, 18, $root.torch.jit.mobile.serialization.DebugInfo.decode);
+        $.class_type = reader.uint32_(position, 20, 0);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.StorageData = class StorageData {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.StorageData();
+        $.data = reader.typedArray(position, 4, Uint8Array);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.IValueUnion = class {
+
+    static decode(reader, position, type) {
+        switch (type) {
+            case 1: return $root.torch.jit.mobile.serialization.Int.decode(reader, position);
+            case 2: return $root.torch.jit.mobile.serialization.Bool.decode(reader, position);
+            case 3: return $root.torch.jit.mobile.serialization.Double.decode(reader, position);
+            case 4: return $root.torch.jit.mobile.serialization.ComplexDouble.decode(reader, position);
+            case 5: return $root.torch.jit.mobile.serialization.TensorMetadata.decode(reader, position);
+            case 6: return $root.torch.jit.mobile.serialization.String.decode(reader, position);
+            case 7: return $root.torch.jit.mobile.serialization.List.decode(reader, position);
+            case 8: return $root.torch.jit.mobile.serialization.Tuple.decode(reader, position);
+            case 9: return $root.torch.jit.mobile.serialization.Dict.decode(reader, position);
+            case 10: return $root.torch.jit.mobile.serialization.Object.decode(reader, position);
+            case 11: return $root.torch.jit.mobile.serialization.IntList.decode(reader, position);
+            case 12: return $root.torch.jit.mobile.serialization.DoubleList.decode(reader, position);
+            case 13: return $root.torch.jit.mobile.serialization.BoolList.decode(reader, position);
+            case 14: return $root.torch.jit.mobile.serialization.Device.decode(reader, position);
+            case 15: return $root.torch.jit.mobile.serialization.EnumValue.decode(reader, position);
+            case 16: return $root.torch.jit.mobile.serialization.Function.decode(reader, position);
+            default: return undefined;
+        }
+    }
+
+    static decodeText(reader, json, type) {
+        switch (type) {
+            case 'Int': return $root.torch.jit.mobile.serialization.Int.decodeText(reader, json);
+            case 'Bool': return $root.torch.jit.mobile.serialization.Bool.decodeText(reader, json);
+            case 'Double': return $root.torch.jit.mobile.serialization.Double.decodeText(reader, json);
+            case 'ComplexDouble': return $root.torch.jit.mobile.serialization.ComplexDouble.decodeText(reader, json);
+            case 'TensorMetadata': return $root.torch.jit.mobile.serialization.TensorMetadata.decodeText(reader, json);
+            case 'String': return $root.torch.jit.mobile.serialization.String.decodeText(reader, json);
+            case 'List': return $root.torch.jit.mobile.serialization.List.decodeText(reader, json);
+            case 'Tuple': return $root.torch.jit.mobile.serialization.Tuple.decodeText(reader, json);
+            case 'Dict': return $root.torch.jit.mobile.serialization.Dict.decodeText(reader, json);
+            case 'Object': return $root.torch.jit.mobile.serialization.Object.decodeText(reader, json);
+            case 'IntList': return $root.torch.jit.mobile.serialization.IntList.decodeText(reader, json);
+            case 'DoubleList': return $root.torch.jit.mobile.serialization.DoubleList.decodeText(reader, json);
+            case 'BoolList': return $root.torch.jit.mobile.serialization.BoolList.decodeText(reader, json);
+            case 'Device': return $root.torch.jit.mobile.serialization.Device.decodeText(reader, json);
+            case 'EnumValue': return $root.torch.jit.mobile.serialization.EnumValue.decodeText(reader, json);
+            case 'Function': return $root.torch.jit.mobile.serialization.Function.decodeText(reader, json);
+            default: return undefined;
+        }
+    }
+};
+
+$root.torch.jit.mobile.serialization.IValue = class IValue {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.IValue();
+        $.val = reader.union(position, 4, $root.torch.jit.mobile.serialization.IValueUnion.decode);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.ExtraFile = class ExtraFile {
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.ExtraFile();
+        $.name = reader.string_(position, 4, null);
+        $.content = reader.string_(position, 6, null);
+        return $;
+    }
+};
+
+$root.torch.jit.mobile.serialization.Module = class Module {
+
+    static identifier(reader) {
+        return reader.identifier === 'PTMF';
+    }
+
+    static create(reader) {
+        return $root.torch.jit.mobile.serialization.Module.decode(reader, reader.root);
+    }
+
+    static decode(reader, position) {
+        const $ = new $root.torch.jit.mobile.serialization.Module();
+        $.bytecode_version = reader.uint32_(position, 4, 0);
+        $.extra_files = reader.tableArray(position, 6, $root.torch.jit.mobile.serialization.ExtraFile.decode);
+        $.methods = reader.typedArray(position, 8, Uint32Array);
+        $.state_obj = reader.uint32_(position, 10, 0);
+        $.ivalues = reader.tableArray(position, 12, $root.torch.jit.mobile.serialization.IValue.decode);
+        $.storage_data_size = reader.int32_(position, 14, 0);
+        $.storage_data = reader.tableArray(position, 16, $root.torch.jit.mobile.serialization.StorageData.decode);
+        $.object_types = reader.tableArray(position, 18, $root.torch.jit.mobile.serialization.ObjectType.decode);
+        $.jit_sources = reader.tableArray(position, 20, $root.torch.jit.mobile.serialization.ExtraFile.decode);
+        $.jit_constants = reader.typedArray(position, 22, Uint32Array);
+        $.operator_version = reader.uint32_(position, 24, 0);
+        $.mobile_ivalue_size = reader.uint32_(position, 26, 0);
+        return $;
+    }
+};

+ 51 - 4
source/pytorch.js

@@ -4,6 +4,7 @@
 var pytorch = {};
 var python = require('./python');
 var base = require('./base');
+var flatbuffers = require('./flatbuffers');
 
 pytorch.ModelFactory = class {
 
@@ -15,11 +16,13 @@ pytorch.ModelFactory = class {
         const identifier = context.identifier;
         return pytorch.Metadata.open(context).then((metadata) => {
             const container = match;
-            container.metadata = metadata;
-            container.on('resolve', (_, name) => {
-                context.exception(new pytorch.Error("Unknown type name '" + name + "' in '" + identifier + "'."), false);
+            return container.read().then(() => {
+                container.metadata = metadata;
+                container.on('resolve', (_, name) => {
+                    context.exception(new pytorch.Error("Unknown type name '" + name + "' in '" + identifier + "'."), false);
+                });
+                return new pytorch.Model(metadata, container);
             });
-            return new pytorch.Model(metadata, container);
         });
     }
 };
@@ -1041,6 +1044,10 @@ pytorch.Container = class {
         if (torch_utils) {
             return torch_utils;
         }
+        const mobile = pytorch.Container.Mobile.open(context);
+        if (mobile) {
+            return mobile;
+        }
         return null;
     }
 
@@ -1049,6 +1056,10 @@ pytorch.Container = class {
         this._events = [];
     }
 
+    read() {
+        return Promise.resolve();
+    }
+
     set metadata(value) {
         this._metadata = value;
     }
@@ -1208,6 +1219,42 @@ pytorch.Container.torch_utils = class extends pytorch.Container {
     }
 };
 
+pytorch.Container.Mobile = class extends pytorch.Container {
+
+    static open(context) {
+        const tags = context.tags('flatbuffers');
+        if (tags.get('file_identifier') === 'PTMF') {
+            return new pytorch.Container.Mobile(context);
+        }
+        return null;
+    }
+
+    constructor(context) {
+        super();
+        this._context = context;
+        this._graphs = [];
+    }
+
+    read() {
+        return this._context.require('./pytorch-schema').then(() => {
+            pytorch.schema = flatbuffers.get('torch').torch.jit.mobile.serialization;
+            // const stream = this._context.stream;
+            // const reader = flatbuffers.BinaryReader.open(stream);
+            // const model = pytorch.schema.Module.create(reader);
+            delete this._context;
+            throw new pytorch.Error('torch.jit.mobile.serialization.Module not supported.');
+        });
+    }
+
+    get format() {
+        return 'TorchScript Mobile';
+    }
+
+    get graphs() {
+        return this._graphs;
+    }
+};
+
 pytorch.Container.Zip = class extends pytorch.Container {
 
     static open(entries) {

+ 1 - 2
source/view.js

@@ -1606,7 +1606,7 @@ view.ModelFactoryService = class {
         this._extensions = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]);
         this._factories = [];
         this.register('./server', [ '.netron']);
-        this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt' ], [ '.model' ]);
+        this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf' ], [ '.model' ]);
         this.register('./onnx', [ '.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel', 'ngf' ]);
         this.register('./mxnet', [ '.json', '.params' ], [ '.mar'] );
         this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb' ], [ '.mlpackage' ]);
@@ -1857,7 +1857,6 @@ view.ModelFactoryService = class {
                 const formats = [
                     { name: 'onnxruntime.experimental.fbs.InferenceSession data', identifier: 'ORTM' },
                     { name: 'tflite.Model data', identifier: 'TFL3' },
-                    { name: 'torch.jit.mobile.serialization.Module data', identifier: 'PTMF' }, // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/serialization/mobile_bytecode.fbs
                     { name: 'FlatBuffers ENNC data', identifier: 'ENNC' },
                 ];
                 for (const format of formats) {

+ 16 - 0
test/models.json

@@ -4836,6 +4836,14 @@
     "format":   "TorchScript v1.6",
     "link":     "https://github.com/PeterL1n/RobustVideoMatting"
   },
+  {
+    "type":     "pytorch",
+    "target":   "scriptmodule.ff",
+    "source":   "https://github.com/lutzroeder/netron/files/10230651/scriptmodule.ff.zip[scriptmodule.ff]",
+    "format":   "TorchScript Mobile",
+    "error":    "torch.jit.mobile.serialization.Module not supported in 'scriptmodule.ff'.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1023"
+  },
   {
     "type":     "pytorch",
     "target":   "segmentor.pt",
@@ -4884,6 +4892,14 @@
     "format":   "TorchScript v1.6",
     "link":     "https://github.com/lutzroeder/netron/issues/281"
   },
+  {
+    "type":     "pytorch",
+    "target":   "squeezenet1_1_traced.ff",
+    "source":   "https://github.com/lutzroeder/netron/files/10230684/squeezenet1_1_traced.ff.zip[squeezenet1_1_traced.ff]",
+    "format":   "TorchScript Mobile",
+    "error":    "torch.jit.mobile.serialization.Module not supported in 'squeezenet1_1_traced.ff'.",
+    "link":     "https://github.com/lutzroeder/netron/issues/1023"
+  },
   {
     "type":     "pytorch",
     "target":   "squeezenet1_1.pt",

+ 4 - 1
tools/pytorch

@@ -20,9 +20,11 @@ sync() {
 }
 
 schema() {
-    echo "caffe2 schema"
     [[ $(grep -U $'\x0D' ./source/caffe2-proto.js) ]] && crlf=1
+    echo "caffe2 schema"
     node ./tools/protoc.js --text --root caffe2 --out ./source/caffe2-proto.js ./third_party/source/pytorch/caffe2/proto/caffe2.proto
+    echo "pytorch schema"
+    node ./tools/flatc.js --root torch --out ./source/pytorch-schema.js ./third_party/source/pytorch/torch/csrc/jit/serialization/mobile_bytecode.fbs
     if [[ -n ${crlf} ]]; then
         unix2dos --quiet --newfile ./source/caffe2-proto.js ./source/caffe2-proto.js
     fi
@@ -42,6 +44,7 @@ while [ "$#" != 0 ]; do
     case "${command}" in
         "clean") clean;;
         "sync") sync;;
+        "schema") schema;;
         "metadata") metadata;;
     esac
 done