|
|
@@ -234,22 +234,40 @@ executorch.Node = class {
|
|
|
executorch.TensorType = class {
|
|
|
|
|
|
constructor(tensor) {
|
|
|
- executorch.TensorType._types = executorch.TensorType._types || [
|
|
|
- 'uint8',
|
|
|
- 'int8', 'int16', 'int32', 'int64',
|
|
|
- 'float16', 'float32', 'float64',
|
|
|
- 'complex16', 'complex32', 'complex64',
|
|
|
- 'boolean',
|
|
|
- 'qint8', 'quint8', 'qint32',
|
|
|
- 'bfloat16',
|
|
|
- 'quint4x2', 'quint2x4', 'bits1x8', 'bits2x4', 'bits4x2', 'bits8', 'bits16',
|
|
|
- 'float8e5m2', 'float8e4m3fn', 'float8e5m2fnuz', 'float8e4m3fnuz',
|
|
|
- 'uint16', 'uint32', 'uint64'
|
|
|
- ];
|
|
|
- if (tensor.scalar_type >= executorch.TensorType._types.length) {
|
|
|
- throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
|
|
|
+ const ScalarType = executorch.schema.executorch_flatbuffer.ScalarType;
|
|
|
+ switch (tensor.scalar_type) {
|
|
|
+ case ScalarType.BYTE: this.dataType = 'uint8'; break;
|
|
|
+ case ScalarType.CHAR: this.dataType = 'int8'; break;
|
|
|
+ case ScalarType.SHORT: this.dataType = 'int16'; break;
|
|
|
+ case ScalarType.INT: this.dataType = 'int32'; break;
|
|
|
+ case ScalarType.LONG: this.dataType = 'int64'; break;
|
|
|
+ case ScalarType.HALF: this.dataType = 'float16'; break;
|
|
|
+ case ScalarType.FLOAT: this.dataType = 'float32'; break;
|
|
|
+ case ScalarType.DOUBLE: this.dataType = 'float64'; break;
|
|
|
+ case 8: this.dataType = 'complex32'; break;
|
|
|
+ case 9: this.dataType = 'complex64'; break;
|
|
|
+ case 10: this.dataType = 'complex128'; break;
|
|
|
+ case ScalarType.BOOL: this.dataType = 'boolean'; break;
|
|
|
+ case ScalarType.QINT8: this.dataType = 'qint8'; break;
|
|
|
+ case ScalarType.QUINT8: this.dataType = 'quint8'; break;
|
|
|
+ case ScalarType.QINT32: this.dataType = 'qint32'; break;
|
|
|
+ case 15: this.dataType = 'bfloat16'; break;
|
|
|
+ case ScalarType.QUINT4X2: this.dataType = 'quint4x2'; break;
|
|
|
+ case ScalarType.QUINT2X4: this.dataType = 'quint2x4'; break;
|
|
|
+ case 18: this.dataType = 'bits1x8'; break;
|
|
|
+ case 19: this.dataType = 'bits2x4'; break;
|
|
|
+ case 20: this.dataType = 'bits4x2'; break;
|
|
|
+ case 21: this.dataType = 'bits8'; break;
|
|
|
+ case ScalarType.BITS16: this.dataType = 'bits16'; break;
|
|
|
+ case ScalarType.FLOAT8E5M2: this.dataType = 'float8e5m2'; break;
|
|
|
+ case ScalarType.FLOAT8E4M3FN: this.dataType = 'float8e4m3fn'; break;
|
|
|
+ case ScalarType.FLOAT8E5M2FNUZ: this.dataType = 'float8e5m2fnuz'; break;
|
|
|
+ case ScalarType.FLOAT8E4M3FNUZ: this.dataType = 'float8e4m3fnuz'; break;
|
|
|
+ case ScalarType.UINT16: this.dataType = 'uint16'; break;
|
|
|
+ case ScalarType.UINT32: this.dataType = 'uint32'; break;
|
|
|
+ case ScalarType.UINT64: this.dataType = 'uint64'; break;
|
|
|
+ default: throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
|
|
|
}
|
|
|
- this.dataType = executorch.TensorType._types[tensor.scalar_type];
|
|
|
this.shape = new executorch.TensorShape(Array.from(tensor.sizes));
|
|
|
}
|
|
|
|