Explorar el Código

Add PyTorch .ot support (#686)

Lutz Roeder hace 4 años
padre
commit
7e286f19aa
Se han modificado 3 ficheros con 23 adiciones y 3 borrados
  1. 0 1
      source/python.js
  2. 22 1
      source/pytorch.js
  3. 1 1
      source/view.js

+ 0 - 1
source/python.js

@@ -2566,7 +2566,6 @@ python.Execution = class {
                 }
                 const obj = {};
                 obj.__proto__ = target;
-                obj.__class__ = target;
                 if (obj.__init__ && typeof obj.__init__ === 'function') {
                     obj.__init__.apply(obj, args);
                 }

+ 22 - 1
source/pytorch.js

@@ -2478,7 +2478,28 @@ pytorch.Container.Zip = class {
                 }
             }
             if (this.format.startsWith('TorchScript ')) {
-                this._type = 'script';
+                if (this._torchscriptArena || this._data.forward) {
+                    this._type = 'script';
+                }
+                else {
+                    if (!Object.entries(this._data).every((entry) => entry[0].indexOf('|') !== -1 && pytorch.Utility.isTensor(entry[1]))) {
+                        throw new pytorch.Error('File does not contain forward function or state dictionary.');
+                    }
+                    const layers = new Map();
+                    for (const entry of Object.entries(this._data)) {
+                        const key = entry[0].split('|');
+                        const value = entry[1];
+                        const parameterName = key.pop();
+                        const name = key.join('|');
+                        if (!layers.has(name)) {
+                            layers.set(name, { name: name, states: [] });
+                        }
+                        const layer = layers.get(name);
+                        layer.states.push({ name: parameterName, arguments: [ { id: '', value: value } ] });
+                    }
+                    this._type = 'weights';
+                    this._data = [ { layers: layers.values() } ];
+                }
             }
             else {
                 const obj = this._data;

+ 1 - 1
source/view.js

@@ -1550,7 +1550,7 @@ view.ModelFactoryService = class {
     constructor(host) {
         this._host = host;
         this._extensions = [];
-        this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel' ]);
+        this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.ot' ]);
         this.register('./onnx', [ '.onnx', '.onn', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx' ]);
         this.register('./mxnet', [ '.json', '.params' ]);
         this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json' ]);