Przeglądaj źródła

Add ExecuTorch test file (#1175)

Lutz Roeder 1 rok temu
rodzic
commit
7a2712e72e
3 zmienionych plików z 262 dodań i 69 usunięć
  1. 254 68
      source/executorch.js
  2. 1 1
      source/view.js
  3. 7 0
      test/models.json

+ 254 - 68
source/executorch.js

@@ -7,7 +7,6 @@ const vulkan = {};
 const xnnpack = {};
 
 import * as base from './base.js';
-import * as flatbuffers from './flatbuffers.js';
 import * as python from './python.js';
 import * as pytorch from './pytorch.js';
 
@@ -23,23 +22,20 @@ executorch.ModelFactory = class {
 
     async open(context) {
         executorch.schema = await context.require('./executorch-schema');
-        const metadata = await pytorch.Metadata.open(context);
-        const execution = new python.Execution();
-        metadata.register(execution);
-        const reader = context.target;
-        await reader.read();
-        return new executorch.Model(execution, reader);
+        const target = context.target;
+        await target.read();
+        return new executorch.Model(target);
     }
 };
 
 executorch.Model = class {
 
-    constructor(execution, reader) {
-        this.format = `ExecuTorch v${reader.program.version}`;
+    constructor(target) {
+        this.format = `ExecuTorch v${target.program.version}`;
         this.graphs = [];
-        for (const plan of reader.program.execution_plan) {
+        for (const plan of target.program.execution_plan) {
             for (const chain of plan.chains) {
-                const graph = new executorch.Graph(execution, reader, plan, chain);
+                const graph = new executorch.Graph(target, plan, chain);
                 this.graphs.push(graph);
             }
         }
@@ -48,7 +44,7 @@ executorch.Model = class {
 
 executorch.Graph = class {
 
-    constructor(execution, reader, plan, chain) {
+    constructor(target, plan, chain) {
         this.inputs = [];
         this.outputs = [];
         this.nodes = [];
@@ -69,7 +65,7 @@ executorch.Graph = class {
                         const type = new executorch.TensorType(tensor);
                         let initializer = null;
                         if (v.data_buffer_idx > 0) {
-                            initializer = new executorch.Tensor(tensor, reader);
+                            initializer = new executorch.Tensor(tensor, target);
                         }
                         const identifier = tensors.length > 1 ? `${index}.${i}` : index.toString();
                         const value = new executorch.Value(identifier, type, initializer);
@@ -108,7 +104,7 @@ executorch.Graph = class {
             this.outputs.push(argument);
         }
         for (const instruction of chain.instructions) {
-            const node = new executorch.Node(execution, reader, plan, chain, instruction, values);
+            const node = new executorch.Node(target, plan, chain, instruction, values);
             this.nodes.push(node);
         }
     }
@@ -138,7 +134,7 @@ executorch.Value = class Value {
 
 executorch.Node = class {
 
-    constructor(execution, reader, plan, chain, instruction, values) {
+    constructor(target, plan, chain, instruction, values) {
         this.name = '';
         this.inputs = [];
         this.outputs = [];
@@ -149,7 +145,7 @@ executorch.Node = class {
             const op = plan.operators[instr_args.op_index];
             const name = op.name.split('::').pop();
             const identifier = op.overload ? `${op.name}.${op.overload}` : op.name;
-            const schemas = execution.invoke('torch._C._jit_get_schemas_for_operator', [op.name]);
+            const schemas = target.execution.invoke('torch._C._jit_get_schemas_for_operator', [op.name]);
             const schema = schemas.find((schema) => schema.name === op.name && schema.overload_name === op.overload);
             if (!schema) {
                 throw new executorch.Error(`Operator schema for '${identifier}' not found.`);
@@ -187,26 +183,9 @@ executorch.Node = class {
         } else if (instr_args instanceof executorch_flatbuffer.DelegateCall) {
             const delegate = plan.delegates[instr_args.delegate_index];
             const args = instr_args.args;
-            const name = delegate.id;
-            let data = null;
-            switch (delegate.processed.location) {
-                case executorch_flatbuffer.DataLocation.INLINE: {
-                    data = reader.program.backend_delegate_data[delegate.processed.index].data;
-                    break;
-                }
-                case executorch_flatbuffer.DataLocation.SEGMENT: {
-                    const segment = reader.program.segments[delegate.processed.index];
-                    data = reader.blob(segment.offset.toNumber(), segment.size.toNumber());
-                    break;
-                }
-                default: {
-                    throw new executorch.Error(`Delegate data location '${delegate.processed.location}' not implemented.`);
-                }
-            }
-            switch (name) {
+            switch (delegate.id) {
                 case 'XnnpackBackend': {
-                    const reader = xnnpack.Reader.open(data);
-                    this.type = reader.read();
+                    this.type = delegate.backend.type || { name: delegate.id };
                     for (const arg of args.slice(0, this.type.inputs.length)) {
                         const value = values.map(arg);
                         const argument = new executorch.Argument('', value.value, value.type);
@@ -220,27 +199,23 @@ executorch.Node = class {
                     break;
                 }
                 case 'CoreMLBackend': {
-                    const reader = coreml.Reader.open(data);
-                    reader.read();
+                    this.type = delegate.backend.type || { name: delegate.id };
                     const input = values.map(args[0]);
                     const output = values.map(args[1], true);
                     this.inputs.push(new executorch.Argument('input', input.value, input.type));
                     this.outputs.push(new executorch.Argument('output', output.value, output.type));
-                    this.type = { name };
                     break;
                 }
                 case 'VulkanBackend': {
-                    const reader = vulkan.Reader.open(data);
-                    reader.read();
+                    this.type = delegate.backend.type || { name: delegate.id };
                     const input = values.map(args[0]);
                     const output = values.map(args[1], true);
                     this.inputs.push(new executorch.Argument('input', input.value, input.type));
                     this.outputs.push(new executorch.Argument('output', output.value, output.type));
-                    this.type = { name };
                     break;
                 }
                 default: {
-                    throw new executorch.Error(`ExecuTorch delegate '${name}' not implemented.`);
+                    throw new executorch.Error(`ExecuTorch delegate '${delegate.id}' not implemented.`);
                 }
             }
             for (const spec of delegate.compile_specs) {
@@ -297,10 +272,10 @@ executorch.TensorShape = class {
 
 executorch.Tensor = class {
 
-    constructor(tensor, reader) {
+    constructor(tensor, target) {
         this.type = new executorch.TensorType(tensor);
         const data_buffer_idx = tensor.data_buffer_idx;
-        const program = reader.program;
+        const program = target.program;
         if (tensor.extra_tensor_info) {
             throw new executorch.Error('Extra tensor info not implemented.');
         } else if (program.constant_buffers) {
@@ -311,7 +286,7 @@ executorch.Tensor = class {
             const offset = constant_segment.offsets[data_buffer_idx].toNumber();
             const next = data_buffer_idx + 1 < constant_segment.offsets.length ? constant_segment.offsets[data_buffer_idx + 1].toNumber() : data_segment.size.toNumber();
             const size = next - offset;
-            this.values = reader.blob(data_segment.offset.toNumber() + offset, size);
+            this.values = target.blob(data_segment.offset.toNumber() + offset, size);
             this.encoding = '<';
         } else {
             throw new executorch.Error('Tensor allocation info not implemented.');
@@ -335,7 +310,11 @@ executorch.Reader = class {
     }
 
     async read() {
-        this.program = executorch.schema.executorch_flatbuffer.Program.create(this.reader);
+        this.metadata = await pytorch.Metadata.open(this.context);
+        this.execution = new python.Execution();
+        this.metadata.register(this.execution);
+        const executorch_flatbuffer = executorch.schema.executorch_flatbuffer;
+        this.program = executorch_flatbuffer.Program.create(this.reader);
         this.reader = this.context.read('binary');
         if (this.reader.length >= 32) {
             this.reader.seek(8);
@@ -349,6 +328,54 @@ executorch.Reader = class {
             }
             this.reader.seek(0);
         }
+        for (const plan of this.program.execution_plan) {
+            for (const chain of plan.chains) {
+                for (const instruction of chain.instructions) {
+                    const instr_args = instruction.instr_args;
+                    if (instr_args instanceof executorch_flatbuffer.DelegateCall) {
+                        const delegate = plan.delegates[instr_args.delegate_index];
+                        if (delegate.backend) {
+                            continue;
+                        }
+                        let data = null;
+                        switch (delegate.processed.location) {
+                            case executorch_flatbuffer.DataLocation.INLINE: {
+                                data = this.program.backend_delegate_data[delegate.processed.index].data;
+                                break;
+                            }
+                            case executorch_flatbuffer.DataLocation.SEGMENT: {
+                                const segment = this.program.segments[delegate.processed.index];
+                                data = this.blob(segment.offset.toNumber(), segment.size.toNumber());
+                                break;
+                            }
+                            default: {
+                                throw new executorch.Error(`Delegate data location '${delegate.processed.location}' not implemented.`);
+                            }
+                        }
+                        switch (delegate.id) {
+                            case 'XnnpackBackend': {
+                                delegate.backend = xnnpack.Reader.open(data, this);
+                                break;
+                            }
+                            case 'CoreMLBackend': {
+                                delegate.backend = coreml.Reader.open(data, this);
+                                break;
+                            }
+                            case 'VulkanBackend': {
+                                delegate.backend = vulkan.Reader.open(data, this);
+                                break;
+                            }
+                            default: {
+                                throw new executorch.Error(`ExecuTorch delegate '${delegate.id}' not implemented.`);
+                            }
+                        }
+                        /* eslint-disable no-await-in-loop */
+                        await delegate.backend.read();
+                        /* eslint-enable no-await-in-loop */
+                    }
+                }
+            }
+        }
     }
 
     blob(offset, size) {
@@ -372,20 +399,21 @@ executorch.Error = class extends Error {
 
 xnnpack.Reader = class {
 
-    static open(data) {
+    static open(data, target) {
         if (data.length >= 30) {
             const reader = base.BinaryReader.open(data);
             reader.skip(4);
             const magic = String.fromCharCode(...reader.read(4));
             if (magic === 'XH00') {
-                return new xnnpack.Reader(reader);
+                return new xnnpack.Reader(reader, target);
             }
         }
         return null;
     }
 
-    constructor(reader) {
+    constructor(reader, target) {
         this.reader = reader;
+        this.target = target;
         reader.skip(2);
         this.flatbuffer = {
             offset: reader.uint32(),
@@ -397,8 +425,9 @@ xnnpack.Reader = class {
         };
     }
 
-    read() {
+    async read() {
         this.reader.seek(this.flatbuffer.offset);
+        const flatbuffers = await import('./flatbuffers.js');
         const data = this.reader.read(this.flatbuffer.size);
         const reader = flatbuffers.BinaryReader.open(data);
         if (!executorch.schema.fb_xnnpack.XNNGraph.identifier(reader)) {
@@ -407,7 +436,7 @@ xnnpack.Reader = class {
         this.graph = executorch.schema.fb_xnnpack.XNNGraph.create(reader);
         this.reader.seek(0);
         const metadata = new xnnpack.Metadata();
-        return new xnnpack.Graph(metadata, this.graph, this);
+        this.type = new xnnpack.Graph(metadata, this.graph, this);
     }
 
     constant(idx) {
@@ -540,8 +569,8 @@ xnnpack.TensorType = class {
             'qcint8', 'qcint32', 'qcint4',
             'qdint8', 'qbint4'
         ];
-        if (tensor.datatype >= executorch.TensorType._types.length) {
-            throw new executorch.Error(`Unknown tensor data type '${tensor.datatype}'.`);
+        if (tensor.datatype >= xnnpack.TensorType._types.length) {
+            throw new xnnpack.Error(`Unknown tensor data type '${tensor.datatype}'.`);
         }
         this.dataType = xnnpack.TensorType._types[tensor.datatype];
         this.shape = new xnnpack.TensorShape(Array.from(tensor.dims));
@@ -585,20 +614,101 @@ xnnpack.Error = class extends Error {
 
 vulkan.Reader = class {
 
-    static open(data) {
-        const reader = flatbuffers.BinaryReader.open(data);
-        if (executorch.schema.vkgraph.XNNGraph.identifier(reader)) {
-            return new vulkan.Reader(reader);
+    static open(data, target) {
+        if (data.length >= 30) {
+            const reader = base.BinaryReader.open(data);
+            reader.skip(4);
+            const magic = String.fromCharCode(...reader.read(4));
+            if (magic === 'VH00') {
+                return new vulkan.Reader(reader, target);
+            }
         }
         return null;
     }
 
-    constructor(reader) {
+    constructor(reader, target) {
         this.reader = reader;
+        this.target = target;
+        reader.skip(2);
+        this.flatbuffer = {
+            offset: reader.uint32(),
+            size: reader.uint32(),
+        };
+        this.constants = {
+            offset: reader.uint32(),
+            size: reader.uint32(),
+        };
+    }
+
+    async read() {
+        this.reader.seek(this.flatbuffer.offset);
+        const metadata = new vulkan.Metadata(this.target.execution);
+        metadata.register('conv_with_clamp(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor)');
+        const flatbuffers = await import('./flatbuffers.js');
+        const data = this.reader.read(this.flatbuffer.size);
+        const reader = flatbuffers.BinaryReader.open(data);
+        if (!executorch.schema.vkgraph.VkGraph.identifier(reader)) {
+            throw new xnnpack.Error('Invalid Vuklan data.');
+        }
+        this.graph = executorch.schema.vkgraph.VkGraph.create(reader);
+        this.reader.seek(0);
+        this.type = new vulkan.Graph(metadata, this.graph, this);
+    }
+};
+
+vulkan.Graph = class {
+
+    constructor(metadata, graph /*, reader */) {
+        this.name = 'VulkanBackend';
+        this.inputs = [];
+        this.outputs = [];
+        this.nodes = [];
+        for (const op of graph.chain) {
+            const node = new vulkan.Node(metadata, op);
+            this.nodes.push(node);
+        }
+    }
+};
+
+vulkan.Node = class {
+
+    constructor(metadata, op) {
+        const schema = metadata.type(op.name);
+        this.type = {
+            name: schema.name,
+            identifier: op.name
+        };
+        if (schema.category) {
+            this.type.category = schema.category;
+        }
+        this.name = '';
+        this.inputs = [];
+        this.outputs = [];
+        this.attributes = [];
+    }
+};
+
+vulkan.Metadata = class {
+
+    constructor(execution) {
+        this.execution = execution;
+    }
+
+    register(signature) {
+        const torch = this.execution.register('torch');
+        const registry = torch._C.getRegistry();
+        const schema = torch.FunctionSchema.parse(signature);
+        const op = new torch._C.Operator(schema);
+        registry.registerOperator(op);
     }
 
-    read() {
-        /* const graph = */ executorch.schema.vkgraph.XNNGraph.create(this.reader);
+    type(identifier) {
+        identifier = identifier.split(/\.([^.]*)$/);
+        const name = identifier[0].replace('.', '::');
+        const overload = identifier[1] === 'default' ? '' : identifier[1];
+        const schemas = this.execution.invoke('torch._C._jit_get_schemas_for_operator', [name]);
+        const schema = schemas.find((schema) => schema.name === name && schema.overload_name === overload);
+        return schema;
     }
 };
 
@@ -612,20 +722,44 @@ vulkan.Error = class extends Error {
 
 coreml.Reader = class {
 
-    static open(data) {
+    static open(data, target) {
         const reader = base.BinaryReader.open(data);
-        return new coreml.Reader(reader);
+        return new coreml.Reader(reader, target);
     }
 
-    constructor(reader) {
+    constructor(reader, target) {
         this.reader = reader;
+        this.target = target;
     }
 
-    read() {
-        this.files(this.reader);
+    async factory() {
+        const coreml = await import('./coreml.js');
+        return new coreml.ModelFactory();
     }
 
-    files(reader) {
+    async read() {
+        const entries = this.entries(this.reader);
+        const factory = await this.factory();
+        const protobuf = await import('./protobuf.js');
+        for (const [key, value] of entries) {
+            const identifier = key.split('/').pop();
+            this.reader.seek(value.offset);
+            const stream = this.reader.stream(value.size);
+            this.reader.seek(0);
+            const context = new coreml.Context(this.target.context, identifier, stream, entries, protobuf);
+            factory.match(context);
+            if (context.type === 'coreml.pb') {
+                /* eslint-disable no-await-in-loop */
+                const model = await factory.open(context);
+                /* eslint-enable no-await-in-loop */
+                [this.type] = model.graphs;
+                this.type.name = 'CoreMLBackend';
+                return;
+            }
+        }
+    }
+
+    entries(reader) {
         const files = new Map();
         reader.seek(reader.length - 1);
         const str = [];
@@ -665,9 +799,61 @@ coreml.Reader = class {
         for (const root of roots.filter((node) => node !== null)) {
             process('', root);
         }
-        return files;
+        if (!Array.from(files.keys()).every((key) => key.startsWith('lowered_module/'))) {
+            throw new executorch.Error('');
+        }
+        const entries = new Map(Array.from(files).map(([key, value]) => {
+            return [key.replace(/^lowered_module\//, ''), value];
+        }));
+        return entries;
+    }
+};
+
+coreml.Context = class {
+
+    constructor(context, identifier, stream, entries, protobuf) {
+        this.context = context;
+        this.identifier = identifier;
+        this.stream = stream;
+        this.entries = entries;
+        this.protobuf = protobuf;
+    }
+
+    tags(type) {
+        if (type === 'pb' && this.identifier.endsWith('.mlmodel')) {
+            return new Map([[1,0],[2,2]]);
+        }
+        return new Map();
     }
 
+    peek(type) {
+        if (type === 'json') {
+            const data = this.stream.peek();
+            const decoder = new TextDecoder('utf-8');
+            const text = decoder.decode(data);
+            return JSON.parse(text);
+        }
+        return null;
+    }
+
+    read(type) {
+        if (type === 'protobuf.binary') {
+            return this.protobuf.BinaryReader.open(this.stream);
+        }
+        return null;
+    }
+
+    async fetch(/* file */) {
+        return Promise.resolve(null);
+    }
+
+    async require(id) {
+        return this.context.require(id);
+    }
+
+    async metadata(name) {
+        return this.context.metadata(name);
+    }
 };
 
-export const ModelFactory = executorch.ModelFactory;
+export const ModelFactory = executorch.ModelFactory;

+ 1 - 1
source/view.js

@@ -5706,7 +5706,7 @@ view.Context = class {
         return this._tags.get(type);
     }
 
-    metadata(name) {
+    async metadata(name) {
         return view.Metadata.open(this, name);
     }
 };

+ 7 - 0
test/models.json

@@ -2189,6 +2189,13 @@
     "format":   "ExecuTorch v0",
     "link":     "https://github.com/lutzroeder/netron/issues/1175"
   },
+  {
+    "type":     "executorch",
+    "target":   "mv2_vulkan.pte",
+    "source":   "https://github.com/user-attachments/files/18514115/mv2_vulkan.pte.zip[mv2_vulkan.pte]",
+    "format":   "ExecuTorch v0",
+    "link":     "https://github.com/lutzroeder/netron/issues/1175"
+  },
   {
     "type":     "executorch",
     "target":   "style_transfer_candy_coreml.pte",