Просмотр исходного кода

Fix TorchScript function undefined tensor (#546)

Lutz Roeder 5 лет назад
Родитель
Сommit
38e8975f44
1 измененных файлов с 26 добавлено и 2 удалено
  1. 26 2
      src/pytorch.js

+ 26 - 2
src/pytorch.js

@@ -131,6 +131,30 @@ pytorch.Graph = class {
             if (container.data) {
                 this._loadScriptModule(metadata, container, container.data, initializers);
             }
+            if (container.constants) {
+                const obj = {
+                    type: 'torch.nn.Constants',
+                    attributes: [],
+                    inputs: [],
+                    outputs: [],
+                };
+                let index = 0;
+                for (const constant of container.constants) {
+                    if (constant.__variable__ && constant.__count__ > 1 && constant.storage) {
+                        const initializer = new pytorch.Tensor(constant.name, constant, true);
+                        obj.inputs.push(new pytorch.Parameter('c' + index.toString(), false, [
+                            new pytorch.Argument(constant.__variable__, initializer.type, initializer)
+                        ]));
+                        obj.outputs.push(new pytorch.Parameter('c' + index.toString(), false, [
+                            new pytorch.Argument(constant.__variable__)
+                        ]));
+                    }
+                    index++;
+                }
+                if (obj.inputs.length > 0) {
+                    this._nodes.push(new pytorch.Node(metadata, '', obj, null));
+                }
+            }
         }
         else if (container.data) {
             const data = container.data;
@@ -2953,7 +2977,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                         switch (parameter.type) {
                             case 'tensor': {
                                 let argument = copyEvalArgs[0];
-                                if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null)) {
+                                if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
                                     if (parameter.optional) {
                                         if (argument === undefined) {
                                             copyArgs.shift();
@@ -2966,7 +2990,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                 }
                                 copyArgs.shift();
                                 copyEvalArgs.shift();
-                                if (argument === null) {
+                                if (argument === null || argument === undefined) {
                                     argument = {};
                                 }
                                 if (!argument.__variable__) {