Explorar o código

Update paddle.js

Lutz Roeder %!s(int64=3) %!d(string=hai) anos
pai
achega
48384e2f7f
Modificáronse 2 ficheiros con 19 adicións e 4 borrados
  1. 9 3
      source/paddle.js
  2. 10 1
      source/python.js

+ 9 - 3
source/paddle.js

@@ -822,7 +822,10 @@ paddle.Utility = class {
         const map = null; // this._data['StructuredToParameterName@@'];
         for (const entry of Object.entries(obj)) {
             const key = entry[0];
-            const value = entry[1];
+            let value = entry[1];
+            if (Array.isArray(value) && value.length === 2 && value[0] === key) {
+                value = value[1];
+            }
             if (paddle.Utility.isNumPyArray(value)) {
                 const name = map ? map[key] : key;
                 const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
@@ -833,7 +836,10 @@ paddle.Utility = class {
         }
     }
 
-    static isNumPyArray(value) {
+    static isNumPyArray(value, name) {
+        if (Array.isArray(value) && value.length === 2 && value[0] === name) {
+            value = value[1];
+        }
         return value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray';
     }
 };
@@ -898,7 +904,7 @@ paddle.Pickle = class {
     static open(context) {
         const obj = context.open('pkl');
         if (obj && !Array.isArray(obj) && Object(obj) === obj &&
-            Object.entries(obj).filter((entry) => paddle.Utility.isNumPyArray(entry[1])).length > 0) {
+            Object.entries(obj).filter((entry) => paddle.Utility.isNumPyArray(entry[1], entry[0])).length > 0) {
             return new paddle.Pickle(obj);
         }
         return null;

+ 10 - 1
source/python.js

@@ -1650,7 +1650,16 @@ python.Execution = class {
         this.registerType('builtins.float', class {});
         this.registerType('builtins.object', class {});
         this.registerType('builtins.str', class {});
-        this.registerType('builtins.tuple', class {});
+        this.registerType('builtins.tuple', class extends Array {
+            constructor(items) {
+                super(items ? items.length : 0);
+                if (items) {
+                    for (let i = 0; i < items.length; i++) {
+                        this[i] = items[i];
+                    }
+                }
+            }
+        });
         this.registerType('typing._Final', class {});
         this.registerType('typing._SpecialForm', class extends typing._Final {});
         this.registerType('typing._BaseGenericAlias', class extends typing._Final {});