Jelajahi Sumber

Add Transformers support (#1480)

Lutz Roeder 7 bulan lalu
induk
melakukan
57bd6f4aca
7 mengubah file dengan 223 tambahan dan 32 penghapusan
  1. 1 1
      source/acuity.js
  2. 0 1
      source/json.js
  3. 174 0
      source/transformers.js
  4. 15 16
      source/view.js
  5. 1 1
      source/xgboost.js
  6. 14 7
      test/models.json
  7. 18 6
      test/worker.js

+ 1 - 1
source/acuity.js

@@ -5,7 +5,7 @@ acuity.ModelFactory = class {
 
     async match(context) {
         const obj = await context.peek('json');
-        if (obj && obj.MetaData && obj.Layers) {
+        if (obj && obj.MetaData && obj.Layers && Object.keys(obj).length < 256) {
             return context.set('acuity', obj);
         }
         return null;

+ 0 - 1
source/json.js

@@ -148,7 +148,6 @@ json.TextReader = class {
                     const key = this._string();
                     switch (key) {
                         case '__proto__':
-                        case 'constructor':
                             throw new json.Error(`Invalid key '${key}' ${this._location()}`);
                         default:
                             break;

+ 174 - 0
source/transformers.js

@@ -0,0 +1,174 @@
+
+// import * as python from './python.js';
+// import * as safetensors from './safetensors.js';
+
+const transformers = {};
+
+transformers.ModelFactory = class {
+
+    async match(context) {
+        const obj = await context.peek('json');
+        if (obj) {
+            if (obj.model_type && obj.architectures) {
+                return context.set('transformers.config', obj);
+            }
+            if (obj.version && obj.added_tokens && obj.model) {
+                return context.set('transformers.tokenizer', obj);
+            }
+            if (obj.tokenizer_class ||
+                (obj.bos_token && obj.eos_token && obj.unk_token) ||
+                (obj.pad_token && obj.additional_special_tokens) ||
+                obj.special_tokens_map_file || obj.full_tokenizer_file) {
+                return context.set('transformers.tokenizer.config', obj);
+            }
+            if (context.identifier === 'vocab.json' && Object.keys(obj).length > 256) {
+                return context.set('transformers.vocab', obj);
+            }
+        }
+        return null;
+    }
+
+    async open(context) {
+        const fetch = async (name) => {
+            try {
+                const content = await context.fetch(name);
+                await this.match(content);
+                if (content.value) {
+                    return content;
+                }
+            } catch {
+                // continue regardless of error
+            }
+            return null;
+        };
+        switch (context.type) {
+            case 'transformers.config': {
+                const tokenizer = await fetch('tokenizer.json');
+                const tokenizer_config = await fetch('tokenizer_config.json');
+                const vocab = await fetch('vocab.json');
+                return new transformers.Model(context, tokenizer, tokenizer_config, vocab);
+            }
+            case 'transformers.tokenizer': {
+                const config = await fetch('config.json');
+                const tokenizer_config = await fetch('tokenizer_config.json');
+                const vocab = await fetch('vocab.json');
+                return new transformers.Model(config, context, tokenizer_config, vocab);
+            }
+            case 'transformers.tokenizer.config': {
+                const config = await fetch('config.json');
+                const tokenizer = await fetch('tokenizer.json');
+                const vocab = await fetch('vocab.json');
+                return new transformers.Model(config, tokenizer, context, vocab);
+            }
+            case 'transformers.vocab': {
+                const config = await fetch('config.json');
+                const tokenizer = await fetch('tokenizer.json');
+                const tokenizer_config = await fetch('tokenizer_config.json');
+                return new transformers.Model(config, tokenizer, tokenizer_config, context);
+            }
+            default: {
+                throw new transformers.Error(`Unsupported Transformers format '${context.type}'.`);
+            }
+        }
+    }
+
+    filter(context, type) {
+        return context.type !== 'transformers.config' || (type !== 'transformers.tokenizer' && type !== 'transformers.tokenizer.config' && type !== 'transformers.vocab' && type !== 'safetensors.json');
+    }
+};
+
+transformers.Model = class {
+
+    constructor(config, tokenizer, tokenizer_config, vocab) {
+        this.format = 'Transformers';
+        this.metadata = [];
+        this.modules = [new transformers.Graph(config, tokenizer, tokenizer_config, vocab)];
+    }
+};
+
+transformers.Graph = class {
+
+    constructor(config, tokenizer, tokenizer_config, vocab) {
+        this.type = 'graph';
+        this.nodes = [];
+        this.inputs = [];
+        this.outputs = [];
+        this.metadata = [];
+        if (config) {
+            for (const [key, value] of Object.entries(config.value)) {
+                const argument = new transformers.Argument(key, value);
+                this.metadata.push(argument);
+            }
+        }
+        if (tokenizer || tokenizer_config) {
+            const node = new transformers.Tokenizer(tokenizer, tokenizer_config, vocab);
+            this.nodes.push(node);
+        }
+    }
+};
+
+transformers.Tokenizer = class {
+
+    constructor(tokenizer, tokenizer_config) {
+        this.type = { name: 'Tokenizer' };
+        this.name = (tokenizer || tokenizer_config).identifier;
+        this.attributes = [];
+        if (tokenizer) {
+            const obj = tokenizer.value;
+            const keys = new Set(['decoder', 'model', 'post_processor', 'pre_tokenizer']);
+            for (const [key, value] of Object.entries(tokenizer.value)) {
+                if (!keys.has(key)) {
+                    const argument = new transformers.Argument(key, value);
+                    this.attributes.push(argument);
+                }
+            }
+            for (const key of keys) {
+                const value = obj[key];
+                if (value) {
+                    const module = new transformers.Object(value);
+                    const argument = new transformers.Argument(key, module, 'object');
+                    this.attributes.push(argument);
+                }
+            }
+        }
+    }
+};
+
+transformers.Object = class {
+
+    constructor(obj) {
+        this.type = { name: obj.type };
+        this.attributes = [];
+        for (const [key, value] of Object.entries(obj)) {
+            if (key !== 'type') {
+                let argument = null;
+                if (Array.isArray(value) && value.every((item) => typeof item === 'object')) {
+                    const values = value.map((item) => new transformers.Object(item));
+                    argument = new transformers.Argument(key, values, 'object[]');
+                } else {
+                    argument = new transformers.Argument(key, value);
+                }
+                this.attributes.push(argument);
+            }
+        }
+    }
+};
+
+transformers.Argument = class {
+
+    constructor(name, value, type) {
+        this.name = name;
+        this.value = value;
+        this.type = type || null;
+    }
+};
+
+transformers.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Transformers model.';
+    }
+};
+
+export const ModelFactory = transformers.ModelFactory;

+ 15 - 16
source/view.js

@@ -866,15 +866,15 @@ view.View = class {
 
     async renderGraph(model, graph, signature, options) {
         this._graph = null;
-        const document = this._host.document;
-        const window = this._host.window;
         const canvas = this._element('canvas');
         while (canvas.lastChild) {
             canvas.removeChild(canvas.lastChild);
         }
-        if (!graph) {
+        if (!graph || graph.type === 'tokenizer' || graph.type === 'vocabulary') {
             return '';
         }
+        const document = this._host.document;
+        const window = this._host.window;
         this._zoom = 1;
         const groups = graph.groups || false;
         const nodes = graph.nodes;
@@ -1120,6 +1120,12 @@ view.View = class {
                 case 'weights':
                     title = 'Weights Properties';
                     break;
+                case 'tokenizer':
+                    title = 'Tokenizer Properties';
+                    break;
+                case 'vocabulary':
+                    title = 'Vocabulary Properties';
+                    break;
                 default:
                     throw new view.Error(`Unsupported graph type '${type}'.`);
             }
@@ -2731,13 +2737,13 @@ view.TargetSelector = class extends view.Control {
                 }
             }
         };
-        const graphs = [];
+        const modules = [];
         const signatures = [];
         const functions = [];
         if (model && Array.isArray(model.modules)) {
             for (const graph of model.modules) {
                 const name = graph.name || '(unnamed)';
-                graphs.push({ name, target: graph, signature: null });
+                modules.push({ name, target: graph, signature: null });
                 if (Array.isArray(graph.functions)) {
                     for (const func of graph.functions) {
                         functions.push({ name: `${name}.${func.name}`, target: func, signature: null });
@@ -2755,10 +2761,10 @@ view.TargetSelector = class extends view.Control {
                 functions.push({ name: func.name, target: func, signature: null });
             }
         }
-        section('Graphs', graphs);
+        section('Modules', modules);
         section('Signatures', signatures);
         section('Functions', functions);
-        const visible = functions.length > 0 || signatures.length > 0 || graphs.length > 1;
+        const visible = functions.length > 0 || signatures.length > 0 || modules.length > 1;
         this._element.style.display = visible ? 'inline' : 'none';
     }
 };
@@ -6216,6 +6222,7 @@ view.ModelFactoryService = class {
         this.register('./qnn', ['.json', '.bin', '.serialized', '.dlc']);
         this.register('./kann', ['.kann', '.bin', '.kgraph'], [], [/^....KaNN/]);
         this.register('./xgboost', ['.xgb', '.xgboost', '.json', '.model', '.bin', '.txt'], [], [/^{L\x00\x00/, /^binf/, /^bs64/, /^\s*booster\[0\]:/]);
+        this.register('./transformers', ['.json']);
         this.register('', ['.cambricon', '.vnnmodel', '.nnc']);
         /* eslint-enable no-control-regex */
     }
@@ -6343,18 +6350,10 @@ view.ModelFactoryService = class {
                     { name: 'Trace Event data', tags: ['traceEvents'] },
                     { name: 'Trace Event data', tags: ['[].pid', '[].ph'] },
                     { name: 'Diffusers configuration', tags: ['_class_name', '_diffusers_version'] },
-                    { name: 'Transformers configuration', tags: ['architectures', 'model_type'] }, // https://huggingface.co/docs/transformers/en/create_a_model
                     { name: 'Transformers generation configuration', tags: ['transformers_version'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['tokenizer_class'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['bos_token', 'eos_token', 'unk_token'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['bos_token', 'eos_token', 'pad_token'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['additional_special_tokens'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['special_tokens_map_file'] },
-                    { name: 'Transformers tokenizer configuration', tags: ['full_tokenizer_file'] },
                     { name: 'Transformers vocabulary data', tags: ['<|im_start|>'] },
                     { name: 'Transformers vocabulary data', tags: ['<|endoftext|>'] },
                     { name: 'Transformers preprocessor configuration', tags: ['crop_size', 'do_center_crop', 'image_mean', 'image_std', 'do_resize'] },
-                    { name: 'Tokenizers data', tags: ['version', 'added_tokens', 'model'] }, // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/serialization.rs
                     { name: 'Tokenizer data', tags: ['<eos>', '<bos>'] },
                     { name: 'Jupyter Notebook data', tags: ['cells', 'nbformat'] },
                     { name: 'Kaggle credentials', tags: ['username','key'] },
@@ -6719,7 +6718,7 @@ view.ModelFactoryService = class {
     _filter(context) {
         const identifier = context.identifier.toLowerCase().split('/').pop();
         const stream = context.stream;
-        if (stream) {
+        if (stream && stream.length < 0x7FFFFFFF) {
             const buffer = stream.peek(Math.min(4096, stream.length));
             const content = String.fromCharCode.apply(null, buffer);
             const list = this._factories.filter((entry) =>

+ 1 - 1
source/xgboost.js

@@ -9,7 +9,7 @@ xgboost.ModelFactory = class {
 
     async match(context) {
         const obj = await context.peek('json');
-        if (obj && obj.learner && obj.version) {
+        if (obj && obj.learner && obj.version && Object.keys(obj).length < 256) {
             return context.set('xgboost.json', obj);
         }
         const stream = context.stream;

+ 14 - 7
test/models.json

@@ -173,13 +173,6 @@
     "error":    "Unsupported file content.",
     "link":     "https://github.com/lutzroeder/netron/issues/458"
   },
-  {
-    "type":     "_",
-    "target":   "tokenizer.json",
-    "source":   "https://github.com/user-attachments/files/16203071/tokenizer.json.zip[tokenizer.json]",
-    "error":    "Invalid file content. File contains Tokenizers data.",
-    "link":     "https://github.com/lutzroeder/netron/issues/458"
-  },
   {
     "type":     "_",
     "target":   "sm_uint8_fence.nnc",
@@ -8463,6 +8456,20 @@
     "format":   "Torch v7",
     "link":     "https://github.com/cpra/fer-cnn-sota"
   },
+  {
+    "type":     "transformers",
+    "target":   "Qwen2-7B-Instruct.zip",
+    "source":   "https://github.com/user-attachments/files/21336228/Qwen2-7B-Instruct.zip",
+    "format":   "Transformers",
+    "link":     "https://github.com/lutzroeder/netron/issues/1480"
+  },
+  {
+    "type":     "transformers",
+    "target":   "tokenizer.json",
+    "source":   "https://github.com/user-attachments/files/16203071/tokenizer.json.zip[tokenizer.json]",
+    "format":   "Transformers",
+    "link":     "https://github.com/lutzroeder/netron/issues/1480"
+  },
   {
     "type":     "tvm",
     "target":   "mobilenet-v1-tvm.json",

+ 18 - 6
test/worker.js

@@ -598,7 +598,7 @@ export class Target {
         if (this.runtime && model.runtime !== this.runtime) {
             throw new Error(`Invalid runtime '${model.runtime}'.`);
         }
-        if (model.metadata && (!Array.isArray(model.metadata) || !model.metadata.every((argument) => argument.name && argument.value))) {
+        if (model.metadata && (!Array.isArray(model.metadata) || !model.metadata.every((argument) => argument.name && (argument.value || argument.value === null || argument.value === '' || argument.value === false || argument.value === 0)))) {
             throw new Error("Invalid model metadata.'");
         }
         if (this.assert) {
@@ -632,8 +632,8 @@ export class Target {
         if (model.version || model.description || model.author || model.license) {
             // continue
         }
-        /* eslint-disable no-unused-expressions */
-        const validateTarget = async (graph) => {
+        const validateGraph = async (graph) => {
+            /* eslint-disable no-unused-expressions */
             const values = new Map();
             const validateValue = async (value) => {
                 if (value === null) {
@@ -741,7 +741,7 @@ export class Target {
                 }
                 if (Array.isArray(type.nodes)) {
                     /* eslint-disable no-await-in-loop */
-                    await validateTarget(type);
+                    await validateGraph(type);
                     /* eslint-enable no-await-in-loop */
                 }
                 view.Documentation.open(type);
@@ -759,7 +759,7 @@ export class Target {
                         const value = attribute.value;
                         if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
                             /* eslint-disable no-await-in-loop */
-                            await validateTarget(value);
+                            await validateGraph(value);
                             /* eslint-enable no-await-in-loop */
                         } else {
                             let text = new view.Formatter(attribute.value, attribute.type).toString();
@@ -813,8 +813,20 @@ export class Target {
                 const sidebar = new view.NodeSidebar(this.view, node);
                 sidebar.render();
             }
-            const sidebar = new view.ModelSidebar(this.view, model, graph);
+            const sidebar = new view.ModelSidebar(this.view, this.model, graph);
             sidebar.render();
+            /* eslint-enable no-unused-expressions */
+        };
+        const validateTarget = async (target) => {
+            switch (target.type) {
+                case 'tokenizer':
+                case 'vocabulary': {
+                    break;
+                }
+                default: {
+                    await validateGraph(target);
+                }
+            }
         };
         for (const module of model.modules) {
             /* eslint-disable no-await-in-loop */