瀏覽代碼

Update pytorch.js

Lutz Roeder 2 年之前
父節點
當前提交
834d02a3d7
共有 1 個文件被更改,包括 152 次插入180 次删除
  1. 152 180
      source/pytorch.js

+ 152 - 180
source/pytorch.js

@@ -47,10 +47,10 @@ pytorch.Graph = class {
         const values = new Map();
         values.map = (name, type, tensor) => {
             if (tensor) {
-                return new pytorch.Value(name, type || null, tensor);
+                return new pytorch.Value(name, type, null, tensor);
             }
             if (!values.has(name)) {
-                values.set(name, new pytorch.Value(name, type || null, tensor || null));
+                values.set(name, new pytorch.Value(name, type, null, tensor));
             } else if (type || tensor) {
                 throw new pytorch.Error(`Duplicate value '${name}'.`);
             }
@@ -265,13 +265,14 @@ pytorch.Argument = class {
 
 pytorch.Value = class {
 
-    constructor(name, type, initializer) {
+    constructor(name, type, quantization, initializer) {
         if (typeof name !== 'string') {
             throw new pytorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
         }
         this.name = name;
-        this.type = initializer ? initializer.type : type;
-        this.initializer = initializer;
+        this.type = initializer && initializer.type ? initializer.type : type || null;
+        this.quantization = quantization;
+        this.initializer = initializer || null;
     }
 };
 
@@ -379,7 +380,7 @@ pytorch.Node = class {
                 const values = list.filter((value) => value !== null).map((value) => {
                     const identifier = value && value.name ? value.name : '';
                     const tensor = value ? new pytorch.Tensor(identifier, value) : null;
-                    return new pytorch.Value(identifier, null, tensor);
+                    return new pytorch.Value(identifier, null, null, tensor);
                 });
                 const argument = new pytorch.Argument(name, values, null, visible);
                 this.inputs.push(argument);
@@ -554,7 +555,7 @@ pytorch.Node = class {
                                         identifier = initializer ? initializer.name : identifier;
                                     }
                                     if (initializer) {
-                                        return new pytorch.Value(identifier, null, initializer);
+                                        return new pytorch.Value(identifier, null, null, initializer);
                                     }
                                     return values.map(identifier);
                                 });
@@ -605,23 +606,23 @@ pytorch.Node = class {
 pytorch.Tensor = class {
 
     constructor(name, tensor) {
-        this._name = name || '';
+        this.name = name || '';
         const layout = tensor.layout ? tensor.layout.__str__() : null;
         const storage = tensor.storage();
         const size = tensor.size() || [];
         if (layout && layout.startsWith('torch.sparse_')) {
-            this._type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size), layout.split('.').pop().replace('_', '.'));
-            this._indices = new pytorch.Tensor('', tensor.indices);
+            this.type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size), layout.split('.').pop().replace('_', '.'));
+            this.indices = new pytorch.Tensor('', tensor.indices);
             this._values = new pytorch.Tensor('', tensor.values);
         } else if (!layout || layout === 'torch.strided') {
-            this._type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
+            this.type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
             this._data = storage.data;
-            this._encoding = '<';
-            this._indices = null;
-            this._stride = tensor.stride();
-            const stride = this._stride;
+            this.encoding = '<';
+            this.indices = null;
+            this.stride = tensor.stride();
+            const stride = this.stride;
             const offset = tensor.storage_offset();
-            const length = size.every((v) => v !== 0) ? size.reduce((a, v, i) => a + stride[i] * (v - 1), 1) : 0;
+            const length = stride ? size.every((v) => v !== 0) ? size.reduce((a, v, i) => a + stride[i] * (v - 1), 1) : 0 : storage.size();
             if (offset !== 0 || length !== storage.size()) {
                 const itemsize = storage.dtype.itemsize();
                 this._offset = itemsize * offset;
@@ -632,28 +633,8 @@ pytorch.Tensor = class {
         }
     }
 
-    get name() {
-        return this._name;
-    }
-
-    get type() {
-        return this._type;
-    }
-
-    get encoding() {
-        return this._encoding;
-    }
-
-    get stride() {
-        return this._stride;
-    }
-
-    get indices() {
-        return this._indices;
-    }
-
     get values() {
-        const type = this._type.layout;
+        const type = this.type.layout;
         if (type && type.startsWith('sparse.')) {
             return this._values;
         }
@@ -675,10 +656,10 @@ pytorch.Tensor = class {
     }
 
     decode() {
-        if (this._encoding !== '<') {
-            throw new pytorch.Error(`Tensor encoding '${this._encoding}' not implemented.`);
+        if (this.encoding !== '<') {
+            throw new pytorch.Error(`Tensor encoding '${this.encoding}' not implemented.`);
         }
-        const type = this._type;
+        const type = this.type;
         const data = this.values;
         const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
         switch (type.dataType) {
@@ -3960,6 +3941,126 @@ pytorch.nnapi.SerializedModel = class {
     }
 };
 
+pytorch.nnapi.Graph = class {
+
+    constructor(model) {
+        this.name = 'torch.classes._nnapi.Compilation';
+        this.nodes = [];
+        this.inputs = [];
+        this.outputs = [];
+        const values = new Map();
+        values.map = (operand) => {
+            if (!values.has(operand.index)) {
+                const name = operand.index.toString();
+                const dimensions = operand.dimensions;
+                const shape = new pytorch.TensorShape(dimensions);
+                let dataType = operand.data_type.replace('[]', '');
+                let quantization = null;
+                switch (dataType) {
+                    case 'quant8_asymm':
+                    case 'quant8_symm_per_channel':
+                    case 'quant8_symm':
+                    case 'quant8_asymm_signed[]':
+                    case 'quant16_asymm':
+                    case 'quant16_symm':
+                        quantization = dataType;
+                        dataType = dataType.indexOf('16') !== -1 ? 'uint16' : 'uint8';
+                        break;
+                    default:
+                        break;
+                }
+                const type = new pytorch.TensorType(dataType, shape);
+                let initializer = null;
+                if (operand.data) {
+                    const size = dimensions.reduce((a, b) => a * b, 1);
+                    const tensor = {
+                        size: () => dimensions,
+                        stride: () => null,
+                        storage_offset: () => 0,
+                        storage: () => ({
+                            dtype: { __reduce__: () => type.dataType },
+                            data: operand.data, size: () => size
+                        })
+                    };
+                    initializer = new pytorch.Tensor(null, tensor);
+                }
+                if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) {
+                    quantization = {
+                        type: quantization || 'linear',
+                        scale: [ operand.scale ],
+                        offset: [ operand.zero_point ]
+                    };
+                }
+                const value = new pytorch.Value(name, type, quantization, initializer);
+                values.set(operand.index, value);
+            }
+            return values.get(operand.index);
+        };
+        const metadata = new pytorch.nnapi.Metadata();
+        for (const operation of model.operations) {
+            const node = new pytorch.nnapi.Node(metadata, operation, values);
+            this.nodes.push(node);
+        }
+        for (let i = 0; i < model.inputs.length; i++) {
+            const name = i.toString();
+            const operand = model.inputs[i];
+            const argument = new pytorch.Argument(name, [ values.map(operand) ]);
+            this.inputs.push(argument);
+        }
+        for (let i = 0; i < model.outputs.length; i++) {
+            const name = i.toString();
+            const operand = model.outputs[i];
+            const argument = new pytorch.Argument(name, [ values.map(operand) ]);
+            this.outputs.push(argument);
+        }
+    }
+};
+
+pytorch.nnapi.Node = class {
+
+    constructor(metadata, operation, values) {
+        const signature = (operation.inputs || []).map((input) => input.data_type);
+        this.name = '';
+        this.type = metadata.type(operation.index, signature);
+        this.inputs = [];
+        this.outputs = [];
+        this.attributes = [];
+        this.chain = [];
+        if (operation.location !== undefined) {
+            this.location = operation.location.toString();
+        }
+        const inputs = this.type.inputs.concat(this.type.attributes);
+        if (operation.inputs) {
+            for (let i = 0; i < operation.inputs.length; i++) {
+                const name = i < inputs.length ? inputs[i].name : i.toString();
+                const operand = operation.inputs[i];
+                if (operand.dimensions.length > 0) {
+                    const value = values.map(operand);
+                    const argument = new pytorch.Argument(name, [ value ]);
+                    this.inputs.push(argument);
+                } else if (name === 'activation') {
+                    const activation = new Map([ [ 1, 19 ], [ 2, 20 ], [ 3, 21 ] ]).get(operand.value) || 0;
+                    if (activation !== 0) {
+                        this.chain.push(new pytorch.nnapi.Node(metadata, { index: activation }));
+                    }
+                } else {
+                    const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false);
+                    this.attributes.push(attribute);
+                }
+            }
+        }
+        if (operation.outputs) {
+            for (let i = 0; i < operation.outputs.length; i++) {
+                const name = i < inputs.length ? inputs[i].name : i.toString();
+                const operand = operation.outputs[i];
+                const value = values.map(operand);
+                const argument = new pytorch.Argument(name, [ value ]);
+                this.outputs.push(argument);
+            }
+        }
+    }
+};
+
 pytorch.nnapi.Metadata = class {
 
     constructor() {
@@ -4096,15 +4197,7 @@ pytorch.nnapi.Metadata = class {
         for (const type of types) {
             const inputs = type.inputs.concat(type.attributes);
             if (signature.length < inputs.length) {
-                let match = true;
-                for (let i = 0; i < inputs.length; i++) {
-                    const input = inputs[i];
-                    if (input.type === undefined || input.type === 'Tensor' || input.type === signature[i]) {
-                        continue;
-                    }
-                    match = false;
-                }
-                if (match) {
+                if (inputs.every((input, i) => input.type === undefined || input.type === 'Tensor' || input.type === signature[i])) {
                     return type;
                 }
             }
@@ -4113,127 +4206,6 @@ pytorch.nnapi.Metadata = class {
     }
 };
 
-pytorch.nnapi.Graph = class {
-
-    constructor(model) {
-        this.name = 'torch.classes._nnapi.Compilation';
-        this.nodes = [];
-        this.inputs = [];
-        this.outputs = [];
-        const values = new Map();
-        values.map = (operand) => {
-            if (!values.has(operand.index)) {
-                const value = new pytorch.nnapi.Argument(operand);
-                values.set(operand.index, value);
-            }
-            return values.get(operand.index);
-        };
-        const metadata = new pytorch.nnapi.Metadata();
-        for (const operation of model.operations) {
-            const node = new pytorch.nnapi.Node(metadata, operation, values);
-            this.nodes.push(node);
-        }
-        for (let i = 0; i < model.inputs.length; i++) {
-            const name = i.toString();
-            const operand = model.inputs[i];
-            const argument = new pytorch.Argument(name, [ values.map(operand) ]);
-            this.inputs.push(argument);
-        }
-        for (let i = 0; i < model.outputs.length; i++) {
-            const name = i.toString();
-            const operand = model.outputs[i];
-            const argument = new pytorch.Argument(name, [ values.map(operand) ]);
-            this.outputs.push(argument);
-        }
-    }
-};
-
-pytorch.nnapi.Argument = class {
-
-    constructor(operand) {
-        this.name = operand.index.toString();
-        const shape = new pytorch.TensorShape(operand.dimensions);
-        let dataType = operand.data_type.replace('[]', '');
-        let quantization = null;
-        switch (dataType) {
-            case 'quant8_asymm':
-            case 'quant8_symm_per_channel':
-            case 'quant8_symm':
-            case 'quant8_asymm_signed[]':
-            case 'quant16_asymm':
-            case 'quant16_symm':
-                quantization = dataType;
-                dataType = dataType.indexOf('16') !== -1 ? 'uint16' : 'uint8';
-                break;
-            default:
-                break;
-        }
-        this.type = new pytorch.TensorType(dataType, shape);
-        this.initializer = operand.data ? new pytorch.nnapi.Tensor(this.type, operand.data) : null;
-        if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) {
-            this.quantization = {
-                type: quantization || 'linear',
-                scale: [ operand.scale ],
-                offset: [ operand.zero_point ]
-            };
-        }
-    }
-};
-
-pytorch.nnapi.Node = class {
-
-    constructor(metadata, operation, values) {
-        const signature = (operation.inputs || []).map((input) => input.data_type);
-        this.name = '';
-        this.type = metadata.type(operation.index, signature);
-        this.inputs = [];
-        this.outputs = [];
-        this.attributes = [];
-        this.chain = [];
-        if (operation.location !== undefined) {
-            this.location = operation.location.toString();
-        }
-        const inputs = this.type.inputs.concat(this.type.attributes);
-        if (operation.inputs) {
-            for (let i = 0; i < operation.inputs.length; i++) {
-                const name = i < inputs.length ? inputs[i].name : i.toString();
-                const operand = operation.inputs[i];
-                if (operand.dimensions.length > 0) {
-                    const value = values.map(operand);
-                    const argument = new pytorch.Argument(name, [ value ]);
-                    this.inputs.push(argument);
-                } else if (name === 'activation') {
-                    const activation = new Map([ [ 1, 19 ], [ 2, 20 ], [ 3, 21 ] ]).get(operand.value) || 0;
-                    if (activation !== 0) {
-                        this.chain.push(new pytorch.nnapi.Node(metadata, { index: activation }));
-                    }
-                } else {
-                    const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false);
-                    this.attributes.push(attribute);
-                }
-            }
-        }
-        if (operation.outputs) {
-            for (let i = 0; i < operation.outputs.length; i++) {
-                const name = i < inputs.length ? inputs[i].name : i.toString();
-                const operand = operation.outputs[i];
-                const value = values.map(operand);
-                const argument = new pytorch.Argument(name, [ value ]);
-                this.outputs.push(argument);
-            }
-        }
-    }
-};
-
-pytorch.nnapi.Tensor = class {
-
-    constructor(type, data) {
-        this.type = type;
-        this.encoding = '<';
-        this.values = data;
-    }
-};
-
 pytorch.Metadata = class {
 
     static async open(context) {
@@ -4271,24 +4243,24 @@ pytorch.Metadata = class {
     }
 
     attribute(type, name) {
-        const attributeName = `${type}:${name}`;
-        if (!this._attributes.has(attributeName)) {
-            this._attributes.set(attributeName, null);
-            const schema = this.type(type);
-            if (schema) {
-                if (schema.inputs) {
-                    for (const input of schema.inputs) {
+        const key = `${type}:${name}`;
+        if (!this._attributes.has(key)) {
+            this._attributes.set(key, null);
+            const metadata = this.type(type);
+            if (metadata) {
+                if (metadata.inputs) {
+                    for (const input of metadata.inputs) {
                         this._attributes.set(`${type}:${input.name}`, input);
                     }
                 }
-                if (schema.attributes) {
-                    for (const attribute of schema.attributes) {
+                if (metadata.attributes) {
+                    for (const attribute of metadata.attributes) {
                         this._attributes.set(`${type}:${attribute.name}`, attribute);
                     }
                 }
             }
         }
-        return this._attributes.get(attributeName);
+        return this._attributes.get(key);
     }
 };