Răsfoiți Sursa

Update tensor formatter (#741) (#961)

Lutz Roeder 2 ani în urmă
părinte
comite
916ec3f2e3

+ 4 - 4
source/bigdl.js

@@ -327,11 +327,11 @@ bigdl.Tensor = class {
                 case 'float32':
                     if (storage.bytes_data && storage.bytes_data.length > 0) {
                         this._values = storage.bytes_data[0];
-                        this._layout = '<';
+                        this._encoding = '<';
                     }
                     else if (storage.float_data && storage.float_data.length > 0) {
                         this._values = storage.float_data;
-                        this._layout = '|';
+                        this._encoding = '|';
                     }
                     break;
                 default:
@@ -349,8 +349,8 @@ bigdl.Tensor = class {
         return this._type;
     }
 
-    get layout() {
-        return this._layout;
+    get encoding() {
+        return this._encoding;
     }
 
     get values() {

+ 1 - 1
source/caffe.js

@@ -627,7 +627,7 @@ caffe.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '|';
     }
 

+ 1 - 1
source/caffe2.js

@@ -561,7 +561,7 @@ caffe2.Tensor = class {
         return null;
     }
 
-    get layout() {
+    get encoding() {
         return '|';
     }
 

+ 1 - 1
source/circle.js

@@ -555,7 +555,7 @@ circle.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'string': return '|';
             default: return '<';

+ 1 - 1
source/cntk.js

@@ -545,7 +545,7 @@ cntk.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '|';
     }
 

+ 1 - 1
source/coreml.js

@@ -1240,7 +1240,7 @@ coreml.Tensor = class {
         return null;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'float32': return '|';
             default: return '<';

+ 2 - 2
source/dlc.js

@@ -243,10 +243,10 @@ dlc.Tensor = class {
     constructor(type, data) {
         this.type = type;
         if (data instanceof Uint8Array) {
-            this.layout = '<';
+            this.encoding = '<';
             this.values = data;
         } else {
-            this.layout = '|';
+            this.encoding = '|';
             switch (type.dataType) {
                 case 'uint8': this.values = data.bytes; break;
                 case 'float32': this.values = data.floats; break;

+ 1 - 1
source/flax.js

@@ -256,7 +256,7 @@ flax.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'string':
             case 'object':

+ 1 - 1
source/hickle.js

@@ -196,7 +196,7 @@ hickle.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return this._littleEndian ? '<' : '>';
     }
 

+ 7 - 7
source/keras.js

@@ -126,8 +126,8 @@ keras.ModelFactory = class {
                                         const components = weight_name.split('/');
                                         components.pop();
                                         const name = (components.length == 0 || components[0] !== layer_name) ? [ layer_name ].concat(components).join('/') : components.join('/');
-                                        const layout = variable.littleEndian ? '<' : '>';
-                                        const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, layout, variable.data);
+                                        const encoding = variable.littleEndian ? '<' : '>';
+                                        const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, encoding, variable.data);
                                         weights.add(name, tensor);
                                     }
                                 }
@@ -962,11 +962,11 @@ keras.Attribute = class {
 
 keras.Tensor = class {
 
-    constructor(name, shape, type, quantization, layout, data) {
+    constructor(name, shape, type, quantization, encoding, data) {
         this._name = name;
         this._type = new keras.TensorType(type, new keras.TensorShape(shape));
         this._quantization = quantization;
-        this._layout = layout;
+        this._encoding = encoding;
         this._data = data;
     }
 
@@ -978,8 +978,8 @@ keras.Tensor = class {
         return this._type;
     }
 
-    get layout() {
-        return this._layout;
+    get encoding() {
+        return this._encoding;
     }
 
     get quantization() {
@@ -992,7 +992,7 @@ keras.Tensor = class {
     }
 
     get values() {
-        if (this._layout === '|') {
+        if (this._encoding === '|') {
             return this._data;
         }
         if (this._data === null) {

+ 1 - 1
source/mlir.js

@@ -403,7 +403,7 @@ mlir.Tensor = class {
         return null;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'float32': return '|';
             default: return '<';

+ 1 - 1
source/mnn.js

@@ -364,7 +364,7 @@ mnn.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'int32':
             case 'float32':

+ 1 - 1
source/mslite.js

@@ -321,7 +321,7 @@ mslite.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'string': return '|';
             default: return '<';

+ 1 - 1
source/mxnet.js

@@ -750,7 +750,7 @@ mxnet.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '<';
     }
 

+ 1 - 1
source/nnabla.js

@@ -326,7 +326,7 @@ nnabla.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '|';
     }
 

+ 1 - 1
source/numpy.js

@@ -233,7 +233,7 @@ numpy.Tensor = class  {
     constructor(array) {
         this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
         this.values = this.type.dataType == 'string' || this.type.dataType == 'object' ? array.flatten().tolist() : array.tobytes();
-        this.layout = this.type.dataType == 'string' || this.type.dataType == 'object' ? '|' : array.dtype.byteorder;
+        this.encoding = this.type.dataType == 'string' || this.type.dataType == 'object' ? '|' : array.dtype.byteorder;
     }
 };
 

+ 18 - 19
source/onnx.js

@@ -618,7 +618,6 @@ onnx.Tensor = class {
             this._name = tensor.values.name || '';
             this._type = context.createTensorType(tensor.values.data_type, tensor.dims.map((dim) => dim), 'sparse');
             this._location = context.createLocation(tensor.values.data_location);
-            this._layout = 'sparse';
             this._values = new onnx.Tensor(context, tensor.values);
             this._indices = new onnx.Tensor(context, tensor.indices);
         } else {
@@ -633,11 +632,11 @@ onnx.Tensor = class {
                         }
                         case onnx.DataType.FLOAT:
                             this._data = new Float32Array(tensor.float_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.DOUBLE:
                             this._data = new Float64Array(tensor.double_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.BOOL:
                             if (tensor.int32_data && tensor.int32_data.length > 0) {
@@ -646,41 +645,41 @@ onnx.Tensor = class {
                                 for (let i = 0; i < this._data.length; i++) {
                                     this._data[i] = array[i] === 0 ? false : true;
                                 }
-                                this._layout = '|';
+                                this._encoding = '|';
                             }
                             break;
                         case onnx.DataType.INT8:
                             this._data = new Int8Array(tensor.int32_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.UINT8:
                             this._data = new Uint8Array(tensor.int32_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.INT16:
                             this._data = new Int32Array(tensor.int32_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.UINT16:
                             this._data = new Int32Array(tensor.int32_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.INT32:
                             this._data = new Int32Array(tensor.int32_data);
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.UINT32:
                         case onnx.DataType.UINT64:
                             this._data = tensor.uint64_data;
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.INT64:
                             this._data = tensor.int64_data;
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.STRING:
                             this._data = tensor.string_data;
-                            this._layout = '|';
+                            this._encoding = '|';
                             break;
                         case onnx.DataType.COMPLEX64:
                         case onnx.DataType.COMPLEX128:
@@ -695,7 +694,7 @@ onnx.Tensor = class {
                                     view.setUint16(i << 1, array[i], true);
                                 }
                                 this._data = buffer;
-                                this._layout = '<';
+                                this._encoding = '<';
                             }
                             break;
                         case onnx.DataType.FLOAT8E4M3FN:
@@ -704,7 +703,7 @@ onnx.Tensor = class {
                         case onnx.DataType.FLOAT8E5M2FNUZ:
                             if (tensor.int32_data && tensor.int32_data.length > 0) {
                                 this._data = new Uint8Array(Array.from(tensor.int32_data));
-                                this._layout = '<';
+                                this._encoding = '<';
                             }
                             break;
                         default:
@@ -715,7 +714,7 @@ onnx.Tensor = class {
                     }
                     if (!this._data && tensor.raw_data && tensor.raw_data.length > 0) {
                         this._data = tensor.raw_data;
-                        this._layout = '<';
+                        this._encoding = '<';
                     }
                     break;
                 }
@@ -730,7 +729,7 @@ onnx.Tensor = class {
                             const length = parseInt(external_data.length, 10);
                             if (Number.isInteger(offset) && Number.isInteger(length)) {
                                 this._data = context.location(external_data.location, offset, length);
-                                this._layout = '<';
+                                this._encoding = '<';
                             }
                         }
                     }
@@ -751,8 +750,8 @@ onnx.Tensor = class {
         return this._category;
     }
 
-    get layout() {
-        return this._layout;
+    get encoding() {
+        return this._encoding;
     }
 
     get type() {
@@ -764,7 +763,7 @@ onnx.Tensor = class {
     }
 
     get values() {
-        switch (this._layout) {
+        switch (this.type.layout) {
             case 'sparse': {
                 return this._values;
             }

+ 16 - 10
source/pytorch.js

@@ -663,16 +663,16 @@ pytorch.Tensor = class {
         this._name = name || '';
         const storage = tensor.storage();
         const size = tensor.size();
-        this._type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
         const layout = tensor.layout ? tensor.layout.__str__() : null;
         this._stride = tensor.stride();
         if (layout && layout.startsWith('torch.sparse_')) {
-            this._layout = layout.split('.').pop().replace('_', '.');
+            this._type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size), layout.split('.').pop().replace('_', '.'));
             this._indices = new pytorch.Tensor('', tensor.indices);
             this._values = new pytorch.Tensor('', tensor.values);
         } else if (!layout || layout === 'torch.strided') {
+            this._type = new pytorch.TensorType(storage.dtype.__reduce__(), new pytorch.TensorShape(size));
             this._data = storage.data;
-            this._layout = '<';
+            this._encoding = '<';
             this._indices = null;
         } else {
             throw new pytorch.Error("Unsupported tensor layout '" + layout + "'.");
@@ -687,8 +687,8 @@ pytorch.Tensor = class {
         return this._type;
     }
 
-    get layout() {
-        return this._layout;
+    get encoding() {
+        return this._encoding;
     }
 
     get stride() {
@@ -700,15 +700,16 @@ pytorch.Tensor = class {
     }
 
     get values() {
-        if (this._layout && this._layout.startsWith('sparse.')) {
+        const type = this._type.layout;
+        if (type && type.startsWith('sparse.')) {
             return this._values;
         }
         return this._data instanceof Uint8Array ? this._data : this._data.peek();
     }
 
     decode() {
-        if (this._layout !== '<') {
-            throw new pytorch.Error("Tensor layout '" + this._layout + "' not implemented.");
+        if (this._encoding !== '<') {
+            throw new pytorch.Error("Tensor encoding '" + this._encoding + "' not implemented.");
         }
         const type = this._type;
         const data = this.values;
@@ -740,9 +741,10 @@ pytorch.Tensor = class {
 
 pytorch.TensorType = class {
 
-    constructor(dataType, shape) {
+    constructor(dataType, shape, layout) {
         this._dataType = dataType;
         this._shape = shape;
+        this._layout = layout;
     }
 
     get dataType() {
@@ -753,6 +755,10 @@ pytorch.TensorType = class {
         return this._shape;
     }
 
+    get layout() {
+        return this._layout;
+    }
+
     toString() {
         return this._dataType + this._shape.toString();
     }
@@ -4210,7 +4216,7 @@ pytorch.nnapi.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '<';
     }
 

+ 1 - 1
source/safetensors.js

@@ -138,7 +138,7 @@ safetensors.Tensor = class {
     constructor(obj, position, stream) {
         const shape = new safetensors.TensorShape(obj.shape);
         this.type = new safetensors.TensorType(obj.dtype, shape);
-        this.layout = '<';
+        this.encoding = '<';
         const size = obj.data_offsets[1] - obj.data_offsets[0];
         position += obj.data_offsets[0];
         stream.seek(position);

+ 1 - 1
source/sklearn.js

@@ -259,7 +259,7 @@ sklearn.Tensor = class {
 
     constructor(array) {
         this.type = new sklearn.TensorType(array.dtype.__name__, new sklearn.TensorShape(array.shape));
-        this.layout = this.type.dataType == 'string' || this.type.dataType == 'object' ? '|' : array.dtype.byteorder;
+        this.encoding = this.type.dataType == 'string' || this.type.dataType == 'object' ? '|' : array.dtype.byteorder;
         this.values = this.type.dataType == 'string' || this.type.dataType == 'object' ? array.tolist() : array.tobytes();
     }
 };

+ 16 - 16
source/tf.js

@@ -1225,7 +1225,7 @@ tf.Tensor = class {
             this._tensor = tensor;
             if (Object.prototype.hasOwnProperty.call(tensor, 'tensor_content')) {
                 this._values = tensor.tensor_content;
-                this._layout = '<';
+                this._encoding = '<';
             } else {
                 const DataType = tf.proto.tensorflow.DataType;
                 switch (tensor.dtype) {
@@ -1239,7 +1239,7 @@ tf.Tensor = class {
                         for (let i = 0; i < values.length; i++) {
                             view.setUint32(i << 2, values[i] << 16, true);
                         }
-                        this._layout = '<';
+                        this._encoding = '<';
                         break;
                     }
                     case DataType.DT_HALF: {
@@ -1249,17 +1249,17 @@ tf.Tensor = class {
                         for (let i = 0; i < values.length; i++) {
                             view.setUint16(i << 1, values[i], true);
                         }
-                        this._layout = '<';
+                        this._encoding = '<';
                         break;
                     }
                     case DataType.DT_FLOAT: {
                         this._values = tensor.float_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_DOUBLE: {
                         this._values = tensor.double_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_UINT8:
@@ -1268,36 +1268,36 @@ tf.Tensor = class {
                     case DataType.DT_INT16:
                     case DataType.DT_INT32: {
                         this._values = tensor.int_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_UINT32: {
                         this._values = tensor.uint32_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_INT64: {
                         this._values = tensor.int64_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_UINT64: {
                         this._values = tensor.uint64_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_BOOL: {
                         this._values = tensor.bool_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_STRING: {
                         this._values = tensor.string_val || null;
-                        this._layout = '|';
+                        this._encoding = '|';
                         break;
                     }
                     case DataType.DT_COMPLEX64: {
-                        this._layout = '|';
+                        this._encoding = '|';
                         const values = tensor.scomplex_val || null;
                         this._values = new Array(values.length >> 1);
                         for (let i = 0; i < values.length; i += 2) {
@@ -1306,7 +1306,7 @@ tf.Tensor = class {
                         break;
                     }
                     case DataType.DT_COMPLEX128: {
-                        this._layout = '|';
+                        this._encoding = '|';
                         const values = tensor.dcomplex_val || null;
                         this._values = new Array(values.length >> 1);
                         for (let i = 0; i < values.length; i += 2) {
@@ -1337,13 +1337,13 @@ tf.Tensor = class {
         return this._category;
     }
 
-    get layout() {
-        return this._layout;
+    get encoding() {
+        return this._encoding;
     }
 
     get values() {
         let values = this._values;
-        if (this._layout === '|' && Array.isArray(values)) {
+        if (this._encoding === '|' && Array.isArray(values)) {
             if (this._type.dataType === 'string') {
                 values = values.map((value) => tf.Utility.decodeText(value));
             }

+ 1 - 1
source/tflite.js

@@ -569,7 +569,7 @@ tflite.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         switch (this._type.dataType) {
             case 'string': return '|';
             default: return '<';

+ 1 - 1
source/torch.js

@@ -490,7 +490,7 @@ torch.Tensor = class {
         return this._type;
     }
 
-    get layout() {
+    get encoding() {
         return '|';
     }
 

+ 76 - 68
source/view.js

@@ -2675,13 +2675,7 @@ view.ValueView = class extends view.Control {
                 if (location !== undefined) {
                     this._bold('location', location);
                 }
-                let layout = this._value.type ? this._value.type.layout : null;
-                if (initializer) {
-                    if (layout && layout !== initializer.layout) {
-                        throw new view.Error('Tensor type layout mismatch.');
-                    }
-                    layout = layout || initializer.layout;
-                }
+                const layout = this._value.type ? this._value.type.layout : null;
                 if (layout) {
                     const layouts = new Map([
                         [ 'sparse', 'sparse' ],
@@ -2729,7 +2723,9 @@ view.ValueView = class extends view.Control {
             if (Array.isArray(tensor.stride) && tensor.stride.length > 0) {
                 this._code('stride', tensor.stride.join(','));
             }
-            if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') {
+            if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
+                contentLine.innerHTML = "Tensor encoding '" + tensor.layout + "' is not implemented.";
+            } else if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
                 contentLine.innerHTML = "Tensor layout '" + tensor.layout + "' is not implemented.";
             } else if (tensor.empty) {
                 contentLine.innerHTML = 'Tensor data is empty.';
@@ -3237,40 +3233,40 @@ view.Tensor = class {
         this._tensor = tensor;
         this._type = tensor.type;
         this._stride = tensor.stride;
-        switch (tensor.layout) {
+        this._encoding = tensor.encoding;
+        this._layout = tensor.type.layout;
+        switch (this._encoding) {
             case undefined:
             case '':
             case '<': {
                 this._data = this._tensor.values;
-                this._layout = '<';
+                this._encoding = '<';
                 this._littleEndian = true;
                 break;
             }
             case '>': {
                 this._data = this._tensor.values;
-                this._layout = '>';
+                this._encoding = '>';
                 this._littleEndian = false;
                 break;
             }
             case '|': {
                 this._values = this._tensor.values;
-                this._layout = '|';
+                this._encoding = '|';
                 break;
             }
-            case 'sparse': {
-                this._indices = this._tensor.indices;
-                this._values = this._tensor.values;
-                this._layout = 'sparse';
-                break;
+            default: {
+                throw new view.Error("Unsupported tensor encoding '" + this._encoding + "'.");
             }
+        }
+        switch (this._layout) {
+            case 'sparse':
             case 'sparse.coo': {
                 this._indices = this._tensor.indices;
                 this._values = this._tensor.values;
-                this._layout = 'sparse.coo';
                 break;
             }
             default: {
-                this._layout = tensor.layout;
                 break;
             }
         }
@@ -3291,6 +3287,10 @@ view.Tensor = class {
         return this._type;
     }
 
+    get encoding() {
+        return this._encoding;
+    }
+
     get layout() {
         return this._layout;
     }
@@ -3301,19 +3301,20 @@ view.Tensor = class {
 
     get empty() {
         switch (this._layout) {
-            case '<':
-            case '>': {
-                return !(Array.isArray(this._data) || this._data instanceof Uint8Array || this._data instanceof Int8Array) || this._data.length === 0;
-            }
-            case '|': {
-                return !(Array.isArray(this._values) || ArrayBuffer.isView(this._values)) || this._values.length === 0;
-            }
             case 'sparse':
             case 'sparse.coo': {
                 return !this._values || this.indices || this._values.values.length === 0;
             }
             default: {
-                throw new Error("Unsupported tensor format '" + this._format + "'.");
+                switch (this._encoding) {
+                    case '<':
+                    case '>':
+                        return !(Array.isArray(this._data) || this._data instanceof Uint8Array || this._data instanceof Int8Array) || this._data.length === 0;
+                    case '|':
+                        return !(Array.isArray(this._values) || ArrayBuffer.isView(this._values)) || this._values.length === 0;
+                    default:
+                        throw new Error("Unsupported tensor encoding '" + this._encoding + "'.");
+                }
             }
         }
     }
@@ -3321,7 +3322,7 @@ view.Tensor = class {
     get value() {
         const context = this._context();
         context.limit = Number.MAX_SAFE_INTEGER;
-        switch (context.layout) {
+        switch (context.encoding) {
             case '<':
             case '>': {
                 return this._decodeData(context, 0);
@@ -3330,7 +3331,7 @@ view.Tensor = class {
                 return this._decodeValues(context, 0);
             }
             default: {
-                throw new Error("Unsupported tensor format '" + this._format + "'.");
+                throw new Error("Unsupported tensor encoding '" + context.encoding + "'.");
             }
         }
     }
@@ -3338,7 +3339,7 @@ view.Tensor = class {
     toString() {
         const context = this._context();
         context.limit = 10000;
-        switch (context.layout) {
+        switch (context.encoding) {
             case '<':
             case '>': {
                 const value = this._decodeData(context, 0);
@@ -3349,59 +3350,30 @@ view.Tensor = class {
                 return view.Tensor._stringify(value, '', '    ');
             }
             default: {
-                throw new Error("Unsupported tensor format '" + this._format + "'.");
+                throw new Error("Unsupported tensor encoding '" + context.encoding + "'.");
             }
         }
     }
 
     _context() {
-        if (this._layout !== '<' && this._layout !== '>' && this._layout !== '|' && this._layout !== 'sparse' && this._layout !== 'sparse.coo') {
+        if (this._encoding !== '<' && this._encoding !== '>' && this._encoding !== '|') {
+            throw new Error("Tensor encoding '" + this._encoding + "' is not supported.");
+        }
+        if (this._layout && (this._layout !== 'sparse' && this._layout !== 'sparse.coo')) {
             throw new Error("Tensor layout '" + this._layout + "' is not supported.");
         }
         const dataType = this._type.dataType;
         const context = {};
-        context.layout = this._layout;
+        context.encoding = this._encoding;
         context.dimensions = this._type.shape.dimensions.map((value) => !Number.isInteger(value) && value.toNumber ? value.toNumber() : value);
         context.dataType = dataType;
         const size = context.dimensions.reduce((a, b) => a * b, 1);
         switch (this._layout) {
-            case '<':
-            case '>': {
-                context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek();
-                context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
-                if (view.Tensor.dataTypes.has(dataType)) {
-                    context.itemsize = view.Tensor.dataTypes.get(dataType);
-                    if (context.data.length < (context.itemsize * size)) {
-                        throw new Error('Invalid tensor data size.');
-                    }
-                } else if (dataType.startsWith('uint') && !isNaN(parseInt(dataType.substring(4), 10))) {
-                    context.dataType = 'uint';
-                    context.bits = parseInt(dataType.substring(4), 10);
-                    context.itemsize = 1;
-                } else if (dataType.startsWith('int') && !isNaN(parseInt(dataType.substring(3), 10))) {
-                    context.dataType = 'int';
-                    context.bits = parseInt(dataType.substring(3), 10);
-                    context.itemsize = 1;
-                } else {
-                    throw new Error("Tensor data type '" + dataType + "' is not implemented.");
-                }
-                break;
-            }
-            case '|': {
-                context.data = this._values;
-                if (!view.Tensor.dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object') {
-                    throw new Error("Tensor data type '" + dataType + "' is not implemented.");
-                }
-                if (size !== this._values.length) {
-                    throw new Error('Invalid tensor data length.');
-                }
-                break;
-            }
             case 'sparse': {
                 const indices = new view.Tensor(this._indices).value;
                 const values = new view.Tensor(this._values).value;
                 context.data = this._decodeSparse(dataType, context.dimensions, indices, values);
-                context.layout = '|';
+                context.encoding = '|';
                 break;
             }
             case 'sparse.coo': {
@@ -3423,11 +3395,47 @@ view.Tensor = class {
                     }
                 }
                 context.data = this._decodeSparse(dataType, context.dimensions, indices, values);
-                context.layout = '|';
+                context.encoding = '|';
                 break;
             }
             default: {
-                throw new view.Tensor("Unsupported tensor layout '" + this._layout + "'.");
+                switch (this._encoding) {
+                    case '<':
+                    case '>': {
+                        context.data = (this._data instanceof Uint8Array || this._data instanceof Int8Array) ? this._data : this._data.peek();
+                        context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
+                        if (view.Tensor.dataTypes.has(dataType)) {
+                            context.itemsize = view.Tensor.dataTypes.get(dataType);
+                            if (context.data.length < (context.itemsize * size)) {
+                                throw new Error('Invalid tensor data size.');
+                            }
+                        } else if (dataType.startsWith('uint') && !isNaN(parseInt(dataType.substring(4), 10))) {
+                            context.dataType = 'uint';
+                            context.bits = parseInt(dataType.substring(4), 10);
+                            context.itemsize = 1;
+                        } else if (dataType.startsWith('int') && !isNaN(parseInt(dataType.substring(3), 10))) {
+                            context.dataType = 'int';
+                            context.bits = parseInt(dataType.substring(3), 10);
+                            context.itemsize = 1;
+                        } else {
+                            throw new Error("Tensor data type '" + dataType + "' is not implemented.");
+                        }
+                        break;
+                    }
+                    case '|': {
+                        context.data = this._values;
+                        if (!view.Tensor.dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object') {
+                            throw new Error("Tensor data type '" + dataType + "' is not implemented.");
+                        }
+                        if (size !== this._values.length) {
+                            throw new Error('Invalid tensor data length.');
+                        }
+                        break;
+                    }
+                    default: {
+                        throw new view.Tensor("Unsupported tensor encoding '" + this._encoding + "'.");
+                    }
+                }
             }
         }
         context.index = 0;

+ 1 - 1
source/xmodel.js

@@ -244,7 +244,7 @@ xmodel.Tensor = class {
         if (node.op_attr && node.op_attr.data) {
             const data = node.op_attr.data;
             if (data.bytes_value && data.bytes_value.value) {
-                this.layout = '<';
+                this.encoding = '<';
                 this.values = data.bytes_value.value;
             }
         }

+ 7 - 4
test/models.js

@@ -550,9 +550,9 @@ class Target {
         }
         if (this.assert) {
             for (const assert of this.assert) {
-                const parts = assert.split('=').map((item) => item.trim());
+                const parts = assert.split('==').map((item) => item.trim());
                 const properties = parts[0].split('.');
-                const value = parts[1];
+                const value = JSON.parse(parts[1].replace(/\s*'|'\s*/g, '"'));
                 let context = { model: this.model };
                 while (properties.length) {
                     const property = properties.shift();
@@ -571,7 +571,7 @@ class Target {
                     }
                     throw new Error("Invalid property path: '" + parts[0]);
                 }
-                if (context !== value.toString()) {
+                if (context !== value) {
                     throw new Error("Invalid '" + context.toString() + "' != '" + assert + "'.");
                 }
             }
@@ -592,7 +592,10 @@ class Target {
                 if (value.initializer) {
                     value.initializer.type.toString();
                     const tensor = new view.Tensor(value.initializer);
-                    if (tensor.layout !== '<' && tensor.layout !== '>' && tensor.layout !== '|' && tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo') {
+                    if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
+                        throw new Error("Tensor encoding '" + tensor.encoding + "' is not implemented.");
+                    }
+                    if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
                         throw new Error("Tensor layout '" + tensor.layout + "' is not implemented.");
                     }
                     if (!tensor.empty) {

+ 8 - 7
test/models.json

@@ -3112,7 +3112,7 @@
     "target":   "centerface.param,centerface.bin",
     "source":   "https://raw.githubusercontent.com/MirrorYuChen/ncnn_example/798f64b7d5f0b883e05cb994258d43658b0661b6/models/centerface.param,https://raw.githubusercontent.com/MirrorYuChen/ncnn_example/798f64b7d5f0b883e05cb994258d43658b0661b6/models/centerface.bin",
     "format":   "ncnn",
-    "assert":   [ "model.graphs[0].nodes[0].type.name=Convolution" ],
+    "assert":   [ "model.graphs[0].nodes[0].type.name == 'Convolution'" ],
     "link":     "https://github.com/MirrorYuChen/ncnn_example"
   },
   {
@@ -3176,7 +3176,7 @@
     "target":   "MobileNetSSD_deploy.param.bin,MobileNetSSD_deploy.bin",
     "source":   "https://raw.githubusercontent.com/chehongshu/ncnnforandroid_objectiondetection_Mobilenetssd/master/MobileNetSSD_demo/app/src/main/assets/MobileNetSSD_deploy.param.bin,https://raw.githubusercontent.com/chehongshu/ncnnforandroid_objectiondetection_Mobilenetssd/master/MobileNetSSD_demo/app/src/main/assets/MobileNetSSD_deploy.bin",
     "format":   "ncnn",
-    "assert":   [ "model.graphs[0].nodes[1].type.name=Convolution" ],
+    "assert":   [ "model.graphs[0].nodes[1].type.name == 'Convolution'" ],
     "link":     "https://github.com/chehongshu/ncnnforandroid_objectiondetection_Mobilenetssd"
   },
   {
@@ -3518,7 +3518,7 @@
     "target":   "denotation_Add_ImageNet1920WithImageMetadataBgr8_SRGB_0_255.onnx",
     "source":   "https://github.com/lutzroeder/netron/files/2587943/onnx_denotation_models.zip[denotation_Add_ImageNet1920WithImageMetadataBgr8_SRGB_0_255.onnx]",
     "format":   "ONNX v3",
-    "assert":   [ "model.graphs[0].nodes[0].outputs[0].value[0].type.denotation = Image(Bgr8,SRGB,NominalRange_0_255)" ],
+    "assert":   [ "model.graphs[0].nodes[0].outputs[0].value[0].type.denotation == 'Image(Bgr8,SRGB,NominalRange_0_255)'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/183"
   },
   {
@@ -3880,7 +3880,7 @@
     "target":   "sparse_initializer_as_output.json",
     "source":   "https://github.com/lutzroeder/netron/files/12444489/sparse_initializer_as_output.json.zip[sparse_initializer_as_output.json]",
     "format":   "ONNX JSON v7",
-    "assert":   [ "model.graphs[0].outputs[0].value[0].type.layout=sparse" ],
+    "assert":   [ "model.graphs[0].outputs[0].value[0].type.layout == 'sparse'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/741"
   },
   {
@@ -3895,7 +3895,7 @@
     "target":   "sparse_to_dense_matmul.onnx",
     "source":   "https://github.com/lutzroeder/netron/files/12444490/sparse_to_dense_matmul.onnx.zip[sparse_to_dense_matmul.onnx]",
     "format":   "ONNX v7",
-    "assert":   [ "model.graphs[0].nodes[0].inputs[0].value[0].type.layout=sparse" ],
+    "assert":   [ "model.graphs[0].nodes[0].inputs[0].value[0].type.layout == 'sparse'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/741"
   },
   {
@@ -5194,6 +5194,7 @@
     "target":   "sparse_coo.pth",
     "source":   "https://github.com/lutzroeder/netron/files/9541426/sparse_coo.pth.zip[sparse_coo.pth]",
     "format":   "PyTorch v1.6",
+    "assert":   [ "model.graphs[0].nodes[0].inputs[0].value[0].type.layout == 'sparse.coo'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/720"
   },
   {
@@ -5972,7 +5973,7 @@
     "source":   "https://github.com/lutzroeder/netron/files/5613448/events.out.tfevents.1606692323.b5c8f88cc7ee.58.2.zip[events.out.tfevents.1606692323.b5c8f88cc7ee.58.2]",
     "format":   "TensorFlow Event File v2",
     "producer": "PyTorch",
-    "assert":   [ "model.graphs[0].nodes[0].type.name=aten::_convolution" ],
+    "assert":   [ "model.graphs[0].nodes[0].type.name == 'aten::_convolution'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/638"
   },
   {
@@ -6452,7 +6453,7 @@
     "target":   "densenet.tflite",
     "source":   "https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz[densenet/densenet.tflite]",
     "format":   "TensorFlow Lite v3",
-    "assert":   [ "model.graphs[0].nodes[0].type.name=Conv2D" ],
+    "assert":   [ "model.graphs[0].nodes[0].type.name == 'Conv2D'" ],
     "link":     "https://www.tensorflow.org/lite/guide/hosted_models"
   },
   {