فهرست منبع

Update view.js (#637)

Lutz Roeder 1 سال پیش
والد
کامیت
7648424162
6فایلهای تغییر یافته به همراه148 افزوده شده و 173 حذف شده
  1. 24 27
      source/pickle.js
  2. 28 21
      source/pytorch.js
  3. 25 31
      source/sklearn.js
  4. 55 80
      source/view.js
  5. 1 1
      test/models.json
  6. 15 13
      test/worker.js

+ 24 - 27
source/pickle.js

@@ -94,7 +94,6 @@ pickle.Node = class {
         this.name = name || '';
         this.inputs = [];
         this.outputs = [];
-        this.attributes = [];
         const isArray = (obj) => {
             return obj && obj.__class__ &&
                 ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
@@ -111,8 +110,8 @@ pickle.Node = class {
             return false;
         };
         if (type === 'builtins.bytearray') {
-            const attribute = new pickle.Argument('value', Array.from(obj), 'byte[]');
-            this.attributes.push(attribute);
+            const argument = new pickle.Argument('value', Array.from(obj), 'byte[]');
+            this.inputs.push(argument);
         } else {
             const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
             for (const [name, value] of entries) {
@@ -120,29 +119,29 @@ pickle.Node = class {
                     continue;
                 } else if (value && isArray(value)) {
                     const tensor = new pickle.Tensor(value);
-                    const attribute = new pickle.Argument(name, tensor, 'tensor');
-                    this.attributes.push(attribute);
+                    const argument = new pickle.Argument(name, tensor, 'tensor');
+                    this.inputs.push(argument);
                 } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => isArray(obj))) {
                     const tensors = value.map((obj) => new pickle.Tensor(obj));
-                    const attribute = new pickle.Argument(name, tensors, 'tensor[]');
-                    this.attributes.push(attribute);
+                    const argument = new pickle.Argument(name, tensors, 'tensor[]');
+                    this.inputs.push(argument);
                 } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && (value.__class__.__name__ === 'function' || value.__class__.__name__ === 'type')) {
                     const obj = {};
                     obj.__class__ = value;
                     const node = new pickle.Node(obj, '', stack);
-                    const attribute = new pickle.Argument(name, node, 'object');
-                    this.attributes.push(attribute);
+                    const argument = new pickle.Argument(name, node, 'object');
+                    this.inputs.push(argument);
                 } else if (isByteArray(value)) {
-                    const attribute = new pickle.Argument(name, Array.from(value), 'byte[]');
-                    this.attributes.push(attribute);
+                    const argument = new pickle.Argument(name, Array.from(value), 'byte[]');
+                    this.inputs.push(argument);
                 } else {
                     stack = stack || new Set();
                     if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
-                        const attribute = new pickle.Argument(name, value, 'string[]');
-                        this.attributes.push(attribute);
+                        const argument = new pickle.Argument(name, value, 'string[]');
+                        this.inputs.push(argument);
                     } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
-                        const attribute = new pickle.Argument(name, value);
-                        this.attributes.push(attribute);
+                        const argument = new pickle.Argument(name, value, 'attribute');
+                        this.inputs.push(argument);
                     } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
                         const values = value.filter((value) => !stack.has(value));
                         const nodes = values.map((value) => {
@@ -151,19 +150,17 @@ pickle.Node = class {
                             stack.delete(value);
                             return node;
                         });
-                        const attribute = new pickle.Argument(name, nodes, 'object[]');
-                        this.attributes.push(attribute);
-                    } else if (value && (value.__class__ || isObject(value))) {
-                        if (!stack.has(value)) {
-                            stack.add(value);
-                            const node = new pickle.Node(value, '', stack);
-                            const attribute = new pickle.Argument(name, node, 'object');
-                            this.attributes.push(attribute);
-                            stack.delete(value);
-                        }
+                        const argument = new pickle.Argument(name, nodes, 'object[]');
+                        this.inputs.push(argument);
+                    } else if (value && (value.__class__ || isObject(value)) && !stack.has(value)) {
+                        stack.add(value);
+                        const node = new pickle.Node(value, '', stack);
+                        const argument = new pickle.Argument(name, node, 'object');
+                        this.inputs.push(argument);
+                        stack.delete(value);
                     } else {
-                        const attribute = new pickle.Argument(name, value);
-                        this.attributes.push(attribute);
+                        const argument = new pickle.Argument(name, value, 'attribute');
+                        this.inputs.push(argument);
                     }
                 }
             }

+ 28 - 21
source/pytorch.js

@@ -292,7 +292,7 @@ pytorch.Node = class {
         };
         const createAttribute = (metadata, name, value) => {
             let visible = true;
-            let type = null;
+            let type = 'attribute';
             if (name === 'training') {
                 visible = false;
                 type = 'boolean';
@@ -386,18 +386,22 @@ pytorch.Node = class {
                 const argument = new pytorch.Argument(name, values, null, visible);
                 this.inputs.push(argument);
             }
-            this.attributes = Array.from(attributes).map(([name, value]) => {
+            for (const [name, value] of attributes) {
                 const type = this.type.identifier;
                 if (pytorch.Utility.isTensor(value)) {
                     const tensor = new pytorch.Tensor('', value);
-                    return new pytorch.Argument(name, tensor, 'tensor');
+                    const argument = new pytorch.Argument(name, tensor, 'tensor');
+                    this.inputs.push(argument);
                 } else if (Array.isArray(value) && value.every((value) => pytorch.Utility.isTensor(value))) {
                     const tensors = value.map((value) => new pytorch.Tensor('', value));
-                    return new pytorch.Argument(name, tensors, 'tensor[]');
+                    const argument = new pytorch.Argument(name, tensors, 'tensor[]');
+                    this.inputs.push(argument);
                 } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) {
-                    return new pytorch.Argument(name, value, 'string[]');
+                    const argument = new pytorch.Argument(name, value, 'string[]');
+                    this.inputs.push(argument);
                 } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) {
-                    return new pytorch.Argument(name, value);
+                    const argument = new pytorch.Argument(name, value, 'attribute');
+                    this.inputs.push(argument);
                 } else if (name === '_modules' && value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' &&
                     value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
                     const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
@@ -407,7 +411,8 @@ pytorch.Node = class {
                         stack.delete(value);
                         return node;
                     });
-                    return new pytorch.Argument(name, values, 'object[]');
+                    const argument = new pytorch.Argument(name, values, 'object[]');
+                    this.inputs.push(argument);
                 } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
                     const values = value.filter((value) => !stack.has(value));
                     const nodes = values.map((value) => {
@@ -420,21 +425,23 @@ pytorch.Node = class {
                         stack.delete(value);
                         return node;
                     });
-                    return new pytorch.Argument(name, nodes, 'object[]');
-                } else if (value && (value.__class__ || isObject(value))) {
-                    if (!stack.has(value)) {
-                        stack.add(value);
-                        const item = {
-                            type: value.__class__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : 'builtins.object',
-                            obj: value
-                        };
-                        const node = new pytorch.Node(metadata, group, item, initializers, values, stack);
-                        stack.delete(value);
-                        return new pytorch.Argument(name, node, 'object');
-                    }
+                    const argument = new pytorch.Argument(name, nodes, 'object[]');
+                    this.inputs.push(argument);
+                } else if (value && (value.__class__ || isObject(value)) && !stack.has(value)) {
+                    stack.add(value);
+                    const item = {
+                        type: value.__class__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : 'builtins.object',
+                        obj: value
+                    };
+                    const node = new pytorch.Node(metadata, group, item, initializers, values, stack);
+                    stack.delete(value);
+                    const argument = new pytorch.Argument(name, node, 'object');
+                    this.inputs.push(argument);
+                } else {
+                    const argument = createAttribute(metadata.attribute(type, name), name, value);
+                    this.inputs.push(argument);
                 }
-                return createAttribute(metadata.attribute(type, name), name, value);
-            });
+            }
         } else {
             this.attributes = [];
             this.inputs = [];

+ 25 - 31
source/sklearn.js

@@ -202,7 +202,6 @@ sklearn.Node = class {
         this.type = metadata.type(type) || { name: type };
         this.inputs = inputs.map((input) => new sklearn.Argument(input, [values.map(input)]));
         this.outputs = outputs.map((output) => new sklearn.Argument(output, [values.map(output)]));
-        this.attributes = [];
         const isArray = (obj) => {
             return obj && obj.__class__ &&
                 ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
@@ -220,7 +219,7 @@ sklearn.Node = class {
         };
         if (type === 'builtins.bytearray') {
             const attribute = new sklearn.Argument('value', Array.from(obj), 'byte[]');
-            this.attributes.push(attribute);
+            this.inputs.push(attribute);
         } else {
             const entries = Object.entries(obj);
             for (const [name, value] of entries) {
@@ -228,29 +227,29 @@ sklearn.Node = class {
                     continue;
                 } else if (value && isArray(value)) {
                     const tensor = new sklearn.Tensor(value);
-                    const attribute = new sklearn.Argument(name, tensor, 'tensor');
-                    this.attributes.push(attribute);
+                    const argument = new sklearn.Argument(name, tensor, 'tensor');
+                    this.inputs.push(argument);
                 } else if (Array.isArray(value) && value.length > 0 && value.every((obj) => isArray(obj))) {
                     const tensors = value.map((obj) => new sklearn.Tensor(obj));
-                    const attribute = new sklearn.Argument(name, tensors, 'tensor[]');
-                    this.attributes.push(attribute);
+                    const argument = new sklearn.Argument(name, tensors, 'tensor[]');
+                    this.inputs.push(argument);
                 } else if (isByteArray(value)) {
-                    const attribute = new sklearn.Argument(name, Array.from(value), 'byte[]');
-                    this.attributes.push(attribute);
+                    const argument = new sklearn.Argument(name, Array.from(value), 'byte[]');
+                    this.inputs.push(argument);
                 } else {
                     stack = stack || new Set();
                     if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'string')) {
-                        const attribute = new sklearn.Argument(name, value, 'string[]');
-                        this.attributes.push(attribute);
+                        const argument = new sklearn.Argument(name, value, 'string[]');
+                        this.inputs.push(argument);
                     } else if (value && Array.isArray(value) && value.every((obj) => typeof obj === 'number')) {
-                        const attribute = new sklearn.Argument(name, value);
-                        this.attributes.push(attribute);
+                        const argument = new sklearn.Argument(name, value, 'attribute');
+                        this.inputs.push(argument);
                     } else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && (value.__class__.__name__ === 'function' || value.__class__.__name__ === 'type')) {
                         const obj = {};
                         obj.__class__ = value;
                         const node = new sklearn.Node(metadata, group, '', obj, [], [], null, stack);
-                        const attribute = new sklearn.Argument(name, node, 'object');
-                        this.attributes.push(attribute);
+                        const argument = new sklearn.Argument(name, node, 'object');
+                        this.inputs.push(argument);
                     } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
                         const values = value.filter((value) => !stack.has(value));
                         const nodes = values.map((value) => {
@@ -259,21 +258,19 @@ sklearn.Node = class {
                             stack.delete(value);
                             return node;
                         });
-                        const attribute = new sklearn.Argument(name, nodes, 'object[]');
-                        this.attributes.push(attribute);
-                    } else if (value && (value.__class__ || isObject(value))) {
-                        if (!stack.has(value)) {
-                            stack.add(value);
-                            const node = new sklearn.Node(metadata, group, '', value, [], [], null, stack);
-                            const attribute = new sklearn.Argument(name, node, 'object');
-                            this.attributes.push(attribute);
-                            stack.delete(value);
-                        }
+                        const argument = new sklearn.Argument(name, nodes, 'object[]');
+                        this.inputs.push(argument);
+                    } else if (value && (value.__class__ || isObject(value)) && !stack.has(value)) {
+                        stack.add(value);
+                        const node = new sklearn.Node(metadata, group, '', value, [], [], null, stack);
+                        const argument = new sklearn.Argument(name, node, 'object');
+                        this.inputs.push(argument);
+                        stack.delete(value);
                     } else {
+                        let type = 'attribute';
+                        let visible = true;
                         const schema = metadata.attribute(type, name);
                         if (schema) {
-                            let type = '';
-                            let visible = true;
                             if (schema.type) {
                                 type = schema.type;
                             }
@@ -290,12 +287,9 @@ sklearn.Node = class {
                                     visible = value !== schema.default;
                                 }
                             }
-                            const attribute = new sklearn.Argument(name, value, type, visible);
-                            this.attributes.push(attribute);
-                        } else {
-                            const attribute = new sklearn.Argument(name, value);
-                            this.attributes.push(attribute);
                         }
+                        const argument = new sklearn.Argument(name, value, type, visible);
+                        this.inputs.push(argument);
                     }
                 }
             }

+ 55 - 80
source/view.js

@@ -1810,12 +1810,12 @@ view.Graph = class extends grapher.Graph {
             const viewNode = this.createNode(node);
             this.setNode(viewNode);
             const inputs = node.inputs;
-            for (const input of inputs) {
-                if (!input.type || input.type.endsWith('*')) {
-                    if (Array.isArray(input.value) && input.value.length === 1 && input.value[0].initializer) {
-                        this.createArgument(input);
+            for (const argument of inputs) {
+                if (!argument.type || argument.type.endsWith('*')) {
+                    if (Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) {
+                        this.createArgument(argument);
                     } else {
-                        for (const value of input.value) {
+                        for (const value of argument.value) {
                             if (value.name !== '' && !value.initializer) {
                                 this.createValue(value).to.push(viewNode);
                             }
@@ -2008,90 +2008,62 @@ view.Node = class extends grapher.Node {
             return current;
         };
         let hiddenTensors = false;
-        const tensors = [];
         const objects = [];
-        const attributes = [];
+        const attribute = (argument) => {
+            let content = new view.Formatter(argument.value, argument.type).toString();
+            if (content && content.length > 12) {
+                content = `${content.substring(0, 12)}\u2026`;
+            }
+            const item = list().argument(argument.name, content);
+            item.tooltip = argument.type;
+            item.separator = ' = ';
+            return item;
+        };
         if (Array.isArray(node.inputs)) {
-            for (const input of node.inputs) {
-                switch (input.type) {
-                    case 'graph':
-                    case 'object':
-                    case 'object[]':
-                    case 'function':
-                    case 'function[]': {
-                        objects.push(input);
-                        break;
-                    }
-                    default: {
-                        if (options.weights && input.visible !== false && input.value.length === 1 && input.value[0].initializer) {
-                            tensors.push(input);
-                        } else if (options.weights && (input.visible === false || input.value.length > 1) && (!input.type || input.type.endsWith('*')) && input.value.some((value) => value.initializer)) {
-                            hiddenTensors = true;
-                        } else if (options.attributes && input.visible !== false && input.type && !input.type.endsWith('*')) {
-                            attributes.push(input);
-                        }
-                    }
+            for (const argument of node.inputs) {
+                const type = argument.type;
+                if (type === 'graph' || type === 'object' || type === 'object[]' || type === 'function' || type === 'function[]') {
+                    objects.push(argument);
+                } else if (options.weights && argument.visible !== false && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) {
+                    const item = this.context.createArgument(argument);
+                    list().add(item);
+                } else if (options.weights && (argument.visible === false || Array.isArray(argument.value) && argument.value.length > 1) && (!argument.type || argument.type.endsWith('*')) && argument.value.some((value) => value.initializer)) {
+                    hiddenTensors = true;
+                } else if (options.attributes && argument.visible !== false && argument.type && !argument.type.endsWith('*')) {
+                    const item = attribute(argument);
+                    list().add(item);
                 }
             }
         }
         if (Array.isArray(node.attributes)) {
-            for (const attribute of node.attributes) {
-                switch (attribute.type) {
-                    case 'graph':
-                    case 'object':
-                    case 'object[]':
-                    case 'function':
-                    case 'function[]': {
-                        objects.push(attribute);
-                        break;
-                    }
-                    default: {
-                        if (options.attributes && attribute.visible !== false) {
-                            attributes.push(attribute);
-                        }
-                    }
+            const attributes = node.attributes.slice();
+            attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase()));
+            for (const argument of node.attributes) {
+                const type = argument.type;
+                if (type === 'graph' || type === 'object' || type === 'object[]' || type === 'function' || type === 'function[]') {
+                    objects.push(argument);
+                } else if (options.attributes && argument.visible !== false) {
+                    const item = attribute(argument);
+                    list().add(item);
                 }
             }
         }
-        if (attributes.length > 0) {
-            attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase()));
-        }
-        for (const argument of tensors) {
-            const item = this.context.createArgument(argument);
-            list().add(item);
-        }
         if (hiddenTensors) {
             const item = list().argument('\u3008\u2026\u3009', '');
             list().add(item);
         }
-        for (const attribute of attributes) {
-            if (attribute.visible !== false) {
-                let value = new view.Formatter(attribute.value, attribute.type).toString();
-                if (value && value.length > 12) {
-                    value = `${value.substring(0, 12)}\u2026`;
-                }
-                const item = list().argument(attribute.name, value);
-                item.tooltip = attribute.type;
-                item.separator = ' = ';
-                list().add(item);
-            }
-        }
         for (const argument of objects) {
-            if (argument.type === 'graph') {
-                const node = this.context.createNode(null, argument.value);
-                const item = list().argument(argument.name, node);
-                list().add(item);
-            }
-            if (argument.type === 'function' || argument.type === 'object') {
-                const node = this.context.createNode(argument.value);
-                const item = list().argument(argument.name, node);
-                list().add(item);
-            }
-            if (argument.type === 'function[]' || argument.type === 'object[]') {
-                const nodes = argument.value.map((value) => this.context.createNode(value));
-                const item = list().argument(argument.name, nodes);
-                list().add(item);
-            }
+            const type = argument.type;
+            let content = null;
+            if (type === 'graph') {
+                content = this.context.createNode(null, argument.value);
+            } else if (type === 'function' || argument.type === 'object') {
+                content = this.context.createNode(argument.value);
+            } else if (type === 'function[]' || argument.type === 'object[]') {
+                content = argument.value.map((value) => this.context.createNode(value));
+            }
+            const item = list().argument(argument.name, content);
+            list().add(item);
         }
         if (Array.isArray(node.nodes) && node.nodes.length > 0) {
             // this.canvas = this.canvas();
@@ -2796,19 +2768,22 @@ view.ArgumentView = class extends view.Control {
         this._source = source;
         this._elements = [];
         this._items = [];
-        const type = argument.type;
+        const type = argument.type === 'attribute' ? null : argument.type;
         let value = argument.value;
+        if (argument.type === 'attribute') {
+            this._source = 'attribute';
+        }
         if (argument.type === 'tensor') {
             value = [{ type: value.type, initializer: value }];
         } else if (argument.type === 'tensor[]') {
             value = value.map((value) => ({ type: value.type, initializer: value }));
         }
-        source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : source;
-        if (source === 'attribute' && type !== 'tensor' && type !== 'tensor[]') {
+        this._source = typeof type === 'string' && !type.endsWith('*') ? 'attribute' : this._source;
+        if (this._source === 'attribute' && type !== 'tensor' && type !== 'tensor[]') {
             this._source = 'attribute';
             const item = new view.PrimitiveView(context, argument);
             this._items.push(item);
-        } else if (value.length === 0) {
+        } else if (Array.isArray(value) && value.length === 0) {
             const item = new view.TextView(this._view, null);
             this._items.push(item);
         } else {
@@ -2863,7 +2838,7 @@ view.PrimitiveView = class extends view.Expander {
         super(context);
         try {
             this._argument = argument;
-            const type = argument.type;
+            const type = argument.type === 'attribute' ? null : argument.type;
             const value = argument.value;
             if (type) {
                 this.enable();

+ 1 - 1
test/models.json

@@ -5861,7 +5861,7 @@
     "target":   "tensor-and-integer-in-tuple.pt",
     "source":   "https://github.com/user-attachments/files/15879202/tensor-and-integer-in-tuple.pt.zip[tensor-and-integer-in-tuple.pt]",
     "format":   "PyTorch v1.6",
-    "assert":   "model.graphs[0].nodes[0].attributes[0].value == 234",
+    "assert":   "model.graphs[0].nodes[0].inputs[1].value == 234",
     "link":     "https://github.com/lutzroeder/netron/issues/543"
   },
   {

+ 15 - 13
test/worker.js

@@ -720,20 +720,22 @@ export class Target {
                 view.Documentation.open(type);
                 node.name.toString();
                 node.description;
-                node.attributes.slice();
-                for (const attribute of node.attributes) {
-                    attribute.name.toString();
-                    attribute.name.length;
-                    const type = attribute.type;
-                    const value = attribute.value;
-                    if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
-                        validateGraph(value);
-                    } else {
-                        let text = new view.Formatter(attribute.value, attribute.type).toString();
-                        if (text && text.length > 1000) {
-                            text = `${text.substring(0, 1000)}...`;
+                const attributes = node.attributes;
+                if (attributes) {
+                    for (const attribute of attributes) {
+                        attribute.name.toString();
+                        attribute.name.length;
+                        const type = attribute.type;
+                        const value = attribute.value;
+                        if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
+                            validateGraph(value);
+                        } else {
+                            let text = new view.Formatter(attribute.value, attribute.type).toString();
+                            if (text && text.length > 1000) {
+                                text = `${text.substring(0, 1000)}...`;
+                            }
+                            /* value = */ text.split('<');
                         }
-                        /* value = */ text.split('<');
                     }
                 }
                 for (const input of node.inputs) {