Lutz Roeder преди 4 месеца
родител
ревизия
06df36330e
променени са 5 файла, в които са добавени 39 реда и са изтрити 19 реда
  1. 33 15
      source/executorch.js
  2. 1 1
      source/python.js
  3. 1 1
      source/pytorch-metadata.json
  4. 3 1
      source/view.js
  5. 1 1
      tools/pytorch_script.py

+ 33 - 15
source/executorch.js

@@ -234,22 +234,40 @@ executorch.Node = class {
 executorch.TensorType = class {
 
     constructor(tensor) {
-        executorch.TensorType._types = executorch.TensorType._types || [
-            'uint8',
-            'int8', 'int16', 'int32', 'int64',
-            'float16', 'float32', 'float64',
-            'complex16', 'complex32', 'complex64',
-            'boolean',
-            'qint8', 'quint8', 'qint32',
-            'bfloat16',
-            'quint4x2', 'quint2x4', 'bits1x8', 'bits2x4', 'bits4x2', 'bits8', 'bits16',
-            'float8e5m2', 'float8e4m3fn', 'float8e5m2fnuz', 'float8e4m3fnuz',
-            'uint16', 'uint32', 'uint64'
-        ];
-        if (tensor.scalar_type >= executorch.TensorType._types.length) {
-            throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
+        const ScalarType = executorch.schema.executorch_flatbuffer.ScalarType;
+        switch (tensor.scalar_type) {
+            case ScalarType.BYTE: this.dataType = 'uint8'; break;
+            case ScalarType.CHAR: this.dataType = 'int8'; break;
+            case ScalarType.SHORT: this.dataType = 'int16'; break;
+            case ScalarType.INT: this.dataType = 'int32'; break;
+            case ScalarType.LONG: this.dataType = 'int64'; break;
+            case ScalarType.HALF: this.dataType = 'float16'; break;
+            case ScalarType.FLOAT: this.dataType = 'float32'; break;
+            case ScalarType.DOUBLE: this.dataType = 'float64'; break;
+            case 8: this.dataType = 'complex32'; break;
+            case 9: this.dataType = 'complex64'; break;
+            case 10: this.dataType = 'complex128'; break;
+            case ScalarType.BOOL: this.dataType = 'boolean'; break;
+            case ScalarType.QINT8: this.dataType = 'qint8'; break;
+            case ScalarType.QUINT8: this.dataType = 'quint8'; break;
+            case ScalarType.QINT32: this.dataType = 'qint32'; break;
+            case 15: this.dataType = 'bfloat16'; break;
+            case ScalarType.QUINT4X2: this.dataType = 'quint4x2'; break;
+            case ScalarType.QUINT2X4: this.dataType = 'quint2x4'; break;
+            case 18: this.dataType = 'bits1x8'; break;
+            case 19: this.dataType = 'bits2x4'; break;
+            case 20: this.dataType = 'bits4x2'; break;
+            case 21: this.dataType = 'bits8'; break;
+            case ScalarType.BITS16: this.dataType = 'bits16'; break;
+            case ScalarType.FLOAT8E5M2: this.dataType = 'float8e5m2'; break;
+            case ScalarType.FLOAT8E4M3FN: this.dataType = 'float8e4m3fn'; break;
+            case ScalarType.FLOAT8E5M2FNUZ: this.dataType = 'float8e5m2fnuz'; break;
+            case ScalarType.FLOAT8E4M3FNUZ: this.dataType = 'float8e4m3fnuz'; break;
+            case ScalarType.UINT16: this.dataType = 'uint16'; break;
+            case ScalarType.UINT32: this.dataType = 'uint32'; break;
+            case ScalarType.UINT64: this.dataType = 'uint64'; break;
+            default: throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
         }
-        this.dataType = executorch.TensorType._types[tensor.scalar_type];
         this.shape = new executorch.TensorShape(Array.from(tensor.sizes));
     }
 

+ 1 - 1
source/python.js

@@ -10323,7 +10323,7 @@ python.Execution = class {
             }
             expect(kind) {
                 if (this.cur().kind !== kind) {
-                    throw new python.Error(`Unexpected '${this.kind}' instead of '${kind}'.`);
+                    throw new python.Error(`Unexpected '${this.cur().kind}' instead of '${kind}'.`);
                 }
                 return this.next();
             }

+ 1 - 1
source/pytorch-metadata.json

@@ -7090,7 +7090,7 @@
     "name": "detectron2::roi_align_rotated_forward(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor"
   },
   {
-    "name": "dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!)"
+    "name": "dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
   },
   {
     "name": "dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"

+ 3 - 1
source/view.js

@@ -3720,7 +3720,9 @@ view.TensorSidebar = class extends view.ObjectSidebar {
                 const dataType = type.dataType;
                 this.addProperty('type', `${dataType}`, 'code');
                 const shape = type.shape && Array.isArray(type.shape.dimensions) ? type.shape.dimensions.toString(', ') : '?';
-                this.addProperty('shape', `${shape || ' '}`, 'code');
+                if (shape) {
+                    this.addProperty('shape', shape, 'code');
+                }
                 const denotation = type.denotation;
                 if (denotation) {
                     this.addProperty('denotation', denotation, 'code');

+ 1 - 1
tools/pytorch_script.py

@@ -77,7 +77,7 @@ known_legacy_schema_definitions = [
     "cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", # noqa E501
     "detectron2::nms_rotated(Tensor boxes, Tensor scores, float iou_threshold) -> Tensor", # noqa E501
     "detectron2::roi_align_rotated_forward(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", # noqa E501
-    "dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!)", # noqa E501
+    "dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)", # noqa E501
     "dim_order_ops::_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)", # noqa E501
     "dim_order_ops::_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)", # noqa E501
     "executorch_prim::et_view.default(Tensor self, int[] size) -> (Tensor out)",