Răsfoiți Sursa

Add PyTorch test file (#720)

Lutz Roeder 3 ani în urmă
părinte
comite
44abcdc28a
2 a modificat fișierele cu 83 adăugiri și 31 ștergeri
  1. 76 31
      source/pytorch.js
  2. 7 0
      test/models.json

+ 76 - 31
source/pytorch.js

@@ -678,9 +678,11 @@ pytorch.Tensor = class {
             case 'float32':
             case 'float64':
             case 'bfloat16':
+            case 'complex64':
+            case 'complex128':
                 break;
             default:
-                context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
+                context.state = "Tensor data type '" + this._type.dataType + "' is not implemented.";
                 return context;
         }
         if (!this._type.shape) {
@@ -702,7 +704,7 @@ pytorch.Tensor = class {
 
         context.dataType = this._type.dataType;
         context.dimensions = this._type.shape.dimensions;
-        context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
+        context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
         return context;
     }
 
@@ -718,56 +720,66 @@ pytorch.Tensor = class {
                 }
                 switch (context.dataType) {
                     case 'boolean':
-                        results.push(context.dataView.getUint8(context.index) === 0 ?  false : true);
+                        results.push(context.view.getUint8(context.index) === 0 ?  false : true);
                         context.index++;
                         context.count++;
                         break;
                     case 'uint8':
-                        results.push(context.dataView.getUint8(context.index));
+                        results.push(context.view.getUint8(context.index));
                         context.index++;
                         context.count++;
                         break;
                     case 'qint8':
                     case 'int8':
-                        results.push(context.dataView.getInt8(context.index));
+                        results.push(context.view.getInt8(context.index));
                         context.index++;
                         context.count++;
                         break;
                     case 'int16':
-                        results.push(context.dataView.getInt16(context.index, this._littleEndian));
+                        results.push(context.view.getInt16(context.index, this._littleEndian));
                         context.index += 2;
                         context.count++;
                         break;
                     case 'int32':
-                        results.push(context.dataView.getInt32(context.index, this._littleEndian));
+                        results.push(context.view.getInt32(context.index, this._littleEndian));
                         context.index += 4;
                         context.count++;
                         break;
                     case 'int64':
-                        results.push(context.dataView.getInt64(context.index, this._littleEndian));
+                        results.push(context.view.getInt64(context.index, this._littleEndian));
                         context.index += 8;
                         context.count++;
                         break;
                     case 'float16':
-                        results.push(context.dataView.getFloat16(context.index, this._littleEndian));
+                        results.push(context.view.getFloat16(context.index, this._littleEndian));
                         context.index += 2;
                         context.count++;
                         break;
                     case 'float32':
-                        results.push(context.dataView.getFloat32(context.index, this._littleEndian));
+                        results.push(context.view.getFloat32(context.index, this._littleEndian));
                         context.index += 4;
                         context.count++;
                         break;
                     case 'float64':
-                        results.push(context.dataView.getFloat64(context.index, this._littleEndian));
+                        results.push(context.view.getFloat64(context.index, this._littleEndian));
                         context.index += 8;
                         context.count++;
                         break;
                     case 'bfloat16':
-                        results.push(context.dataView.getBfloat16(context.index, this._littleEndian));
+                        results.push(context.view.getBfloat16(context.index, this._littleEndian));
                         context.index += 2;
                         context.count++;
                         break;
+                    case 'complex64':
+                        results.push(context.view.getComplex64(i << 3, this._littleEndian));
+                        context.index += 8;
+                        context.count++;
+                        break;
+                    case 'complex128':
+                        results.push(context.view.getComplex128(i << 4, this._littleEndian));
+                        context.index += 16;
+                        context.count++;
+                        break;
                     default:
                         throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'.");
                 }
@@ -799,22 +811,26 @@ pytorch.Tensor = class {
             result.push(indentation + ']');
             return result.join('\n');
         }
-        if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
-            return indentation + value.toString();
-        }
-        if (typeof value == 'string') {
-            return indentation + value;
-        }
-        if (value == Infinity) {
-            return indentation + 'Infinity';
-        }
-        if (value == -Infinity) {
-            return indentation + '-Infinity';
-        }
-        if (isNaN(value)) {
-            return indentation + 'NaN';
+        switch (typeof value) {
+            case 'string':
+                return indentation + value;
+            case 'number':
+                if (value == Infinity) {
+                    return indentation + 'Infinity';
+                }
+                if (value == -Infinity) {
+                    return indentation + '-Infinity';
+                }
+                if (isNaN(value)) {
+                    return indentation + 'NaN';
+                }
+                return indentation + value.toString();
+            default:
+                if (value && value.toString) {
+                    return indentation + value.toString();
+                }
+                return indentation + '(undefined)';
         }
-        return indentation + value.toString();
     }
 };
 
@@ -1833,20 +1849,20 @@ pytorch.Execution = class extends python.Execution {
                 this._device = null;
             }
             get device() {
-                return null;
+                return this._device;
             }
             get dtype() {
                 return this._dtype;
             }
-            get data() {
-                return this._cdata;
-            }
             element_size() {
                 return this._dtype.element_size;
             }
             size() {
                 return this._size;
             }
+            get data() {
+                return this._cdata;
+            }
             _set_cdata(data) {
                 const length = this.size() * this.dtype.itemsize();
                 if (length !== data.length) {
@@ -1876,6 +1892,33 @@ pytorch.Execution = class extends python.Execution {
                 return storage;
             }
         });
+        this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase {
+            constructor() {
+                super();
+                throw new python.Error('_UntypedStorage not implemented.');
+            }
+        });
+        this.registerType('torch.storage._TypedStorage', class {
+            constructor() {
+                throw new python.Error('_TypedStorage not implemented.');
+            }
+        });
+        this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage {
+            constructor() {
+                super();
+                throw new python.Error('_LegacyStorage not implemented.');
+            }
+        });
+        this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase {
+            constructor(size) {
+                super(size, torch.complex64);
+            }
+        });
+        this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase {
+            constructor(size) {
+                super(size, torch.complex128);
+            }
+        });
         this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase {
             constructor(size) {
                 super(size, torch.bool);
@@ -2058,6 +2101,8 @@ pytorch.Execution = class extends python.Execution {
         this.registerType('torch.HalfTensor', class extends torch.Tensor {});
         this.registerType('torch.FloatTensor', class extends torch.Tensor {});
         this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
+        this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {});
+        this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {});
         this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
         this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
         this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});

+ 7 - 0
test/models.json

@@ -4155,6 +4155,13 @@
     "format":   "TorchScript v1.0",
     "link":     "https://github.com/ApolloAuto/apollo"
   },
+  {
+    "type":     "pytorch",
+    "target":   "complex_tensor.pt",
+    "source":   "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]",
+    "format":   "PyTorch v1.6",
+    "link":     "https://github.com/lutzroeder/netron/issues/720"
+  },
   {
     "type":     "pytorch",
     "target":   "d2go.pt",