Sfoglia il codice sorgente

CoreML enum support (#193)

Lutz Roeder 5 anni fa
parent
commit
1fda494f93
1 ha cambiato i file con 28 aggiunte e 9 eliminazioni
  1. 28 9
      src/coreml.js

+ 28 - 9
src/coreml.js

@@ -693,6 +693,7 @@ coreml.Node = class {
                 this._initializer(initializers, 'Weights', 'weights', [ data.inputDim, data.outputChannels ], data.weights);
                 return { 'weights': true };
             case 'loadConstant':
+            case 'loadConstantND':
                 this._initializer(initializers, 'Weights', 'data', data.shape, data.data);
                 return { 'data': true };
             case 'scale':
@@ -809,16 +810,8 @@ coreml.Attribute = class {
                 this._type = schema.type;
             }
             if (this._type && coreml.proto) {
-                let type = coreml.proto;
-                const parts = this._type.split('.');
-                while (type && parts.length > 0) {
-                    type = type[parts.shift()];
-                }
-                if (type && type[this._value]) {
-                    this._value = type[this.value];
-                }
+                this._value = coreml.Utility.enum(this._type, this._value);
             }
-
             if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
                 this._visible = false;
             }
@@ -1099,6 +1092,32 @@ coreml.OptionalType = class {
     }
 };
 
+coreml.Utility = class {
+
+    static enum(name, value) {
+        let type = coreml.proto;
+        const parts = name.split('.');
+        while (type && parts.length > 0) {
+            type = type[parts.shift()];
+        }
+        if (type) {
+            coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
+            if (!coreml.Utility._enumKeyMap.has(name)) {
+                const map = new Map();
+                for (const key of Object.keys(type)) {
+                    map.set(type[key], key);
+                }
+                coreml.Utility._enumKeyMap.set(name, map);
+            }
+            const map = coreml.Utility._enumKeyMap.get(name);
+            if (map.has(value)) {
+                return map.get(value);
+            }
+        }
+        return value;
+    }
+};
+
 coreml.Metadata = class {
 
     static open(host) {