|
|
@@ -6,13 +6,16 @@ const safetensors = {};
|
|
|
safetensors.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
- const stream = context.stream;
|
|
|
- if (stream.length > 9) {
|
|
|
- const buffer = stream.peek(9);
|
|
|
- if (buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
|
|
|
- const size = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24 | buffer [3] << 32 | buffer [3] << 40;
|
|
|
- if (size < stream.length) {
|
|
|
- return { name: 'safetensor', size: size };
|
|
|
+ const container = safetensors.Container.open(context);
|
|
|
+ if (container) {
|
|
|
+ return { name: 'safetensors', value: container };
|
|
|
+ }
|
|
|
+ const obj = context.peek('json');
|
|
|
+ if (obj.weight_map) {
|
|
|
+ const entries = Object.entries(obj.weight_map);
|
|
|
+ if (entries.every(([, value]) => typeof value === 'string' && value.endsWith('.safetensors'))) {
|
|
|
+ if (entries.length > 0) {
|
|
|
+ return { name: 'safetensors.json', value: entries };
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -20,37 +23,56 @@ safetensors.ModelFactory = class {
|
|
|
}
|
|
|
|
|
|
async open(context, target) {
|
|
|
- const stream = context.stream;
|
|
|
- stream.seek(8);
|
|
|
- const buffer = stream.read(target.size);
|
|
|
- const reader = json.TextReader.open(buffer);
|
|
|
- const obj = reader.read();
|
|
|
- const model = new safetensors.Model(obj, stream.position, stream);
|
|
|
- stream.seek(0);
|
|
|
- return model;
|
|
|
+ switch (target.name) {
|
|
|
+ case 'safetensors': {
|
|
|
+ const container = target.value;
|
|
|
+ await container.read();
|
|
|
+ return new safetensors.Model(container.entries);
|
|
|
+ }
|
|
|
+ case 'safetensors.json': {
|
|
|
+ target = new Map(target.value);
|
|
|
+ const keys = new Set(target.keys());
|
|
|
+ const files = Array.from(new Set(target.values()));
|
|
|
+ const contexts = await Promise.all(files.map((name) => context.fetch(name)));
|
|
|
+ const containers = contexts.map((context) => safetensors.Container.open(context));
|
|
|
+ await Promise.all(containers.map((container) => container.read()));
|
|
|
+ const entries = new Map();
|
|
|
+ for (const container of containers) {
|
|
|
+ for (const [key, value] of Array.from(container.entries)) {
|
|
|
+ if (keys.has(key)) {
|
|
|
+ entries.set(key, value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return new safetensors.Model(entries);
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ throw new safetensors.Error("Unsupported Safetensors format '" + target.name + "'.");
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
};
|
|
|
|
|
|
safetensors.Model = class {
|
|
|
|
|
|
- constructor(obj, position, stream) {
|
|
|
+ constructor(entries) {
|
|
|
this.format = 'Safetensors';
|
|
|
- this.graphs = [ new safetensors.Graph(obj, position, stream) ];
|
|
|
+ this.graphs = [ new safetensors.Graph(entries) ];
|
|
|
}
|
|
|
};
|
|
|
|
|
|
safetensors.Graph = class {
|
|
|
|
|
|
- constructor(obj, position, stream) {
|
|
|
+ constructor(entries) {
|
|
|
this.inputs = [];
|
|
|
this.outputs = [];
|
|
|
this.nodes = [];
|
|
|
const layers = new Map();
|
|
|
- for (const [key, value] of Object.entries(obj)) {
|
|
|
+ for (const [key, value] of Array.from(entries)) {
|
|
|
if (key === '__metadata__') {
|
|
|
continue;
|
|
|
}
|
|
|
- const parts = key[0].split('.');
|
|
|
+ const parts = key.split('.');
|
|
|
const name = parts.pop();
|
|
|
const layer = parts.join('.');
|
|
|
if (!layers.has(layer)) {
|
|
|
@@ -59,7 +81,7 @@ safetensors.Graph = class {
|
|
|
layers.get(layer).push([ name, key, value]);
|
|
|
}
|
|
|
for (const [name, values] of layers) {
|
|
|
- const node = new safetensors.Node(name, values, position, stream);
|
|
|
+ const node = new safetensors.Node(name, values);
|
|
|
this.nodes.push(node);
|
|
|
}
|
|
|
}
|
|
|
@@ -87,14 +109,14 @@ safetensors.Value = class {
|
|
|
|
|
|
safetensors.Node = class {
|
|
|
|
|
|
- constructor(name, values, position, stream) {
|
|
|
+ constructor(name, values) {
|
|
|
this.name = name;
|
|
|
this.type = { name: 'Module' };
|
|
|
this.inputs = [];
|
|
|
this.outputs = [];
|
|
|
this.attributes = [];
|
|
|
for (const [name, identifier, obj] of values) {
|
|
|
- const tensor = new safetensors.Tensor(obj, position, stream);
|
|
|
+ const tensor = new safetensors.Tensor(obj);
|
|
|
const value = new safetensors.Value(identifier, tensor);
|
|
|
const argument = new safetensors.Argument(name, [ value ]);
|
|
|
this.inputs.push(argument);
|
|
|
@@ -141,21 +163,70 @@ safetensors.TensorShape = class {
|
|
|
|
|
|
safetensors.Tensor = class {
|
|
|
|
|
|
- constructor(obj, position, stream) {
|
|
|
+ constructor(obj) {
|
|
|
const shape = new safetensors.TensorShape(obj.shape);
|
|
|
this.type = new safetensors.TensorType(obj.dtype, shape);
|
|
|
this.encoding = '<';
|
|
|
- const size = obj.data_offsets[1] - obj.data_offsets[0];
|
|
|
- position += obj.data_offsets[0];
|
|
|
- stream.seek(position);
|
|
|
- this._data = stream.stream(size);
|
|
|
+ this.data = obj.__data__;
|
|
|
}
|
|
|
|
|
|
get values() {
|
|
|
- return this._data instanceof Uint8Array ? this._data : this._data.peek();
|
|
|
+ if (this.data instanceof Uint8Array) {
|
|
|
+ return this.data;
|
|
|
+ }
|
|
|
+ if (this.data && this.data.peek) {
|
|
|
+ return this.data.peek();
|
|
|
+ }
|
|
|
+ return null;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+safetensors.Container = class {
|
|
|
+
|
|
|
+ static open(context) {
|
|
|
+ const identifier = context.identifier;
|
|
|
+ const stream = context.stream;
|
|
|
+ if (stream.length > 9) {
|
|
|
+ const buffer = stream.peek(9);
|
|
|
+ if (buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
|
|
|
+ const size = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24 | buffer [3] << 32 | buffer [3] << 40;
|
|
|
+ if (size < stream.length) {
|
|
|
+ return new safetensors.Container(identifier, stream, size);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ constructor(identifier, stream, size) {
|
|
|
+ this.identifier = identifier;
|
|
|
+ this.size = size;
|
|
|
+ this.stream = stream;
|
|
|
+ this.entries = new Map();
|
|
|
+ }
|
|
|
+
|
|
|
+ async read() {
|
|
|
+ const stream = this.stream;
|
|
|
+ const position = stream.position;
|
|
|
+ stream.seek(8);
|
|
|
+ const buffer = stream.read(this.size);
|
|
|
+ const reader = json.TextReader.open(buffer);
|
|
|
+ const obj = reader.read();
|
|
|
+ const offset = stream.position;
|
|
|
+ for (const [key, value] of Object.entries(obj)) {
|
|
|
+ if (key === '__metadata__') {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const [start, end] = value.data_offsets;
|
|
|
+ stream.seek(offset + start);
|
|
|
+ value.__data__ = stream.stream(end - start);
|
|
|
+ this.entries.set(key, value);
|
|
|
+ }
|
|
|
+ stream.seek(position);
|
|
|
+ delete this.size;
|
|
|
+ delete this.stream;
|
|
|
+ }
|
|
|
+};
|
|
|
|
|
|
safetensors.Error = class extends Error {
|
|
|
|