Bladeren bron

Add Safetensors support (#1113)

Lutz Roeder 2 jaren geleden
bovenliggende
commit
91b7f2d06a
4 gewijzigde bestanden met toevoegingen van 146 en 1 verwijderingen
  1. 1 0
      package.json
  2. 1 1
      source/base.js
  3. 143 0
      source/safetensors.js
  4. 1 0
      source/view.js

+ 1 - 0
package.json

@@ -139,6 +139,7 @@
             { "ext": "pth",         "name": "PyTorch Model"            },
             { "ext": "ptl",         "name": "PyTorch Model"            },
             { "ext": "rknn",        "name": "RKNN Model"               },
+            { "ext": "safetensors", "name": "Safetensors Checkpoint"   },
             { "ext": "t7",          "name": "Torch Model"              },
             { "ext": "tfl",         "name": "TensorFlow Lite Model"    },
             { "ext": "tflite",      "name": "TensorFlow Lite Model"    },

+ 1 - 1
source/base.js

@@ -1090,7 +1090,7 @@ base.Metadata = class {
             'mlnet', 'mar',  'meta', 'nn', 'ngf', 'hn', 'har',
             'param', 'params',
             'paddle', 'pdiparams', 'pdmodel', 'pdopt', 'pdparams', 'nb',
-            'pkl', 'joblib',
+            'pkl', 'joblib', 'safetensors',
             'ptl', 't7',
             'dlc', 'uff', 'armnn',
             'mnn', 'ms', 'ncnn', 'om', 'tm', 'mge', 'tmfile', 'tnnproto', 'xmodel', 'kmodel', 'rknn',

+ 143 - 0
source/safetensors.js

@@ -0,0 +1,143 @@
+
+var safetensors = {};
+var json = require('./json');
+
+safetensors.ModelFactory = class {
+
+    match(context) {
+        const stream = context.stream;
+        if (stream.length > 9) {
+            const buffer = stream.peek(9);
+            if (buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
+                const size = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24 | buffer [3] << 32 | buffer [3] << 40;
+                if (size < stream.length) {
+                    return { size: size };
+                }
+            }
+        }
+        return '';
+    }
+
+    async open(context, match) {
+        const stream = context.stream;
+        stream.seek(8);
+        const buffer = stream.read(match.size);
+        const reader = json.TextReader.open(buffer);
+        const obj = reader.read();
+        const model = new safetensors.Model(obj, stream.position, stream);
+        stream.seek(0);
+        return model;
+    }
+};
+
+safetensors.Model = class {
+
+    constructor(obj, position, stream) {
+        this.format = 'Safetensors';
+        this.graphs = [ new safetensors.Graph(obj, position, stream) ];
+    }
+};
+
+safetensors.Graph = class {
+
+    constructor(obj, position, stream) {
+        this.inputs = [];
+        this.outputs = [];
+        this.nodes = [];
+        for (const entry of Object.entries(obj)) {
+            if (entry[0] === '__metadata__') {
+                continue;
+            }
+            this.nodes.push(new safetensors.Node(entry[0], entry[1], position, stream));
+        }
+    }
+};
+
+safetensors.Argument = class {
+
+    constructor(name, value) {
+        this.name = name;
+        this.value = value;
+    }
+};
+
+safetensors.Value = class {
+
+    constructor(name, value) {
+        this.name = name;
+        this.initializer = value;
+    }
+
+    get type() {
+        return this.initializer.type;
+    }
+};
+
+safetensors.Node = class {
+
+    constructor(key, value, position, stream) {
+        const parts = key.split('.');
+        const name = parts.pop();
+        this.name = parts.join('.');
+        this.type = { name: 'Module' };
+        const argument = new safetensors.Argument(name, [
+            new safetensors.Value(key, new safetensors.Tensor(value, position, stream))
+        ]);
+        this.inputs = [ argument ];
+        this.outputs = [];
+        this.attributes = [];
+    }
+};
+
+safetensors.TensorType = class {
+
+    constructor(dtype, shape) {
+        switch (dtype) {
+            case 'F16': this.dataType = 'float16'; break;
+            case 'F32': this.dataType = 'float32'; break;
+            default: throw new safetensors.Error("Unsupported data type '" + dtype + "'.");
+        }
+        this.shape = shape;
+    }
+
+    toString() {
+        return this.dataType + this.shape.toString();
+    }
+};
+
+safetensors.TensorShape = class {
+
+    constructor(dimensions) {
+        this.dimensions = dimensions;
+    }
+
+    toString() {
+        return '[' + this.dimensions.map((dimension) => dimension.toString()).join(',') + ']';
+    }
+};
+
+safetensors.Tensor = class {
+
+    constructor(obj, position, stream) {
+        const shape = new safetensors.TensorShape(obj.shape);
+        this.type = new safetensors.TensorType(obj.dtype, shape);
+        this.layout = '<';
+        const size = obj.data_offsets[1] - obj.data_offsets[0];
+        position += obj.data_offsets[0];
+        stream.seek(position);
+        this.values = stream.read(size);
+    }
+};
+
+
+safetensors.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading Safetensors model.';
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.ModelFactory = safetensors.ModelFactory;
+}

+ 1 - 0
source/view.js

@@ -4939,6 +4939,7 @@ view.ModelFactoryService = class {
         this.register('./onednn', [ '.json']);
         this.register('./mlir', [ '.mlir']);
         this.register('./hailo', [ '.hn', '.har' ]);
+        this.register('./safetensors', [ '.safetensors' ]);
     }
 
     register(id, factories, containers) {