Jelajahi Sumber

Update pytorch.js (#546)

Lutz Roeder 3 tahun lalu
induk
melakukan
200003e293
2 mengubah file dengan 42 tambahan dan 10 penghapusan
  1. 14 1
      source/python.js
  2. 28 9
      source/pytorch.js

+ 14 - 1
source/python.js

@@ -3124,7 +3124,14 @@ python.Execution = class {
                         if (target.target.value === '__annotations__') {
                             context.set(target.target.value, context.get(target.target.value) || {});
                         }
-                        context.get(target.target.value)[index] = this.expression(expression.expression, context);
+                        const obj = context.get(target.target.value);
+                        const value = this.expression(expression.expression, context);
+                        if (obj instanceof Map) {
+                            obj.set(index, value);
+                        }
+                        else {
+                            obj[index] = value;
+                        }
                         return undefined;
                     }
                 }
@@ -3168,6 +3175,9 @@ python.Execution = class {
                     if (context.get(expression.target.value)) {
                         const index = this.expression(expression.arguments.value[0], context);
                         const target = context.get(expression.target.value);
+                        if (target instanceof Map) {
+                            return target.get(index);
+                        }
                         return target[index < 0 ? target.length + index : index];
                     }
                 }
@@ -3182,6 +3192,9 @@ python.Execution = class {
                 }
                 if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) {
                     const index = this.expression(expression.arguments.value[0], context);
+                    if (target instanceof Map) {
+                        return target.get(index);
+                    }
                     return target[index < 0 ? target.length + index : index];
                 }
                 break;

+ 28 - 9
source/pytorch.js

@@ -2774,22 +2774,41 @@ pytorch.Container.Zip = class {
                                     tensor.__origin__ = 'graph-input';
                                     return tensor;
                                 }
-                                case 'Tuple':
+                                case 'Tuple': {
                                     return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']'));
-                                case 'List':
+                                }
+                                case 'List': {
                                     return type.arguments.map((type, index) => defaultValue(type, name + '[' + index.toString() + ']' ));
-                                case 'Dict':
-                                    return {};
-                                case 'int':
+                                }
+                                case 'Dict': {
+                                    if (type.arguments[1].name.value === 'Tensor') {
+                                        const Dict = class extends Map {
+                                            get(key) {
+                                                if (!super.has(key)) {
+                                                    super.set(key, defaultValue(type.arguments[1], name + ':' + key));
+                                                }
+                                                return super.get(key);
+                                            }
+                                        };
+                                        return new Dict();
+                                    }
+                                    return new Map();
+                                }
+                                case 'int': {
                                     return 0;
-                                case 'float':
+                                }
+                                case 'float': {
                                     return 0.0;
-                                case 'bool':
+                                }
+                                case 'bool': {
                                     return false;
-                                case 'Optional':
+                                }
+                                case 'Optional': {
                                     return undefined;
-                                default:
+                                }
+                                default: {
                                     break;
+                                }
                             }
                         }
                         throw new pytorch.Error("Unsupported function parameter type '" + JSON.stringify(type) + "'.");