فهرست منبع

Add Safetensors .json support (#1113)

Lutz Roeder 2 سال پیش
والد
کامیت
924897cfcf
3فایلهای تغییر یافته به همراه110 افزوده شده و 30 حذف شده
  1. 100 29
      source/safetensors.js
  2. 1 1
      source/view.js
  3. 9 0
      test/models.json

+ 100 - 29
source/safetensors.js

@@ -6,13 +6,16 @@ const safetensors = {};
 safetensors.ModelFactory = class {
 
     match(context) {
-        const stream = context.stream;
-        if (stream.length > 9) {
-            const buffer = stream.peek(9);
-            if (buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
-                const size = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24 | buffer [3] << 32 | buffer [3] << 40;
-                if (size < stream.length) {
-                    return { name: 'safetensor', size: size };
+        const container = safetensors.Container.open(context);
+        if (container) {
+            return { name: 'safetensors', value: container };
+        }
+        const obj = context.peek('json');
+        if (obj.weight_map) {
+            const entries = Object.entries(obj.weight_map);
+            if (entries.every(([, value]) => typeof value === 'string' && value.endsWith('.safetensors'))) {
+                if (entries.length > 0) {
+                    return { name: 'safetensors.json', value: entries };
                 }
             }
         }
@@ -20,37 +23,56 @@ safetensors.ModelFactory = class {
     }
 
     async open(context, target) {
-        const stream = context.stream;
-        stream.seek(8);
-        const buffer = stream.read(target.size);
-        const reader = json.TextReader.open(buffer);
-        const obj = reader.read();
-        const model = new safetensors.Model(obj, stream.position, stream);
-        stream.seek(0);
-        return model;
+        switch (target.name) {
+            case 'safetensors': {
+                const container = target.value;
+                await container.read();
+                return new safetensors.Model(container.entries);
+            }
+            case 'safetensors.json': {
+                target = new Map(target.value);
+                const keys = new Set(target.keys());
+                const files = Array.from(new Set(target.values()));
+                const contexts = await Promise.all(files.map((name) => context.fetch(name)));
+                const containers = contexts.map((context) => safetensors.Container.open(context));
+                await Promise.all(containers.map((container) => container.read()));
+                const entries = new Map();
+                for (const container of containers) {
+                    for (const [key, value] of Array.from(container.entries)) {
+                        if (keys.has(key)) {
+                            entries.set(key, value);
+                        }
+                    }
+                }
+                return new safetensors.Model(entries);
+            }
+            default: {
+                throw new safetensors.Error("Unsupported Safetensors format '" + target.name + "'.");
+            }
+        }
     }
 };
 
 safetensors.Model = class {
 
-    constructor(obj, position, stream) {
+    constructor(entries) {
         this.format = 'Safetensors';
-        this.graphs = [ new safetensors.Graph(obj, position, stream) ];
+        this.graphs = [ new safetensors.Graph(entries) ];
     }
 };
 
 safetensors.Graph = class {
 
-    constructor(obj, position, stream) {
+    constructor(entries) {
         this.inputs = [];
         this.outputs = [];
         this.nodes = [];
         const layers = new Map();
-        for (const [key, value] of Object.entries(obj)) {
+        for (const [key, value] of Array.from(entries)) {
             if (key === '__metadata__') {
                 continue;
             }
-            const parts = key[0].split('.');
+            const parts = key.split('.');
             const name = parts.pop();
             const layer = parts.join('.');
             if (!layers.has(layer)) {
@@ -59,7 +81,7 @@ safetensors.Graph = class {
             layers.get(layer).push([ name, key, value]);
         }
         for (const [name, values] of layers) {
-            const node = new safetensors.Node(name, values, position, stream);
+            const node = new safetensors.Node(name, values);
             this.nodes.push(node);
         }
     }
@@ -87,14 +109,14 @@ safetensors.Value = class {
 
 safetensors.Node = class {
 
-    constructor(name, values, position, stream) {
+    constructor(name, values) {
         this.name = name;
         this.type = { name: 'Module' };
         this.inputs = [];
         this.outputs = [];
         this.attributes = [];
         for (const [name, identifier, obj] of values) {
-            const tensor = new safetensors.Tensor(obj, position, stream);
+            const tensor = new safetensors.Tensor(obj);
             const value = new safetensors.Value(identifier, tensor);
             const argument = new safetensors.Argument(name, [ value ]);
             this.inputs.push(argument);
@@ -141,21 +163,70 @@ safetensors.TensorShape = class {
 
 safetensors.Tensor = class {
 
-    constructor(obj, position, stream) {
+    constructor(obj) {
         const shape = new safetensors.TensorShape(obj.shape);
         this.type = new safetensors.TensorType(obj.dtype, shape);
         this.encoding = '<';
-        const size = obj.data_offsets[1] - obj.data_offsets[0];
-        position += obj.data_offsets[0];
-        stream.seek(position);
-        this._data = stream.stream(size);
+        this.data = obj.__data__;
     }
 
     get values() {
-        return this._data instanceof Uint8Array ? this._data : this._data.peek();
+        if (this.data instanceof Uint8Array) {
+            return this.data;
+        }
+        if (this.data && this.data.peek) {
+            return this.data.peek();
+        }
+        return null;
     }
 };
 
+safetensors.Container = class {
+
+    static open(context) {
+        const identifier = context.identifier;
+        const stream = context.stream;
+        if (stream.length > 9) {
+            const buffer = stream.peek(9);
+            if (buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
+                const size = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24 | buffer [3] << 32 | buffer [3] << 40;
+                if (size < stream.length) {
+                    return new safetensors.Container(identifier, stream, size);
+                }
+            }
+        }
+        return null;
+    }
+
+    constructor(identifier, stream, size) {
+        this.identifier = identifier;
+        this.size = size;
+        this.stream = stream;
+        this.entries = new Map();
+    }
+
+    async read() {
+        const stream = this.stream;
+        const position = stream.position;
+        stream.seek(8);
+        const buffer = stream.read(this.size);
+        const reader = json.TextReader.open(buffer);
+        const obj = reader.read();
+        const offset = stream.position;
+        for (const [key, value] of Object.entries(obj)) {
+            if (key === '__metadata__') {
+                continue;
+            }
+            const [start, end] = value.data_offsets;
+            stream.seek(offset + start);
+            value.__data__ = stream.stream(end - start);
+            this.entries.set(key, value);
+        }
+        stream.seek(position);
+        delete this.size;
+        delete this.stream;
+    }
+};
 
 safetensors.Error = class extends Error {
 

+ 1 - 1
source/view.js

@@ -5218,7 +5218,7 @@ view.ModelFactoryService = class {
         this.register('./sentencepiece', [ '.model' ]);
         this.register('./hailo', [ '.hn', '.har', '.metadata.json' ]);
         this.register('./nnc', [ '.nnc' ]);
-        this.register('./safetensors', [ '.safetensors' ]);
+        this.register('./safetensors', [ '.safetensors', '.json' ]);
         this.register('./modular', [ '.maxviz' ]);
         this.register('./cambricon', [ '.cambricon' ]);
         this.register('./weka', [ '.model' ]);

+ 9 - 0
test/models.json

@@ -5902,6 +5902,15 @@
     "target":   "simple.safetensors",
     "source":   "https://github.com/lutzroeder/netron/files/11943528/simple.safetensors.zip[simple.safetensors]",
     "format":   "Safetensors",
+    "assert":   [ "model.graphs[0].nodes[0].inputs[0].name == 'uint64'" ],
+    "link":     "https://github.com/lutzroeder/netron/issues/1113"
+  },
+  {
+    "type":     "safetensors",
+    "target":   "model.safetensors.index.json,model.safetensors",
+    "source":   "https://github.com/lutzroeder/netron/files/13689293/model.safetensors.index.json.zip[model.safetensors.index.json,model.safetensors]",
+    "format":   "Safetensors",
+    "assert":   [ "model.graphs[0].nodes[0].inputs[0].name == 'float32'" ],
     "link":     "https://github.com/lutzroeder/netron/issues/1113"
   },
   {