espresso.js 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. const espresso = {};
  2. espresso.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier.toLowerCase();
  5. if (identifier.endsWith('.espresso.net')) {
  6. const obj = await context.peek('json');
  7. if (obj && Array.isArray(obj.layers) && obj.format_version) {
  8. return context.set('espresso.net', obj);
  9. }
  10. }
  11. if (identifier.endsWith('.espresso.shape')) {
  12. const obj = await context.peek('json');
  13. if (obj && obj.layer_shapes) {
  14. return context.set('espresso.shape', obj);
  15. }
  16. }
  17. if (identifier.endsWith('.espresso.weights')) {
  18. const target = await context.read('binary');
  19. return context.set('espresso.weights', target);
  20. }
  21. return null;
  22. }
  23. filter(context, match) {
  24. if (context.type === 'espresso.net' && (match.type === 'espresso.weights' || match.type === 'espresso.shape' || match.type === 'coreml.metadata.mlmodelc')) {
  25. return false;
  26. }
  27. if (context.type === 'espresso.shape' && (match.type === 'espresso.weights' || match.type === 'coreml.metadata.mlmodelc')) {
  28. return false;
  29. }
  30. return true;
  31. }
  32. async open(context) {
  33. const metadata = await context.metadata('espresso-metadata.json');
  34. switch (context.type) {
  35. case 'espresso.net': {
  36. const reader = new espresso.Reader(context.value, null, null);
  37. await reader.read(context);
  38. return new espresso.Model(metadata, reader);
  39. }
  40. case 'espresso.weights': {
  41. const reader = new espresso.Reader(null, context.value, null);
  42. await reader.read(context);
  43. return new espresso.Model(metadata, reader);
  44. }
  45. case 'espresso.shape': {
  46. const reader = new espresso.Reader(null, null, context.value);
  47. await reader.read(context);
  48. return new espresso.Model(metadata, reader);
  49. }
  50. default: {
  51. throw new espresso.Error(`Unsupported Core ML format '${context.type}'.`);
  52. }
  53. }
  54. }
  55. };
  56. espresso.Model = class {
  57. constructor(metadata, reader) {
  58. this.format = reader.format;
  59. this.metadata = [];
  60. this.modules = [new espresso.Graph(metadata, reader)];
  61. if (reader.version) {
  62. this.version = reader.version;
  63. }
  64. if (reader.description) {
  65. this.description = reader.description;
  66. }
  67. for (const argument of reader.properties) {
  68. this.metadata.push(argument);
  69. }
  70. }
  71. };
  72. espresso.Graph = class {
  73. constructor(metadata, reader) {
  74. this.name = '';
  75. this.type = reader.type;
  76. for (const value of reader.values.values()) {
  77. const name = value.name;
  78. const type = value.type;
  79. const description = value.description;
  80. const initializer = value.initializer;
  81. if (!value.value) {
  82. value.value = new espresso.Value(name, type, description, initializer);
  83. }
  84. }
  85. this.inputs = reader.inputs.map((argument) => {
  86. const values = argument.value.map((value) => value.value);
  87. return new espresso.Argument(argument.name, values, null, argument.visible);
  88. });
  89. this.outputs = reader.outputs.map((argument) => {
  90. const values = argument.value.map((value) => value.value);
  91. return new espresso.Argument(argument.name, values, null, argument.visible);
  92. });
  93. for (const obj of reader.nodes) {
  94. const attributes = obj.attributes;
  95. switch (obj.type) {
  96. case 'loop':
  97. attributes.conditionNetwork = new espresso.Graph(attributes.conditionNetwork);
  98. attributes.bodyNetwork = new espresso.Graph(attributes.bodyNetwork);
  99. break;
  100. case 'branch':
  101. attributes.ifBranch = new espresso.Graph(attributes.ifBranch);
  102. attributes.elseBranch = new espresso.Graph(attributes.elseBranch);
  103. break;
  104. default:
  105. break;
  106. }
  107. }
  108. this.nodes = reader.nodes.map((obj) => new espresso.Node(metadata, obj));
  109. }
  110. };
  111. espresso.Argument = class {
  112. constructor(name, value, type, visible) {
  113. this.name = name;
  114. this.value = value;
  115. this.type = type || null;
  116. this.visible = visible !== false;
  117. }
  118. };
  119. espresso.Value = class {
  120. constructor(name, type, description, initializer) {
  121. if (typeof name !== 'string') {
  122. throw new espresso.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  123. }
  124. this.name = name;
  125. this.type = !type && initializer ? initializer.type : type;
  126. this.description = description || null;
  127. this.initializer = initializer || null;
  128. this.quantization = initializer ? initializer.quantization : null;
  129. }
  130. };
  131. espresso.Node = class {
  132. constructor(metadata, obj) {
  133. if (!obj.type) {
  134. throw new Error('Undefined node type.');
  135. }
  136. const type = metadata.type(obj.type);
  137. this.type = type ? { ...type } : { name: obj.type };
  138. this.type.name = obj.type.split(':').pop();
  139. this.name = obj.name || '';
  140. this.description = obj.description || '';
  141. this.inputs = (obj.inputs || []).map((argument) => {
  142. const values = argument.value.map((value) => value.value);
  143. return new espresso.Argument(argument.name, values, null, argument.visible);
  144. });
  145. this.outputs = (obj.outputs || []).map((argument) => {
  146. const values = argument.value.map((value) => value.value);
  147. return new espresso.Argument(argument.name, values, null, argument.visible);
  148. });
  149. this.attributes = Object.entries(obj.attributes || []).map(([name, value]) => {
  150. const schema = metadata.attribute(obj.type, name);
  151. let type = null;
  152. let visible = true;
  153. if (schema) {
  154. type = schema.type ? schema.type : type;
  155. if (schema.visible === false) {
  156. visible = false;
  157. } else if (schema.default !== undefined) {
  158. if (Array.isArray(value)) {
  159. value = value.map((item) => Number(item));
  160. }
  161. if (typeof value === 'bigint') {
  162. value = value.toNumber();
  163. }
  164. if (JSON.stringify(schema.default) === JSON.stringify(value)) {
  165. visible = false;
  166. }
  167. }
  168. }
  169. return new espresso.Argument(name, value, type, visible);
  170. });
  171. if (Array.isArray(obj.chain)) {
  172. this.chain = obj.chain.map((obj) => new espresso.Node(metadata, obj));
  173. }
  174. }
  175. };
  176. espresso.Tensor = class {
  177. constructor(type, data, quantization, category) {
  178. this.type = type;
  179. this.values = data;
  180. this.quantization = quantization;
  181. this.category = category;
  182. this.encoding = '<';
  183. }
  184. };
  185. espresso.TensorType = class {
  186. constructor(dataType, shape) {
  187. this.dataType = dataType;
  188. this.shape = shape || new espresso.TensorShape([]);
  189. }
  190. equals(obj) {
  191. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  192. }
  193. toString() {
  194. return this.dataType + this.shape.toString();
  195. }
  196. };
  197. espresso.TensorShape = class {
  198. constructor(dimensions) {
  199. this.dimensions = dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim);
  200. }
  201. equals(obj) {
  202. return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) &&
  203. this.dimensions.length === obj.dimensions.length &&
  204. obj.dimensions.every((value, index) => this.dimensions[index] === value);
  205. }
  206. toString() {
  207. return Array.isArray(this.dimensions) && this.dimensions.length > 0 ?
  208. `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]` : '';
  209. }
  210. };
  211. espresso.Reader = class {
  212. constructor(net, weights, shape) {
  213. this.targets = [net, shape, weights];
  214. }
  215. async read(context) {
  216. this.format = 'Espresso';
  217. this.properties = [];
  218. this.inputs = [];
  219. this.outputs = [];
  220. this.nodes = [];
  221. let [net, shape, weights] = this.targets;
  222. delete this.targets;
  223. if (!net) {
  224. const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.net');
  225. const content = await context.fetch(name);
  226. net = await content.read('json');
  227. }
  228. this.shapes = new Map();
  229. if (!shape) {
  230. const name = context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.shape');
  231. try {
  232. const content = await context.fetch(name);
  233. shape = await content.read('json');
  234. } catch {
  235. // continue regardless of error
  236. }
  237. }
  238. if (shape && shape.layer_shapes) {
  239. for (const [name, value] of Object.entries(shape.layer_shapes)) {
  240. const dimensions = [value.n, value.k, value.w, value.h];
  241. const shape = new espresso.TensorShape(dimensions);
  242. this.shapes.set(name, shape);
  243. }
  244. }
  245. this.blobs = new Map();
  246. if (!weights) {
  247. const name = net && net.storage ? net.storage : context.identifier.replace(/\.espresso\.(net|weights|shape)$/i, '.espresso.weights');
  248. try {
  249. const content = await context.fetch(name);
  250. weights = await content.read('binary');
  251. } catch {
  252. // continue regardless of error
  253. }
  254. }
  255. if (weights) {
  256. const reader = weights;
  257. const length = reader.uint64().toNumber();
  258. for (let i = 0; i < length; i++) {
  259. const key = reader.uint64().toNumber();
  260. const size = reader.uint64().toNumber();
  261. this.blobs.set(key, size);
  262. }
  263. for (const [key, size] of this.blobs) {
  264. const buffer = reader.read(size);
  265. this.blobs.set(key, buffer);
  266. }
  267. }
  268. this.values = new Map();
  269. if (net.format_version) {
  270. const major = Math.floor(net.format_version / 100);
  271. const minor = net.format_version % 100;
  272. this.format += ` v${major}.${minor}`;
  273. }
  274. if (net && Array.isArray(net.layers)) {
  275. for (const layer of net.layers) {
  276. const type = layer.type;
  277. const data = { ...layer };
  278. const top = layer.top.split(',').map((name) => this._value(name));
  279. const bottom = layer.bottom.split(',').map((name) => this._value(name));
  280. const obj = {};
  281. obj.type = type;
  282. obj.name = layer.name;
  283. obj.attributes = data;
  284. obj.inputs = [{ name: 'inputs', value: bottom }];
  285. obj.outputs = [{ name: 'outputs', value: top }];
  286. obj.chain = [];
  287. switch (type) {
  288. case 'convolution':
  289. case 'deconvolution': {
  290. this._weights(obj, data, [data.C, data.K, data.Nx, data.Ny]);
  291. if (data.has_biases) {
  292. obj.inputs.push(this._initializer('biases', data.blob_biases, 'float32', [data.C]));
  293. }
  294. delete data.has_biases;
  295. delete data.blob_biases;
  296. break;
  297. }
  298. case 'batchnorm': {
  299. obj.inputs.push(this._initializer('params', data.blob_batchnorm_params, 'float32', [4, data.C]));
  300. delete data.blob_batchnorm_params;
  301. break;
  302. }
  303. case 'inner_product': {
  304. this._weights(obj, data, [data.nC, data.nB]);
  305. if (data.has_biases) {
  306. obj.inputs.push(this._initializer('biases', data.blob_biases, 'float32', [data.nC]));
  307. }
  308. delete data.has_biases;
  309. delete data.blob_biases;
  310. break;
  311. }
  312. case 'instancenorm_1d':
  313. case 'dynamic_dequantize': {
  314. this._weights(obj, data, null);
  315. break;
  316. }
  317. default: {
  318. break;
  319. }
  320. }
  321. const blobs = Object.keys(data).filter((key) => key.startsWith('blob_'));
  322. if (blobs.length > 0) {
  323. throw new espresso.Error(`Unknown blob '${blobs.join(',')}' for type '${type}'.`);
  324. }
  325. if (data.has_prelu) {
  326. obj.chain.push({ type: 'prelu' });
  327. }
  328. if (data.fused_relu || data.has_relu) {
  329. obj.chain.push({ type: 'relu' });
  330. }
  331. if (data.fused_tanh || data.has_tanh) {
  332. obj.chain.push({ type: 'tanh' });
  333. }
  334. if (data.has_batch_norm) {
  335. obj.chain.push({ type: 'batch_norm' });
  336. }
  337. if (data.weights) {
  338. for (const [name, identifier] of Object.entries(data.weights)) {
  339. obj.inputs.push(this._initializer(name, identifier, 'float32', null));
  340. }
  341. delete data.weights;
  342. }
  343. delete data.name;
  344. delete data.type;
  345. delete data.top;
  346. delete data.bottom;
  347. delete data.fused_tanh;
  348. delete data.fused_relu;
  349. delete data.has_prelu;
  350. delete data.has_relu;
  351. delete data.has_tanh;
  352. delete data.has_batch_norm;
  353. this.nodes.push(obj);
  354. }
  355. }
  356. delete this.shapes;
  357. delete this.blobs;
  358. }
  359. _value(name) {
  360. if (!this.values.has(name)) {
  361. const shape = this.shapes.get(name);
  362. const type = shape ? new espresso.TensorType('float32', shape) : null;
  363. this.values.set(name, { name, type });
  364. }
  365. return this.values.get(name);
  366. }
  367. _weights(obj, data, dimensions) {
  368. if (data.blob_weights !== undefined) {
  369. obj.inputs.push(this._initializer('weights', data.blob_weights, 'float32', dimensions));
  370. delete data.blob_weights;
  371. return;
  372. }
  373. if (data.blob_weights_f16 !== undefined) {
  374. obj.inputs.push(this._initializer('weights', data.blob_weights_f16, 'float16', dimensions));
  375. delete data.blob_weights_f16;
  376. return;
  377. }
  378. const keys = ['wBeta', 'wGamma', 'W_S8', 'W_int8', 'W_t_int8'];
  379. for (const key of keys) {
  380. if (data.weights && data.weights[key] !== undefined) {
  381. let dataType = 'float32';
  382. let name = key;
  383. if (key.endsWith('_S8')) {
  384. dataType = 'int8';
  385. name = key.replace(/_S8$/, '');
  386. } else if (key.endsWith('_int8')) {
  387. dataType = 'int8';
  388. name = key.replace(/_int8$/, '');
  389. }
  390. obj.inputs.push(this._initializer(name, data.weights[key], dataType, dimensions));
  391. delete data.weights[key];
  392. }
  393. }
  394. }
  395. _initializer(name, identifier, dataType, dimensions) {
  396. if (!Number.isInteger(identifier)) {
  397. throw new espresso.Error(`Invalid '${identifier}' blob identifier.`);
  398. }
  399. dataType = dataType || 'float32';
  400. const blob = this.blobs.get(identifier);
  401. if (!dimensions) {
  402. const itemsize = dataType === 'float32' ? 4 : 1;
  403. dimensions = blob ? [blob.length / itemsize] : ['?'];
  404. }
  405. const shape = new espresso.TensorShape(dimensions);
  406. const type = new espresso.TensorType(dataType, shape);
  407. const value = {};
  408. const initializer = new espresso.Tensor(type, blob, null, 'Blob');
  409. value.value = new espresso.Value(`${identifier}\nblob`, type, null, initializer);
  410. return { name, value: [value] };
  411. }
  412. };
  413. espresso.Error = class extends Error {
  414. constructor(message) {
  415. super(message);
  416. this.name = 'Error loading Espresso model.';
  417. }
  418. };
  419. export const ModelFactory = espresso.ModelFactory;