ソースを参照

Update view.js (#1285)

Lutz Roeder 1 年間 前
コミット
69957d0494
1 ファイル変更201 行追加98 行削除
  1. 201 98
      source/view.js

+ 201 - 98
source/view.js

@@ -1026,26 +1026,6 @@ view.View = class {
                 sidebar.on('show-documentation', async (/* sender, e */) => {
                     await this.showDefinition(node.type);
                 });
-                sidebar.on('export-tensor', async (sender, tensor) => {
-                    const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
-                    const file = await this._host.save('NumPy Array', 'npy', defaultPath);
-                    if (file) {
-                        try {
-                            let data_type = tensor.type.dataType;
-                            data_type = data_type === 'boolean' ? 'bool' : data_type;
-                            const execution = new python.Execution();
-                            const bytes = execution.invoke('io.BytesIO', []);
-                            const dtype = execution.invoke('numpy.dtype', [data_type]);
-                            const array = execution.invoke('numpy.asarray', [tensor.value, dtype]);
-                            execution.invoke('numpy.save', [bytes, array]);
-                            bytes.seek(0);
-                            const blob = new Blob([bytes.read()], { type: 'application/octet-stream' });
-                            await this._host.export(file, blob);
-                        } catch (error) {
-                            this.error(error, 'Error saving NumPy tensor.', null);
-                        }
-                    }
-                });
                 sidebar.on('activate', (sender, value) => {
                     this._graph.select([value]);
                 });
@@ -2546,7 +2526,6 @@ view.NodeSidebar = class extends view.ObjectSidebar {
                 const name = input.name;
                 if (input.value.length > 0) {
                     const value = new view.ArgumentView(this._view, input);
-                    value.on('export-tensor', (sender, value) => this.emit('export-tensor', value));
                     value.on('activate', (sender, value) => this.emit('activate', value));
                     value.on('deactivate', (sender, value) => this.emit('deactivate', value));
                     value.on('select', (sender, value) => this.emit('select', value));
@@ -2579,7 +2558,6 @@ view.NodeSidebar = class extends view.ObjectSidebar {
         switch (attribute.type) {
             case 'tensor': {
                 value = new view.ValueView(this._view, { type: attribute.value.type, initializer: attribute.value }, '');
-                value.on('export-tensor', (sender, value) => this.emit('export-tensor', value));
                 break;
             }
             case 'tensor[]': {
@@ -2803,7 +2781,6 @@ view.ArgumentView = class extends view.Control {
         this._items = [];
         for (const value of argument.value) {
             const item = new view.ValueView(context, value);
-            item.on('export-tensor', (sender, value) => this.emit('export-tensor', value));
             item.on('activate', (sender, value) => this.emit('activate', value));
             item.on('deactivate', (sender, value) => this.emit('deactivate', value));
             item.on('select', (sender, value) => this.emit('select', value));
@@ -2873,7 +2850,8 @@ view.ValueView = class extends view.Control {
             } else if (this._hasCategory) {
                 this._bold('category', initializer.category);
             } else if (type) {
-                this._code('tensor', type.toString().split('<').join('&lt;').split('>').join('&gt;'));
+                const value = type.toString().split('<').join('&lt;').split('>').join('&gt;');
+                this._code('tensor', value);
             }
         } catch (error) {
             super.error(error, false);
@@ -2912,6 +2890,14 @@ view.ValueView = class extends view.Control {
                         descriptionLine.innerHTML = description;
                         this._element.appendChild(descriptionLine);
                     }
+                    const identifier = this._value.identifier;
+                    if (identifier !== undefined) {
+                        this._bold('identifier', identifier);
+                    }
+                    const layout = this._value.type ? this._value.type.layout : null;
+                    if (layout) {
+                        this._bold('layout', layout.replace('.', ' '));
+                    }
                     const quantization = this._value.quantization;
                     if (quantization) {
                         if (typeof quantization.type !== 'string') {
@@ -2930,27 +2916,16 @@ view.ValueView = class extends view.Control {
                             this._element.appendChild(line);
                         }
                     }
-                    const identifier = this._value.identifier;
-                    if (identifier !== undefined) {
-                        this._bold('identifier', identifier);
-                    }
-                    const layout = this._value.type ? this._value.type.layout : null;
-                    if (layout) {
-                        const layouts = new Map([
-                            ['sparse', 'sparse'],
-                            ['sparse.coo', 'sparse coo'],
-                            ['sparse.csr', 'sparse csr'],
-                            ['sparse.csc', 'sparse csc'],
-                            ['sparse.bsr', 'sparse bsr'],
-                            ['sparse.bsc', 'sparse bsc']
-                        ]);
-                        this._bold('layout', layouts.get(layout));
-                    }
                     if (initializer) {
                         if (initializer.location) {
                             this._bold('location', initializer.location);
                         }
-                        this._tensor(initializer);
+                        const stride = initializer.stride;
+                        if (Array.isArray(stride) && stride.length > 0) {
+                            this._code('stride', stride.join(','));
+                        }
+                        const tensor = new view.TensorView(this._view, initializer);
+                        tensor.tensor(this._element);
                     }
                 } catch (error) {
                     super.error(error, false);
@@ -2987,13 +2962,65 @@ view.ValueView = class extends view.Control {
         child.className = this._element.childNodes.length < 2 ? 'sidebar-item-value-line' : 'sidebar-item-value-line-border';
         this._element.appendChild(child);
     }
+};
+
+view.TensorView = class extends view.Control {
+
+    constructor(context, value) {
+        super(context);
+        this._value = value;
+    }
+
+    render() {
+        if (!this._element) {
+            this._element = this.createElement('div', 'sidebar-item-value');
+            this._expander = this.createElement('div', 'sidebar-item-value-expander');
+            this._expander.innerText = '+';
+            this._expander.addEventListener('click', () => {
+                try {
+                    this.toggle();
+                } catch (error) {
+                    super.error(error, false);
+                    this._info('ERROR', error.message);
+                }
+            });
+            this._element.appendChild(this._expander);
+            this._style = 'sidebar-item-value-line';
+            this._collapse();
+        }
+        return [this._element];
+    }
+
+    toggle() {
+        if (this._expander) {
+            while (this._element.childElementCount > 1) {
+                this._element.removeChild(this._element.lastChild);
+            }
+            if (this._expander.innerText === '+') {
+                this._expander.innerText = '-';
+                try {
+                    this.tensor(this._element);
+                } catch (error) {
+                    super.error(error, false);
+                    this._info('ERROR', error.message);
+                }
+            } else {
+                this._expander.innerText = '+';
+                this._collapse();
+            }
+        }
+    }
 
-    _tensor(value) {
+    _collapse() {
+        const line = this.createElement('div', this._style);
+        line.innerHTML = '\u2026';
+        this._element.appendChild(line);
+    }
+
+    tensor(element) {
+        const value = this._value;
         const contentLine = this.createElement('pre');
         const tensor = new view.Tensor(value);
-        if (Array.isArray(tensor.stride) && tensor.stride.length > 0) {
-            this._code('stride', tensor.stride.join(','));
-        }
         if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
             contentLine.innerHTML = `Tensor encoding '${tensor.layout}' is not implemented.`;
         } else if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
@@ -3011,15 +3038,37 @@ view.ValueView = class extends view.Control {
                 value.type.shape.dimensions.length > 0) {
                 this._saveButton = this.createElement('div', 'sidebar-item-value-expander');
                 this._saveButton.innerHTML = '&#x1F4BE;';
-                this._saveButton.addEventListener('click', () => {
-                    this.emit('export-tensor', tensor);
+                this._saveButton.addEventListener('click', async () => {
+                    await this.export();
                 });
-                this._element.appendChild(this._saveButton);
+                element.appendChild(this._saveButton);
             }
         }
-        const valueLine = this.createElement('div', 'sidebar-item-value-line-border');
+        const valueLine = this.createElement('div', this._style || 'sidebar-item-value-line-border');
         valueLine.appendChild(contentLine);
-        this._element.appendChild(valueLine);
+        element.appendChild(valueLine);
+    }
+
+    async export() {
+        const tensor = new view.Tensor(this._value);
+        const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
+        const file = await this._host.save('NumPy Array', 'npy', defaultPath);
+        if (file) {
+            try {
+                let data_type = tensor.type.dataType;
+                data_type = data_type === 'boolean' ? 'bool' : data_type;
+                const execution = new python.Execution();
+                const bytes = execution.invoke('io.BytesIO', []);
+                const dtype = execution.invoke('numpy.dtype', [data_type]);
+                const array = execution.invoke('numpy.asarray', [tensor.value, dtype]);
+                execution.invoke('numpy.save', [bytes, array]);
+                bytes.seek(0);
+                const blob = new Blob([bytes.read()], { type: 'application/octet-stream' });
+                await this._host.export(file, blob);
+            } catch (error) {
+                this.error(error, 'Error saving NumPy tensor.', null);
+            }
+        }
     }
 };
 
@@ -3157,14 +3206,46 @@ view.TensorSidebar = class extends view.ObjectSidebar {
 
     render() {
         const value = this._value;
+        const tensor = value.initializer;
         const [name] = value.name.split('\n');
         this.addProperty('name', name);
-        if (value.type) {
-            const item = new view.ValueView(this._view, value, '');
-            this.add('type', item);
-            // item.toggle();
+        const category = tensor.category;
+        if (category) {
+            this.addProperty('category', category);
+        }
+        const description = tensor.description;
+        if (description) {
+            this.addProperty('description', description);
+        }
+        const type = tensor.type;
+        if (type) {
+            const value = type.toString().split('<').join('&lt;').split('>').join('&gt;');
+            const denotation = type.denotation;
+            const layout = type.layout;
+            this.addProperty('type', `${value}`, 'code');
+            if (denotation) {
+                this.addProperty('denotation', denotation, 'code');
+            }
+            if (layout) {
+                this.addProperty('layout', layout.replace('.', ' '));
+            }
+        }
+        const identifier = this._value.identifier;
+        if (identifier !== undefined) {
+            this.addProperty('identifier', tensor.identifier);
+        }
+        const location = tensor.location;
+        if (location) {
+            this.addProperty('location', tensor.location);
+        }
+        const stride = tensor.stride;
+        if (Array.isArray(stride) && stride.length > 0) {
+            this.addProperty('stride', stride.join(','), 'code');
+        }
+        if (tensor) {
+            const value = new view.TensorView(this._view, tensor);
+            this.add('value', value);
         }
-
         /*
         // TODO
         if (value.initializer) {
@@ -3650,44 +3731,8 @@ view.Tensor = class {
     constructor(tensor) {
         this._tensor = tensor;
         this._type = tensor.type;
-        this._encoding = tensor.encoding;
         this._layout = tensor.type.layout;
         this._stride = tensor.stride;
-        switch (this._encoding) {
-            case undefined:
-            case '':
-            case '<': {
-                this._data = this._tensor.values;
-                this._encoding = '<';
-                this._littleEndian = true;
-                break;
-            }
-            case '>': {
-                this._data = this._tensor.values;
-                this._encoding = '>';
-                this._littleEndian = false;
-                break;
-            }
-            case '|': {
-                this._values = this._tensor.values;
-                this._encoding = '|';
-                break;
-            }
-            default: {
-                throw new view.Error(`Unsupported tensor encoding '${this._encoding}'.`);
-            }
-        }
-        switch (this._layout) {
-            case 'sparse':
-            case 'sparse.coo': {
-                this._indices = this._tensor.indices;
-                this._values = this._tensor.values;
-                break;
-            }
-            default: {
-                break;
-            }
-        }
         view.Tensor.dataTypes = view.Tensor.dataTypeSizes || new Map([
             ['boolean', 1],
             ['qint8', 1], ['qint16', 2], ['qint32', 4],
@@ -3705,10 +3750,6 @@ view.Tensor = class {
         return this._type;
     }
 
-    get encoding() {
-        return this._encoding;
-    }
-
     get layout() {
         return this._layout;
     }
@@ -3717,19 +3758,39 @@ view.Tensor = class {
         return this._stride;
     }
 
+    get encoding() {
+        this._read();
+        return this._encoding;
+    }
+
+    get values() {
+        this._read();
+        return this._values;
+    }
+
+    get indices() {
+        this._read();
+        return this._indices;
+    }
+
+    get data() {
+        this._read();
+        return this._data;
+    }
+
     get empty() {
         switch (this._layout) {
             case 'sparse':
             case 'sparse.coo': {
-                return !this._values || this.indices || this._values.values === null || this._values.values.length === 0;
+                return !this.values || this.indices || this.values.values === null || this.values.values.length === 0;
             }
             default: {
                 switch (this._encoding) {
                     case '<':
                     case '>':
-                        return !(Array.isArray(this._data) || this._data instanceof Uint8Array || this._data instanceof Int8Array) || this._data.length === 0;
+                        return !(Array.isArray(this.data) || this.data instanceof Uint8Array || this.data instanceof Int8Array) || this.data.length === 0;
                     case '|':
-                        return !(Array.isArray(this._values) || ArrayBuffer.isView(this._values)) || this._values.length === 0;
+                        return !(Array.isArray(this.values) || ArrayBuffer.isView(this.values)) || this.values.length === 0;
                     default:
                         throw new Error(`Unsupported tensor encoding '${this._encoding}'.`);
                 }
@@ -4133,6 +4194,48 @@ view.Tensor = class {
         }
     }
 
+    _read() {
+        if (this._encoding === undefined) {
+            this._encoding = this._tensor.encoding;
+            this._values = null;
+            switch (this._encoding) {
+                case undefined:
+                case '':
+                case '<': {
+                    this._data = this._tensor.values;
+                    this._encoding = '<';
+                    this._littleEndian = true;
+                    break;
+                }
+                case '>': {
+                    this._data = this._tensor.values;
+                    this._encoding = '>';
+                    this._littleEndian = false;
+                    break;
+                }
+                case '|': {
+                    this._values = this._tensor.values;
+                    this._encoding = '|';
+                    break;
+                }
+                default: {
+                    throw new view.Error(`Unsupported tensor encoding '${this._encoding}'.`);
+                }
+            }
+            switch (this._layout) {
+                case 'sparse':
+                case 'sparse.coo': {
+                    this._indices = this._tensor.indices;
+                    this._values = this._tensor.values;
+                    break;
+                }
+                default: {
+                    break;
+                }
+            }
+        }
+    }
+
     get metrics() {
         if (!this._metrics) {
             const data = this.value;