2
0

safetensors.js 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import * as json from './json.js';
  2. const safetensors = {};
  3. safetensors.ModelFactory = class {
  4. async match(context) {
  5. const container = safetensors.Reader.open(context);
  6. if (container) {
  7. return context.set('safetensors', container);
  8. }
  9. const obj = await context.peek('json');
  10. if (obj && obj.weight_map) {
  11. const entries = Object.entries(obj.weight_map);
  12. if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.safetensors'))) {
  13. return context.set('safetensors.json', entries);
  14. }
  15. }
  16. return null;
  17. }
  18. async open(context) {
  19. switch (context.type) {
  20. case 'safetensors': {
  21. const container = context.value;
  22. await container.read();
  23. return new safetensors.Model(container.entries);
  24. }
  25. case 'safetensors.json': {
  26. const weight_map = new Map(context.value);
  27. const keys = new Set(weight_map.keys());
  28. const files = Array.from(new Set(weight_map.values()));
  29. const contexts = await Promise.all(files.map((name) => context.fetch(name)));
  30. const containers = contexts.map((context) => safetensors.Reader.open(context));
  31. await Promise.all(containers.map((container) => container.read()));
  32. const entries = new Map();
  33. for (const container of containers) {
  34. for (const [key, value] of Array.from(container.entries)) {
  35. if (keys.has(key)) {
  36. entries.set(key, value);
  37. }
  38. }
  39. }
  40. return new safetensors.Model(entries);
  41. }
  42. default: {
  43. throw new safetensors.Error(`Unsupported Safetensors format '${context.type}'.`);
  44. }
  45. }
  46. }
  47. };
  48. safetensors.Model = class {
  49. constructor(entries) {
  50. this.format = 'Safetensors';
  51. this.modules = [new safetensors.Module(entries)];
  52. }
  53. };
  54. safetensors.Module = class {
  55. constructor(entries) {
  56. this.inputs = [];
  57. this.outputs = [];
  58. this.nodes = [];
  59. const layers = new Map();
  60. for (const [key, value] of Array.from(entries)) {
  61. if (key === '__metadata__') {
  62. continue;
  63. }
  64. const parts = key.split('.');
  65. const name = parts.pop();
  66. const layer = parts.join('.');
  67. if (!layers.has(layer)) {
  68. layers.set(layer, []);
  69. }
  70. layers.get(layer).push([name, key, value]);
  71. }
  72. for (const [name, values] of layers) {
  73. const node = new safetensors.Node(name, values);
  74. this.nodes.push(node);
  75. }
  76. }
  77. };
  78. safetensors.Argument = class {
  79. constructor(name, value) {
  80. this.name = name;
  81. this.value = value;
  82. }
  83. };
  84. safetensors.Value = class {
  85. constructor(name, value) {
  86. this.name = name;
  87. this.type = value.type;
  88. this.initializer = value;
  89. }
  90. };
  91. safetensors.Node = class {
  92. constructor(name, values) {
  93. this.name = name;
  94. this.type = { name: 'Module' };
  95. this.inputs = [];
  96. this.outputs = [];
  97. this.attributes = [];
  98. for (const [name, identifier, obj] of values) {
  99. const tensor = new safetensors.Tensor(obj);
  100. const value = new safetensors.Value(identifier, tensor);
  101. const argument = new safetensors.Argument(name, [value]);
  102. this.inputs.push(argument);
  103. }
  104. }
  105. };
  106. safetensors.TensorType = class {
  107. constructor(dtype, shape) {
  108. switch (dtype) {
  109. case 'I8': this.dataType = 'int8'; break;
  110. case 'I16': this.dataType = 'int16'; break;
  111. case 'I32': this.dataType = 'int32'; break;
  112. case 'I64': this.dataType = 'int64'; break;
  113. case 'U8': this.dataType = 'uint8'; break;
  114. case 'U16': this.dataType = 'uint16'; break;
  115. case 'U32': this.dataType = 'uint32'; break;
  116. case 'U64': this.dataType = 'uint64'; break;
  117. case 'BF16': this.dataType = 'bfloat16'; break;
  118. case 'F16': this.dataType = 'float16'; break;
  119. case 'F32': this.dataType = 'float32'; break;
  120. case 'F64': this.dataType = 'float64'; break;
  121. case 'BOOL': this.dataType = 'boolean'; break;
  122. case 'F8_E4M3': this.dataType = 'float8e4m3fn'; break;
  123. case 'F8_E5M2': this.dataType = 'float8e5m2'; break;
  124. default: throw new safetensors.Error(`Unsupported data type '${dtype}'.`);
  125. }
  126. this.shape = shape;
  127. }
  128. toString() {
  129. return this.dataType + this.shape.toString();
  130. }
  131. };
  132. safetensors.TensorShape = class {
  133. constructor(dimensions) {
  134. this.dimensions = dimensions;
  135. }
  136. toString() {
  137. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  138. }
  139. };
  140. safetensors.Tensor = class {
  141. constructor(obj) {
  142. const shape = new safetensors.TensorShape(obj.shape);
  143. this.type = new safetensors.TensorType(obj.dtype, shape);
  144. this.encoding = '<';
  145. this.data = obj.__data__;
  146. }
  147. get values() {
  148. if (this.data instanceof Uint8Array) {
  149. return this.data;
  150. }
  151. if (this.data && this.data.peek) {
  152. return this.data.peek();
  153. }
  154. return null;
  155. }
  156. };
  157. safetensors.Reader = class {
  158. static open(context) {
  159. const identifier = context.identifier;
  160. const stream = context.stream;
  161. if (stream.length > 9) {
  162. const buffer = stream.peek(9);
  163. // Safetensors implementation caps headers length at 100 MB.
  164. if (buffer[4] === 0 && buffer[5] === 0 && buffer[6] === 0 && buffer[7] === 0 && buffer[8] === 0x7b) {
  165. const size = (buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer[3] << 24) >>> 0;
  166. if (size < stream.length) {
  167. return new safetensors.Reader(identifier, stream, size);
  168. }
  169. }
  170. }
  171. return null;
  172. }
  173. constructor(identifier, stream, size) {
  174. this.identifier = identifier;
  175. this.size = size;
  176. this.stream = stream;
  177. this.entries = new Map();
  178. }
  179. async read() {
  180. const stream = this.stream;
  181. const position = stream.position;
  182. stream.seek(8);
  183. const buffer = stream.read(this.size);
  184. const reader = json.TextReader.open(buffer);
  185. const obj = reader.read();
  186. const offset = stream.position;
  187. for (const [key, value] of Object.entries(obj)) {
  188. if (key === '__metadata__') {
  189. continue;
  190. }
  191. const [start, end] = value.data_offsets;
  192. stream.seek(offset + start);
  193. value.__data__ = stream.stream(end - start);
  194. this.entries.set(key, value);
  195. }
  196. stream.seek(position);
  197. delete this.size;
  198. delete this.stream;
  199. }
  200. };
  201. safetensors.Error = class extends Error {
  202. constructor(message) {
  203. super(message);
  204. this.name = 'Error loading Safetensors model.';
  205. }
  206. };
  207. export const ModelFactory = safetensors.ModelFactory;