Explorar o código

Add Core ML test file (#832)

Lutz Roeder %!s(int64=4) %!d(string=hai) anos
pai
achega
b146f93803
Modificáronse 2 ficheiros con 41 adicións e 13 borrados
  1. 34 13
      source/coreml.js
  2. 7 0
      test/models.json

+ 34 - 13
source/coreml.js

@@ -740,24 +740,42 @@ coreml.Graph = class {
                     else {
                         argument.value = data;
                     }
+                    argument.const = true;
                     op.delete = true;
                 }
             }
         }
 
+        for (const op of operations) {
+            for (const input of op.inputs) {
+                if (input.arguments.length > 1 && input.arguments.some((argument) => argument.const)) {
+                    if (input.arguments.every((argument) => argument.value instanceof coreml.Tensor)) {
+                        continue;
+                    }
+                    for (const argument of input.arguments) {
+                        for (const from of argument.from) {
+                            from.delete = false;
+                        }
+                        delete argument.value;
+                    }
+                }
+            }
+        }
+
         for (const op of operations) {
             if (op.delete) {
                 continue;
             }
             op.inputs = op.inputs.filter((input) => {
-                if (input.arguments.length !== 1) {
+                if (input.arguments.every((argument) => argument.value === undefined || argument.value instanceof coreml.Tensor)) {
                     return true;
                 }
-                const argument = input.arguments[0];
-                if (argument.value === undefined || argument.value instanceof coreml.Tensor) {
-                    return true;
+                if (input.arguments.length === 1) {
+                    const argument = input.arguments[0];
+                    op.attributes[input.name] = argument.value;
+                    return false;
                 }
-                op.attributes[input.name] = argument.value;
+                op.attributes[input.name] = input.arguments.map((argument) => argument.value[0]);
                 return false;
             });
         }
@@ -1127,24 +1145,27 @@ coreml.Node = class {
 
 coreml.Attribute = class {
 
-    constructor(schema, name, value) {
+    constructor(metadata, name, value) {
         this._name = name;
         this._value = value;
-        if (schema) {
-            if (schema.type) {
-                this._type = schema.type;
+        if (this._value instanceof coreml.Tensor) {
+            this._type = 'tensor';
+        }
+        if (metadata) {
+            if (metadata.type) {
+                this._type = metadata.type;
             }
             if (this._type && coreml.proto) {
                 this._value = coreml.Utility.enum(this._type, this._value);
             }
-            if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
+            if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
                 this._visible = false;
             }
-            else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
+            else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
                 if (Array.isArray(value)) {
                     value = value.map((item) => item.toNumber());
                 }
-                if (JSON.stringify(schema.default) == JSON.stringify(value)) {
+                if (JSON.stringify(metadata.default) == JSON.stringify(value)) {
                     this._visible = false;
                 }
             }
@@ -1530,7 +1551,7 @@ coreml.Utility = class {
             case 'tensorType':
                 return coreml.Utility.tensorType(type.tensorType);
             case 'listType':
-                return new coreml.ListType(coreml.Utility.tensorType(type.listType.tensorType));
+                return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
             default:
                 throw new coreml.Error("Unsupported value type '" + type.type + "'.");
         }

+ 7 - 0
test/models.json

@@ -1420,6 +1420,13 @@
     "format": "Core ML v4",
     "link":   "https://github.com/lutzroeder/netron/issues/193"
   },
+  {
+    "type":   "coreml",
+    "target": "lstm_model.mlpackage.zip",
+    "source": "https://github.com/lutzroeder/netron/files/7406821/lstm_model.mlpackage.zip",
+    "format": "Core ML Package v6",
+    "link":   "https://github.com/lutzroeder/netron/issues/832"
+  },
   {
     "type":   "coreml",
     "target": "MessageClassifier.mlmodel",