keras.js 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480
  1. import * as json from './json.js';
  2. import * as python from './python.js';
  3. const keras = {};
  4. const tfjs = {};
  5. keras.ModelFactory = class {
  6. async match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  9. const group = await context.peek('hdf5');
  10. if (group && group.attributes && group.attributes.get('CLASS') !== 'hickle') {
  11. if (identifier === 'model.weights.h5') {
  12. return context.set('keras.model.weights.h5', group);
  13. }
  14. if (identifier === 'parameter.h5') {
  15. return context.set('hdf5.parameter.h5', group);
  16. }
  17. return context.set('keras.h5', group);
  18. }
  19. const json = await context.peek('json');
  20. if (json) {
  21. if (json.mxnet_version || (json.nodes && json.arg_nodes && json.heads)) {
  22. return null;
  23. }
  24. if (json.model_config || (json.class_name && json.config)) {
  25. return context.set('keras.config.json', json);
  26. }
  27. if (identifier === 'metadata.json' && json.keras_version) {
  28. return context.set('keras.metadata.json', json);
  29. }
  30. }
  31. const container = await tfjs.Container.open(context);
  32. if (container) {
  33. return context.set('tfjs', container);
  34. }
  35. const pickle = await context.peek('pkl');
  36. if (pickle && pickle.__class__ &&
  37. pickle.__class__.__module__ === 'keras.engine.sequential' &&
  38. pickle.__class__.__name__ === 'Sequential') {
  39. return context.set('tfjs.pickle', pickle);
  40. }
  41. // model.weights.npz
  42. const entries = await context.peek('npz');
  43. const regex = /^(__root__|layers\/.+|_layer_checkpoint_dependencies\/.+)\.npy$/;
  44. if (entries instanceof Map && entries.size > 0 && Array.from(entries).every(([name]) => regex.test(name))) {
  45. return context.set('keras.model.weights.npz', entries);
  46. }
  47. // keras_metadata.pb
  48. if (extension === 'pb' && context.stream && context.stream.length > 16) {
  49. const tags = await context.tags('pb');
  50. if (tags.size === 1 && tags.get(1) === 2) {
  51. const stream = context.stream;
  52. const buffer = stream.peek(Math.min(stream.length, 1024));
  53. const content = String.fromCharCode.apply(null, buffer);
  54. if (/root"/.test(content) && /\{\s*"class_name"\s*:/.test(content)) {
  55. return context.set('keras.pb.SavedMetadata');
  56. }
  57. }
  58. }
  59. return null;
  60. }
  61. filter(context, match) {
  62. if (context.type === 'keras.metadata.json' && (match.type === 'keras.config.json' || match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
  63. return false;
  64. }
  65. if (context.type === 'keras.config.json' && (match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
  66. return false;
  67. }
  68. if (context.type === 'tfjs' && match.type === 'tf.tfjs.weights') {
  69. return false;
  70. }
  71. return true;
  72. }
  73. async open(context) {
  74. const request_json = async (context, name) => {
  75. try {
  76. context = await context.fetch(name);
  77. } catch {
  78. return null;
  79. }
  80. return await context.read('json');
  81. };
  82. const _create_config = (weights_store) => {
  83. const config = {};
  84. config.class_name = 'Model';
  85. config.config = {};
  86. config.config.layers = [];
  87. const snake_to_pascal_case = (name) => {
  88. return name.replace(/(^|_|\d)([a-z])/g, (match, p1, p2) => p1 === '_' ? p2.toUpperCase() : p1 + p2.toUpperCase());
  89. };
  90. for (const [name, value] of weights_store) {
  91. const layer = {};
  92. layer.name = name;
  93. layer.class_name = name.split('/').pop().replace(/_[0-9]+$/, '');
  94. layer.class_name = snake_to_pascal_case(layer.class_name);
  95. layer.config = {};
  96. layer.config.name = name;
  97. layer._trainable_variables = value;
  98. config.config.layers.push(layer);
  99. }
  100. return config;
  101. };
  102. const _load_state = (trackable, weights_store, assets_store, inner_path) => {
  103. inner_path = inner_path || '';
  104. if (trackable && trackable.config && Array.isArray(trackable.config.layers)) {
  105. /* eslint-disable no-use-before-define */
  106. _load_container_state(trackable, weights_store, assets_store, inner_path ? `${inner_path}/layers` : 'layers');
  107. /* eslint-enable no-use-before-define */
  108. } else {
  109. const weights = weights_store.get(inner_path);
  110. if (weights) {
  111. trackable._trainable_variables = weights;
  112. }
  113. }
  114. };
  115. const _load_container_state = (container, weights_store, assets_store, inner_path) => {
  116. const used_names = new Map();
  117. for (const trackable of container.config.layers) {
  118. const pascal_to_snake_case = (name) => {
  119. name = name.replace(/\W+/g, "");
  120. name = name.replace(/(.)([A-Z][a-z]+)/g, (match, p1, p2) => `${p1}_${p2}`);
  121. name = name.replace(/([a-z])([A-Z])/g, (match, p1, p2) => `${p1}_${p2}`);
  122. return name.toLowerCase();
  123. };
  124. let name = pascal_to_snake_case(trackable.class_name);
  125. if (used_names.has(name)) {
  126. const next = used_names.get(name) + 1;
  127. used_names.set(name, next);
  128. name = `${name}_${next}`;
  129. } else {
  130. used_names.set(name, 0);
  131. }
  132. _load_state(trackable, weights_store, assets_store, `${inner_path}/${name}`);
  133. }
  134. };
  135. const read_weights_hdf5 = (group) => {
  136. const weights_store = new Map();
  137. const stack = [[group, '']];
  138. while (stack.length > 0) {
  139. const [group, path] = stack.pop();
  140. if (group.groups instanceof Map) {
  141. const checkpoint = group.groups.get('layers') || group.groups.get('_layer_checkpoint_dependencies');
  142. if (checkpoint) {
  143. for (const [key, layer] of checkpoint.groups) {
  144. const name = `${path ? `${path}/` : ''}layers/${key}`;
  145. stack.push([layer, name]);
  146. const values = [];
  147. for (const vars of layer.groups) {
  148. for (const [name, group] of vars[1].groups) {
  149. const variable = group.value;
  150. if (variable) {
  151. const layout = variable.littleEndian ? '<' : '>';
  152. const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.data);
  153. values.push(tensor);
  154. }
  155. }
  156. }
  157. if (values.length > 0) {
  158. weights_store.set(name, values);
  159. }
  160. }
  161. }
  162. }
  163. }
  164. return weights_store;
  165. };
  166. const read_weights_numpy = (entries) => {
  167. const weights_store = new Map();
  168. for (const [path, array] of entries) {
  169. const file = path.split('/').map((name) => name === '_layer_checkpoint_dependencies' ? 'layers' : name).join('/');
  170. if (file.endsWith('.npy') && file.startsWith('layers/')) {
  171. if (array.dtype.name === 'object' && array.shape.length === 0 && Array.isArray(array.data) && array.data.length === 1) {
  172. const values = Object.values(array.data[0]).map((array) => {
  173. const stride = array.strides.map((stride) => stride / array.itemsize);
  174. const dataType = array.dtype.__name__;
  175. const values = dataType === 'string' || dataType === 'object' ? array.flatten().tolist() : array.tobytes();
  176. const encoding = dataType === 'string' || dataType === 'object' ? '|' : array.dtype.byteorder;
  177. return new keras.Tensor('', array.shape, dataType, stride, null, encoding, values);
  178. });
  179. if (values.length > 0) {
  180. const name = file.replace(/\.npy$/, '');
  181. weights_store.set(name, values);
  182. }
  183. }
  184. }
  185. }
  186. return weights_store;
  187. };
  188. const request_weights = async (context) => {
  189. const formats = [
  190. ['model.weights.h5', 'hdf5', read_weights_hdf5],
  191. ['model.weights.npz', 'npz', read_weights_numpy],
  192. ];
  193. for (const [name, type, callback] of formats) {
  194. let content = null;
  195. try {
  196. /* eslint-disable no-await-in-loop */
  197. content = await context.fetch(name);
  198. /* eslint-enable no-await-in-loop */
  199. } catch {
  200. // continue regardless of error
  201. }
  202. if (content) {
  203. /* eslint-disable no-await-in-loop */
  204. const obj = await content.peek(type);
  205. /* eslint-enable no-await-in-loop */
  206. if (obj) {
  207. return callback(obj);
  208. }
  209. }
  210. }
  211. return new Map();
  212. };
  213. const open_model = async (format, producer, backend, config, weights) => {
  214. const metadata = await context.metadata('keras-metadata.json');
  215. return new keras.Model(metadata, format, producer, backend, config, weights);
  216. };
  217. switch (context.type) {
  218. case 'keras.config.json': {
  219. const obj = context.value;
  220. const config = obj.model_config ? obj.model_config : obj;
  221. const backend = obj.backend || '';
  222. let version = obj.keras_version ? obj.keras_version : null;
  223. if (!version) {
  224. const metadata = await request_json(context, 'metadata.json');
  225. if (metadata && metadata.keras_version) {
  226. version = metadata.keras_version;
  227. }
  228. }
  229. const format = `Keras${version ? ` v${version}` : ''}`;
  230. const weights_store = await request_weights(context);
  231. _load_state(config, weights_store);
  232. return open_model(format, '', backend, config, null);
  233. }
  234. case 'keras.model.weights.h5': {
  235. const group = context.value;
  236. const weights_store = read_weights_hdf5(group);
  237. const metadata = await request_json(context, 'metadata.json');
  238. let config = await request_json(context, 'config.json');
  239. const name = config ? 'Keras' : 'Keras Weights';
  240. const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
  241. if (config) {
  242. _load_state(config, weights_store);
  243. } else {
  244. config = _create_config(weights_store);
  245. }
  246. return await open_model(format, '', '', config, null);
  247. }
  248. case 'keras.model.weights.npz': {
  249. const entries = context.value;
  250. const weights_store = read_weights_numpy(entries);
  251. const metadata = await request_json(context, 'metadata.json');
  252. let config = await request_json(context, 'config.json');
  253. const name = config ? 'Keras' : 'Keras Weights';
  254. const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
  255. if (config) {
  256. _load_state(config, weights_store);
  257. } else {
  258. config = _create_config(weights_store);
  259. }
  260. return await open_model(format, '', '', config, null);
  261. }
  262. case 'keras.metadata.json': {
  263. const metadata = context.value;
  264. let config = await request_json(context, 'config.json');
  265. const name = config ? 'Keras' : 'Keras Weights';
  266. const format = name + (metadata.keras_version ? ` v${metadata.keras_version}` : '');
  267. const weights_store = await request_weights(context);
  268. if (!config && (!weights_store || weights_store.size === 0)) {
  269. throw new keras.Error("'config.json' or 'model.weights.*' not present.");
  270. }
  271. if (config) {
  272. _load_state(config, weights_store);
  273. } else {
  274. config = _create_config(weights_store);
  275. }
  276. return await open_model(format, '', '', config, null);
  277. }
  278. case 'hdf5.parameter.h5':
  279. case 'keras.h5': {
  280. const find_root_group = (root_group) => {
  281. const kerasmodel = root_group.group('model/kerasmodel');
  282. if (kerasmodel && kerasmodel.attributes.has('model_config')) {
  283. return kerasmodel;
  284. }
  285. return root_group;
  286. };
  287. const read_model_config = (group) => {
  288. if (group.attributes.has('model_config')) {
  289. const buffer = group.attributes.get('model_config');
  290. const reader = json.TextReader.open(buffer);
  291. if (reader) {
  292. return reader.read();
  293. }
  294. }
  295. return null;
  296. };
  297. const load_attributes_from_hdf5_group = (group, name) => {
  298. if (group.attributes.has(name)) {
  299. return group.attributes.get(name);
  300. }
  301. if (group.attributes.has(`${name}0`)) {
  302. let index = 0;
  303. let value = [];
  304. while (group.attributes.has(name + index.toString())) {
  305. const chunk = group.attributes.get(name + index.toString());
  306. value = value.concat(chunk);
  307. index++;
  308. }
  309. return value;
  310. }
  311. return null;
  312. };
  313. const weights = new keras.Weights();
  314. const group = context.value;
  315. const root_group = find_root_group(group);
  316. const model_config = read_model_config(root_group);
  317. if (model_config) {
  318. const backend = root_group.attributes.get('backend') || '';
  319. const version = root_group.attributes.get('keras_version') || '';
  320. const format = `Keras${version ? ` v${version}` : ''}`;
  321. const model_weights_group = root_group.group('model_weights');
  322. if (model_weights_group) {
  323. const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
  324. for (const layer_name of layer_names) {
  325. const layer_weights = model_weights_group.group(layer_name);
  326. if (layer_weights) {
  327. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  328. if (Array.isArray(weight_names) && weight_names.length > 0) {
  329. for (const weight_name of weight_names) {
  330. const weight = layer_weights.group(weight_name);
  331. if (weight && weight.value) {
  332. const variable = weight.value;
  333. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, variable.littleEndian ? '<' : '>', variable.data);
  334. weights.add(layer_name, tensor);
  335. }
  336. }
  337. }
  338. }
  339. }
  340. }
  341. if (!model_config) {
  342. throw new keras.Error("'model_config' is not present.");
  343. }
  344. if (!model_config.class_name) {
  345. throw new keras.Error("'class_name' is not present.");
  346. }
  347. return open_model(format, '', backend, model_config, weights);
  348. }
  349. const layer_names = load_attributes_from_hdf5_group(root_group, 'layer_names');
  350. if (layer_names && Array.isArray(layer_names)) {
  351. const version = root_group.attributes.get('keras_version') || '';
  352. const format = `Keras Weights${version ? ` v${version}` : ''}`;
  353. const backend = root_group.attributes.get('backend') || '';
  354. for (const layer_name of layer_names) {
  355. const layer_weights = root_group.group(layer_name);
  356. if (layer_weights) {
  357. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  358. if (Array.isArray(weight_names) && weight_names.length > 0) {
  359. for (const weight_name of weight_names) {
  360. const weight = layer_weights.group(weight_name);
  361. if (weight && weight.value) {
  362. const variable = weight.value;
  363. const components = weight_name.split('/');
  364. components.pop();
  365. const name = (components.length === 0 || components[0] !== layer_name) ? [layer_name].concat(components).join('/') : components.join('/');
  366. const encoding = variable.littleEndian ? '<' : '>';
  367. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, encoding, variable.data);
  368. weights.add(name, tensor);
  369. }
  370. }
  371. }
  372. }
  373. }
  374. return open_model(format, '', backend, null, weights);
  375. }
  376. const rootKeys = new Set(root_group.attributes.keys());
  377. rootKeys.delete('nb_layers');
  378. if (rootKeys.size > 0 || root_group.value !== null) {
  379. throw new keras.Error('File format is not HDF5 Weights.');
  380. }
  381. const format = 'HDF5 Weights';
  382. let weights_group = root_group;
  383. if (root_group.attributes.size === 0 && root_group.value === null && root_group.groups.size === 1) {
  384. const group = root_group.groups.values().next().value;
  385. if (group.attributes.size === 0 && group.value === null) {
  386. weights_group = group;
  387. }
  388. }
  389. const tensorKeys = new Set(['name', 'shape', 'quantization']);
  390. const groups = Array.from(weights_group.groups.values());
  391. if (groups.every((group) => group.attributes.size === 0 && group.groups.length === 0 && group.value !== null)) {
  392. for (const group of groups) {
  393. const variable = group.value;
  394. const layout = variable.littleEndian ? '<' : '>';
  395. const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  396. weights.add('', tensor);
  397. }
  398. return open_model(format, '', '', null, weights);
  399. }
  400. if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
  401. for (const group of groups) {
  402. const module = group.attributes.has('name') ? group.attributes.get('name') : group.name;
  403. for (const variableGroup of group.groups.values()) {
  404. if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
  405. throw new keras.Error('Variable format is not HDF5 Weights.');
  406. }
  407. const variable = variableGroup.value;
  408. if (!variable) {
  409. throw new keras.Error('Variable value is not HDF5 Weights.');
  410. }
  411. const name = module ? [module, variableGroup.name].join('/') : variableGroup.name;
  412. const layout = variable.littleEndian ? '<' : '>';
  413. const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  414. weights.add(module, tensor);
  415. }
  416. }
  417. return open_model(format, '', '', null, weights);
  418. }
  419. const walk = function(group) {
  420. if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
  421. for (const subGroup of group.groups.values()) {
  422. walk(subGroup);
  423. }
  424. return;
  425. }
  426. const subKeys = new Set(['index', 'need_grad']);
  427. const attribtues = Array.from(group.attributes.keys());
  428. const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
  429. if (match && group.value !== null && group.groups.size === 0) {
  430. const variable = group.value;
  431. const variableName = group.path;
  432. let moduleName = variableName;
  433. const parts = variableName.split('/');
  434. if (parts.length > 1) {
  435. parts.pop();
  436. moduleName = parts.join('/');
  437. }
  438. const layout = variable.littleEndian ? '<' : '>';
  439. const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  440. weights.add(moduleName, tensor);
  441. return;
  442. }
  443. throw new keras.Error('Module group format is not HDF5 Weights.');
  444. };
  445. walk(weights_group);
  446. return open_model(format, '', '', null, weights);
  447. }
  448. case 'tfjs': {
  449. const target = context.value;
  450. await target.read();
  451. return open_model(target.format, target.producer, target.backend, target.config, target.weights);
  452. }
  453. case 'keras.pickle': {
  454. const obj = context.value;
  455. const execution = new python.Execution();
  456. const decoder = new TextDecoder('utf-8');
  457. const format = `Keras Pickle${obj.keras_version ? ` v${decoder.decode(obj.keras_version)}` : ''}`;
  458. const backend = obj.backend ? decoder.decode(obj.backend) : '';
  459. const reader = json.TextReader.open(obj.model_config);
  460. const model_config = reader.read();
  461. const weights = new keras.Weights();
  462. const model_weights_group = obj.model_weights;
  463. if (model_weights_group) {
  464. const layer_names = model_weights_group.layer_names.map((buffer) => decoder.decode(buffer));
  465. for (const layer_name of layer_names) {
  466. const layer_weights = model_weights_group[layer_name];
  467. if (layer_weights) {
  468. const weight_names = layer_weights.weight_names.map((buffer) => decoder.decode(buffer));
  469. if (Array.isArray(weight_names) && weight_names.length > 0) {
  470. for (const weight_name of weight_names) {
  471. const buffer = layer_weights[weight_name];
  472. const pickle = execution.__import__('pickle');
  473. const unpickler = new pickle.Unpickler(buffer);
  474. const variable = unpickler.load();
  475. const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.__name__, null, null, '<', variable.data);
  476. weights.add(layer_name, tensor);
  477. }
  478. }
  479. }
  480. }
  481. }
  482. return open_model(format, '', backend, model_config, weights);
  483. }
  484. case 'keras.pb.SavedMetadata': {
  485. keras.proto = await context.require('./keras-proto');
  486. const format = 'Keras Saved Metadata';
  487. const reader = await context.read('protobuf.binary');
  488. const saved_metadata = keras.proto.third_party.tensorflow.python.keras.protobuf.SavedMetadata.decode(reader);
  489. if (!saved_metadata || !Array.isArray(saved_metadata.nodes) ||
  490. !saved_metadata.nodes.every((node) => node && typeof node.metadata === 'string' && node.metadata.length > 0)) {
  491. throw new keras.Error('Invalid keras.protobuf.SavedMetadata.');
  492. }
  493. const objects = new Map();
  494. for (const node of saved_metadata.nodes) {
  495. const reader = json.TextReader.open(node.metadata);
  496. node.metadata = reader.read();
  497. objects.set(node.node_path, node);
  498. }
  499. const model_config = objects.get('root').metadata;
  500. return open_model(format, '', '', model_config, null);
  501. }
  502. default: {
  503. throw new keras.Error(`Unsupported Keras format '${context.type}'.`);
  504. }
  505. }
  506. }
  507. };
  508. keras.Model = class {
  509. constructor(metadata, format, producer, backend, config, weights) {
  510. this.format = format;
  511. this.runtime = backend;
  512. this.producer = producer;
  513. metadata = new keras.GraphMetadata(metadata);
  514. this.modules = [new keras.Graph(metadata, config, weights)];
  515. }
  516. };
  517. keras.Graph = class {
  518. constructor(metadata, config, weights, group) {
  519. this.inputs = [];
  520. this.outputs = [];
  521. this.nodes = [];
  522. group = group || '';
  523. const values = new Map();
  524. values.map = (name, type, tensor) => {
  525. if (tensor) {
  526. return new keras.Value(name, type || null, tensor);
  527. }
  528. if (!values.has(name)) {
  529. values.set(name, new keras.Value(name, type || null, tensor || null));
  530. } else if (type || tensor) {
  531. throw new keras.Error(`Duplicate value '${name}'.`);
  532. }
  533. return values.get(name);
  534. };
  535. if (config) {
  536. const getInputType = (layer) => {
  537. if (layer && layer.config) {
  538. let dataType = '?';
  539. let shape = [];
  540. const config = layer.config;
  541. if (config.dtype) {
  542. dataType = config.dtype;
  543. delete config.dtype;
  544. }
  545. if (Array.isArray(config.batch_input_shape)) {
  546. shape = config.batch_input_shape.map((s) => s === null ? '?' : s);
  547. delete config.batch_input_shape;
  548. } else if (config.batch_input_shape &&
  549. config.batch_input_shape.class_name === '__tuple__' &&
  550. Array.isArray(config.batch_input_shape.items)) {
  551. shape = config.batch_input_shape.items.map((s) => s === null ? '?' : s);
  552. delete config.batch_input_shape;
  553. }
  554. return new keras.TensorType(dataType, new keras.TensorShape(shape));
  555. }
  556. return null;
  557. };
  558. this.name = config.name || (config.config && config.config.name ? config.config.name : '');
  559. this.description = config.class_name;
  560. let baseType = config.class_name;
  561. switch (baseType) {
  562. case '__Function__':
  563. this.type = 'function';
  564. break;
  565. case 'Sequential':
  566. case 'Functional':
  567. case 'Model': {
  568. break;
  569. }
  570. case 'Tokenizer': {
  571. config = { config: { layers: [config] } };
  572. baseType = 'Functional';
  573. break;
  574. }
  575. default: {
  576. const layers = Array.from(config.layers ? config.layers : config);
  577. const sequential = layers.every((layer) => layer.inbound_nodes === undefined);
  578. baseType = sequential ? 'Sequential' : 'Functional';
  579. break;
  580. }
  581. }
  582. switch (baseType) {
  583. case 'Sequential': {
  584. config = config.config;
  585. const outputs = null;
  586. let name = 'input';
  587. let index = -1;
  588. const layers = Array.from(config.layers ? config.layers : config);
  589. while (layers.length > 0) {
  590. const layer = layers.shift();
  591. let current = index.toString();
  592. index++;
  593. if (index === 0) {
  594. const type = getInputType(layer);
  595. let remove = false;
  596. if (layer.class_name === 'InputLayer' && layer.config && layer.config.name) {
  597. name = layer.config.name;
  598. remove = true;
  599. }
  600. const value = values.map(name, type);
  601. const argument = new keras.Argument(name, [value]);
  602. this.inputs.push(argument);
  603. if (remove) {
  604. continue;
  605. }
  606. }
  607. const nodeInputs = [{ name }];
  608. if (layer.config && layer.config.name) {
  609. current = layer.config.name;
  610. }
  611. name = current;
  612. let nodeOutputs = [name];
  613. if (index === layers.length) {
  614. if (outputs && outputs.length > 0) {
  615. nodeOutputs = [outputs[0]];
  616. name = null;
  617. }
  618. }
  619. layer.inputs = nodeInputs;
  620. layer.outputs = nodeOutputs;
  621. const node = new keras.Node(metadata, layer, group, weights, values);
  622. this.nodes.push(node);
  623. }
  624. if (name) {
  625. const value = values.map(name);
  626. const argument = new keras.Argument(name, [value]);
  627. this.outputs.push(argument);
  628. }
  629. break;
  630. }
  631. case '__Function__':
  632. case 'Functional':
  633. case 'Model': {
  634. config = config.config;
  635. const nodes = new Map();
  636. if (config.layers) {
  637. const is_constant = (item) => {
  638. return Array.isArray(item) && (item.length === 3 || item.length === 4) && item[0] === '_CONSTANT_VALUE' && item[1] === -1;
  639. };
  640. const is_connection = (item) => {
  641. return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string' && typeof item[1] === 'number' && typeof item[2] === 'number';
  642. };
  643. const read_value = (input_data) => {
  644. if (!Array.isArray(input_data)) {
  645. return input_data;
  646. }
  647. const transform = (value) => {
  648. if (value.every((item) => is_constant(item))) {
  649. for (let i = 0; i < value.length; i++) {
  650. value[i] = value[i][2];
  651. }
  652. } else if (value.every((item) => Array.isArray(item))) {
  653. const dims = value.map((item) => transform(item));
  654. const [dim] = dims;
  655. for (let i = 1; i < dims.length; i++) {
  656. if (dim.length === dims[i].length) {
  657. if (!dims[i].every((value, i) => value === dim[i])) {
  658. throw new python.Error('Invalid array shape.');
  659. }
  660. }
  661. }
  662. return [value.length].concat(dim);
  663. }
  664. return [value.length];
  665. };
  666. const shape = transform(input_data);
  667. const flatten = (input) => input.reduce((a, b) => a.concat(Array.isArray(b) ? flatten(b) : b), []);
  668. const value = flatten(input_data);
  669. return { shape, value };
  670. };
  671. const functional = config.layers.every((layer) => Array.isArray(layer.inbound_nodes));
  672. if (functional) {
  673. const read_connection = (input_data) => {
  674. const [node_name, node_index, tensor_index] = input_data;
  675. const inbound_node_key = `${node_name}[${node_index}]`;
  676. const inbound_node = nodes.get(inbound_node_key);
  677. const tensor_key = `${node_name}[${node_index}][${tensor_index}]`;
  678. if (inbound_node) {
  679. while (tensor_index >= inbound_node.outputs.length) {
  680. inbound_node.outputs.push(undefined);
  681. }
  682. inbound_node.outputs[tensor_index] = tensor_key;
  683. }
  684. return tensor_key;
  685. };
  686. const process_node = (node, inbound_node) => {
  687. if (Array.isArray(inbound_node) && inbound_node.length === 4 && typeof inbound_node[0] === 'string') {
  688. const key = read_connection(inbound_node);
  689. node.inputs.push({ name: key });
  690. for (const [name, value] of Object.entries(inbound_node[3])) {
  691. if (is_connection(value)) {
  692. const key = read_connection(value);
  693. node.inputs.push({ name: key });
  694. } else if (Array.isArray(value)) {
  695. const array = read_value(value);
  696. node.args[name] = array;
  697. } else {
  698. node.args[name] = value;
  699. }
  700. }
  701. } else if (Array.isArray(inbound_node)) {
  702. for (const input_data of inbound_node) {
  703. // [ 'conv2d', 0, 0 ] or [ 'conv2d', 0, 0, {} ]
  704. if (Array.isArray(input_data) && is_connection(input_data)) {
  705. const key = read_connection(input_data);
  706. node.inputs.push({ name: key });
  707. } else if (Array.isArray(input_data) && input_data.every((item) => is_connection(item))) {
  708. for (const input of input_data) {
  709. const key = read_connection(input);
  710. node.inputs.push({ name: key });
  711. }
  712. } else if (Array.isArray(input_data)) {
  713. const value = read_value(input_data);
  714. node.inputs.push(value);
  715. } else {
  716. throw new keras.Error(`Invalid inbound connection '${JSON.stringify(input_data)}'.`);
  717. }
  718. }
  719. } else if (inbound_node && inbound_node.args) {
  720. for (const arg of inbound_node.args) {
  721. if (arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history)) {
  722. const key = read_connection(arg.config.keras_history);
  723. node.inputs.push({ name: key });
  724. } else if (Array.isArray(arg) && arg.every((arg) => arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history))) {
  725. for (const input of arg) {
  726. const key = read_connection(input.config.keras_history);
  727. node.inputs.push({ name: key });
  728. }
  729. }
  730. }
  731. }
  732. };
  733. let legacy_format = true;
  734. for (const layer of config.layers) {
  735. if (Array.isArray(layer.inbound_nodes)) {
  736. for (const inbound_node of layer.inbound_nodes) {
  737. if (Array.isArray(inbound_node.args)) {
  738. legacy_format = false;
  739. }
  740. }
  741. }
  742. }
  743. for (const layer of config.layers) {
  744. const class_name = layer.class_name;
  745. let first_index = 0;
  746. if (legacy_format) {
  747. const keys = new Set(Object.keys(layer.config));
  748. const is_functional_config = keys.has('name') && keys.has('layers') && keys.has('input_layers') && keys.has('output_layers');
  749. if (class_name === 'Sequential' ||
  750. (is_functional_config && Array.isArray(layer.config.layers) && layer.config.layers.length > 0 && layer.config.layers[0].class_name === 'InputLayer')) {
  751. first_index++;
  752. }
  753. }
  754. if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 0) {
  755. layer.inputs = [];
  756. layer.outputs = [];
  757. layer.args = {};
  758. nodes.set(`${layer.name}[${first_index}]`, layer);
  759. } else if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 1) {
  760. layer.inputs = [];
  761. layer.outputs = [];
  762. layer.args = {};
  763. [layer.inbound_node] = layer.inbound_nodes;
  764. nodes.set(`${layer.name}[${first_index}]`, layer);
  765. } else {
  766. let config = {};
  767. switch (class_name) {
  768. case 'Functional':
  769. case 'Sequential':
  770. case 'Model': {
  771. config = layer;
  772. break;
  773. }
  774. default: {
  775. config.class_name = '__Function__';
  776. config.name = layer.name;
  777. config.config = {};
  778. config.config.layers = [{ ...layer }];
  779. delete config.config.layers[0].inbound_nodes;
  780. delete config.config.layers[0].input_layers;
  781. delete config.config.layers[0].output_layers;
  782. break;
  783. }
  784. }
  785. const type = new keras.Graph(metadata, config, weights, '');
  786. for (let i = 0; i < layer.inbound_nodes.length; i++) {
  787. const index = i + first_index;
  788. const key = `${layer.name}[${index}]`;
  789. const node = {};
  790. node.name = key;
  791. node.class_name = '__Function__';
  792. node.config = {};
  793. node.config.name = key;
  794. node.inputs = [];
  795. node.outputs = [];
  796. node.args = {};
  797. node.__type__ = type;
  798. node.inbound_node = layer.inbound_nodes[i];
  799. nodes.set(key, node);
  800. }
  801. }
  802. }
  803. for (const entry of nodes) {
  804. if (entry[1].inbound_node) {
  805. process_node(entry[1], entry[1].inbound_node);
  806. }
  807. }
  808. if (Array.isArray(config.input_layers)) {
  809. if (config.input_layers.length === 3 && typeof config.input_layers[0] === 'string' && Number.isInteger(config.input_layers[1]) && Number.isInteger(config.input_layers[2])) {
  810. config.input_layers = [config.input_layers];
  811. }
  812. for (let i = 0; i < config.input_layers.length; i++) {
  813. const input_data = config.input_layers[i];
  814. const name = read_connection(input_data);
  815. const [node_name, node_index] = input_data;
  816. const inbound_node_key = `${node_name}[${node_index}]`;
  817. const node = nodes.get(inbound_node_key);
  818. let type = null;
  819. if (node && node.class_name === 'InputLayer') {
  820. type = getInputType(node);
  821. nodes.delete(name);
  822. nodes.delete(inbound_node_key);
  823. }
  824. const value = values.map(name, type);
  825. const argument = new keras.Argument(node_name, [value]);
  826. this.inputs.push(argument);
  827. }
  828. }
  829. if (Array.isArray(config.output_layers)) {
  830. if (config.output_layers.length === 3 && typeof config.output_layers[0] === 'string' && Number.isInteger(config.output_layers[1]) && Number.isInteger(config.output_layers[2])) {
  831. config.output_layers = [config.output_layers];
  832. }
  833. for (let i = 0; i < config.output_layers.length; i++) {
  834. const output_data = config.output_layers[i];
  835. const [name] = output_data;
  836. const key = read_connection(output_data);
  837. const value = values.map(key);
  838. const argument = new keras.Argument(name, [value]);
  839. this.outputs.push(argument);
  840. }
  841. }
  842. } else {
  843. for (const layer of config.layers) {
  844. layer.inputs = [];
  845. layer.outputs = [];
  846. layer.args = {};
  847. nodes.set(`${layer.name}[0]`, layer);
  848. }
  849. }
  850. }
  851. for (const entry of nodes) {
  852. const node = new keras.Node(metadata, entry[1], group, weights, values);
  853. this.nodes.push(node);
  854. }
  855. break;
  856. }
  857. default: {
  858. throw new keras.Error(`'${config.class_name}' is not supported.`);
  859. }
  860. }
  861. } else if (weights) {
  862. this.type = 'weights';
  863. for (const name of weights.keys()) {
  864. if (weights.get('', name).length <= 6) {
  865. const layer = { class_name: 'Weights', config: { name } };
  866. const node = new keras.Node(metadata, layer, '', weights, values);
  867. this.nodes.push(node);
  868. }
  869. }
  870. }
  871. }
  872. };
  873. keras.Argument = class {
  874. constructor(name, value, type = null, visible = true) {
  875. this.name = name;
  876. this.value = value;
  877. this.type = type;
  878. this.visible = visible;
  879. }
  880. };
  881. keras.Value = class {
  882. constructor(name, type, initializer = null) {
  883. if (typeof name !== 'string') {
  884. throw new keras.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  885. }
  886. this.name = name;
  887. this.type = !type && initializer ? initializer.type : type;
  888. this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
  889. this.initializer = initializer;
  890. }
  891. };
  892. keras.Node = class {
  893. constructor(metadata, layer, group, weights, values) {
  894. const config = layer.config || {};
  895. const args = layer.args || {};
  896. let inputs = layer.inputs || [];
  897. let outputs = layer.outputs || [];
  898. const name = config && config.name ? config.name : '';
  899. this.group = group || '';
  900. this.name = (this.group ? `${this.group}/` : '') + name;
  901. this.inputs = [];
  902. this.outputs = [];
  903. this.attributes = [];
  904. this.chain = [];
  905. let names = [name];
  906. let class_name = layer.class_name;
  907. let model = false;
  908. switch (class_name) {
  909. case '__Function__': {
  910. this.type = layer.__type__;
  911. model = true;
  912. break;
  913. }
  914. case 'Model':
  915. case 'Functional':
  916. case 'Sequential': {
  917. const name = layer.name || (layer.config ? layer.config.name : '');
  918. this.type = new keras.Graph(metadata, layer, weights, (group ? `${group}/` : '') + name);
  919. model = true;
  920. if (config) {
  921. delete config.layers;
  922. delete config.input_layers;
  923. delete config.output_layers;
  924. }
  925. this.inputs = [new keras.Argument('inputs', inputs.map((input) => values.map(input.name)))];
  926. this.outputs = [new keras.Argument('outputs', outputs.map((name) => values.map(name)))];
  927. inputs = [];
  928. outputs = [];
  929. break;
  930. }
  931. case 'Wrapper':
  932. case 'Bidirectional':
  933. case 'TimeDistributed': {
  934. if (config && config.layer) {
  935. const inner = config.layer;
  936. delete config.layer;
  937. this.inner = new keras.Node(metadata, inner, null, null, values);
  938. if (class_name === 'Bidirectional' && inner.config.name) {
  939. names = [`${name}/forward_${inner.config.name}`, `${name}/backward_${inner.config.name}`];
  940. if (!group) {
  941. group = name;
  942. }
  943. }
  944. }
  945. this.type = metadata.type(class_name) || { name: class_name };
  946. break;
  947. }
  948. case 'TFOpLambda': {
  949. if (config && config.function) {
  950. class_name = config.function;
  951. delete config.function;
  952. }
  953. this.type = metadata.type(class_name) || { name: class_name };
  954. break;
  955. }
  956. default: {
  957. this.type = metadata.type(class_name) || { name: class_name };
  958. break;
  959. }
  960. }
  961. if (layer._trainable_variables) {
  962. if (inputs.length === 0 && Array.isArray(this.type.inputs) && this.type.inputs.length > 0) {
  963. // weights-only, remove 'input' from type metadata
  964. this.type = { ...this.type };
  965. this.type.inputs = this.type.inputs.slice(1);
  966. }
  967. for (const variable of layer._trainable_variables) {
  968. inputs.push({ name: '', initializer: variable });
  969. }
  970. } else if (weights && !model) {
  971. for (const name of names) {
  972. let tensors = weights.get(group, name);
  973. if (tensors.length > 0) {
  974. for (const initializer of tensors) {
  975. inputs.push({ name: initializer.name, initializer });
  976. }
  977. } else {
  978. tensors = weights.get('', name);
  979. for (const initializer of tensors) {
  980. inputs.push({ name: initializer.name, initializer });
  981. }
  982. }
  983. }
  984. }
  985. const attributes = [];
  986. const convertAttributeValue = (value) => {
  987. if (Array.isArray(value) || value !== Object(value)) {
  988. return value;
  989. }
  990. const obj = {};
  991. if (value.class_name) {
  992. obj.__type__ = value.class_name;
  993. }
  994. if (value.config) {
  995. const config = value.config;
  996. for (const [key, value] of Object.entries(config)) {
  997. obj[key] = convertAttributeValue(value);
  998. }
  999. }
  1000. return obj;
  1001. };
  1002. if (config && !Array.isArray(config)) {
  1003. for (const [name, value] of Object.entries(config)) {
  1004. if (class_name !== 'Activation' && name === 'activation' && value !== 'linear') {
  1005. if (typeof value === 'string') {
  1006. const config = { activation: value };
  1007. const node = new keras.Node(metadata, { class_name: 'Activation', config }, null, null, value);
  1008. this.chain.push(node);
  1009. } else if (value && typeof value.class_name === 'string' && value.config) {
  1010. const type = value.class_name;
  1011. if (!metadata.type(type)) {
  1012. metadata.add(type, { name: type, category: 'Activation' });
  1013. }
  1014. const node = new keras.Node(metadata, value, null, null, value);
  1015. this.chain.push(node);
  1016. }
  1017. }
  1018. if (name !== 'name' && name !== 'batch_input_shape') {
  1019. const schema = metadata.attribute(class_name, name);
  1020. attributes.push([schema, name, value]);
  1021. }
  1022. }
  1023. }
  1024. const innerType = this.inner ? this.inner.type : null;
  1025. const innerMetadata = innerType ? metadata.type(innerType) : null;
  1026. let inputIndex = 0;
  1027. while (inputs.length > 0) {
  1028. let list = false;
  1029. let name = null;
  1030. let visible = true;
  1031. if (!innerMetadata || inputIndex === 0) {
  1032. if (this.type && this.type.inputs && inputIndex < this.type.inputs.length) {
  1033. const input = this.type.inputs[inputIndex];
  1034. name = input.name;
  1035. if (class_name === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
  1036. inputIndex++;
  1037. continue;
  1038. }
  1039. visible = input.visible !== false;
  1040. if (this.type.inputs[inputIndex].list) {
  1041. list = true;
  1042. }
  1043. }
  1044. } else {
  1045. switch (class_name) {
  1046. case 'Bidirectional': {
  1047. let innerIndex = inputIndex;
  1048. if (innerMetadata && innerMetadata.inputs) {
  1049. if (innerIndex < innerMetadata.inputs.length) {
  1050. name = `forward_${innerMetadata.inputs[innerIndex].name}`;
  1051. } else {
  1052. innerIndex = innerIndex - innerMetadata.inputs.length + 1;
  1053. if (innerIndex < innerMetadata.inputs.length) {
  1054. name = `backward_${innerMetadata.inputs[innerIndex].name}`;
  1055. }
  1056. }
  1057. }
  1058. visible = false;
  1059. break;
  1060. }
  1061. case 'TimeDistributed':
  1062. if (innerMetadata && innerMetadata.inputs && inputIndex < innerMetadata.inputs.length) {
  1063. name = innerMetadata.inputs[inputIndex].name;
  1064. }
  1065. break;
  1066. default:
  1067. break;
  1068. }
  1069. }
  1070. const input = list ? inputs.splice(0, inputs.length) : [inputs.shift()];
  1071. const inputArguments = input.map((input) => {
  1072. if (input.name) {
  1073. return values.map(input.name, null, input.initializer);
  1074. }
  1075. if (input.initializer) {
  1076. return values.map(input.name, null, input.initializer);
  1077. }
  1078. if (input.value !== undefined) {
  1079. const tensor = new keras.Tensor('', input.shape, config.dtype || '?', null, null, '|', input.value);
  1080. return values.map('', null, tensor);
  1081. }
  1082. throw new keras.Error(`Invalid argument '${JSON.stringify(input.name)}'.`);
  1083. });
  1084. if (!name && inputArguments.length === 1 && inputArguments[0].initializer && inputArguments[0].initializer.name) {
  1085. if (names.length === 1 && names[0] === '') {
  1086. name = inputArguments[0].initializer.name;
  1087. } else {
  1088. const parts = inputArguments[0].initializer.name.split('/').pop().split(':').shift().split('_');
  1089. const inputName1 = parts.pop();
  1090. const inputName2 = parts.length > 0 ? [parts.pop(), inputName1].join('_') : '';
  1091. const inputNames = new Set(['recurrent_kernel', 'running_mean', 'running_std', 'moving_mean', 'moving_variance', 'depthwise_filter', 'pointwise_filter']);
  1092. name = inputNames.has(inputName2) ? inputName2 : inputName1;
  1093. }
  1094. }
  1095. const argument = new keras.Argument(name || inputIndex.toString(), inputArguments, null, visible);
  1096. this.inputs.push(argument);
  1097. inputIndex++;
  1098. }
  1099. for (let i = 0; i < outputs.length; i++) {
  1100. const output = outputs[i];
  1101. const name = this.type && this.type.outputs && i < this.type.outputs.length && this.type.outputs[i] && this.type.outputs[i].name ? this.type.outputs[i].name : i.toString();
  1102. const argument = new keras.Argument(name, output === undefined || output.length === 0 ? [] : [values.map(output)]);
  1103. this.outputs.push(argument);
  1104. }
  1105. const inputTypes = new Map((this.type.inputs || []).map((input) => [input.name, input.type]));
  1106. for (const [name, arg] of Object.entries(args)) {
  1107. if (name !== 'name') {
  1108. if ((arg && arg.name) || (inputTypes.has(name) && inputTypes.get(name) === 'Tensor' && arg)) {
  1109. if (arg.name) {
  1110. const value = values.map(arg.name);
  1111. const argument = new keras.Argument(name, [value]);
  1112. this.inputs.push(argument);
  1113. } else {
  1114. const tensor = new keras.Tensor('', arg.shape, config.dtype || '?', null, null, '|', arg.value);
  1115. const value = values.map('', null, tensor);
  1116. const argument = new keras.Argument(name, [value]);
  1117. this.inputs.push(argument);
  1118. }
  1119. } else {
  1120. const schema = metadata.attribute(class_name, name);
  1121. this.attributes.push([schema, name, arg]);
  1122. }
  1123. }
  1124. }
  1125. this.attributes = attributes.map(([metadata, name, value]) => {
  1126. let type = null;
  1127. let visible = true;
  1128. if (value && typeof value === 'object' && value.class_name && value.config) {
  1129. value = convertAttributeValue(value);
  1130. }
  1131. switch (name) {
  1132. case 'trainable':
  1133. type = 'boolean';
  1134. visible = false;
  1135. break;
  1136. case 'dtype':
  1137. visible = false;
  1138. break;
  1139. default: {
  1140. if (metadata) {
  1141. type = metadata.type ? metadata.type : type;
  1142. if (metadata.visible === false) {
  1143. visible = false;
  1144. } else if (metadata.default !== undefined) {
  1145. if (Array.isArray(value)) {
  1146. if (Array.isArray(metadata.default)) {
  1147. visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
  1148. } else {
  1149. visible = !value.every((item) => item === metadata.default);
  1150. }
  1151. } else {
  1152. visible = value !== metadata.default;
  1153. }
  1154. }
  1155. }
  1156. break;
  1157. }
  1158. }
  1159. return new keras.Argument(name, value, type, visible);
  1160. });
  1161. if (typeof this.type.name !== 'string' || !this.type.name.split) { // #416
  1162. throw new keras.Error(`Unsupported node type '${JSON.stringify(this.type.name)}'.`);
  1163. }
  1164. }
  1165. };
  1166. keras.Tensor = class {
  1167. constructor(name, shape, type, stride, quantization, encoding, data, location) {
  1168. this.name = name;
  1169. this.type = new keras.TensorType(type, new keras.TensorShape(shape));
  1170. this.stride = stride;
  1171. this.encoding = encoding;
  1172. this._data = data;
  1173. this.location = location;
  1174. if (quantization && (quantization.scale !== 0 || quantization.min !== 0)) {
  1175. this.quantization = {
  1176. type: 'linear',
  1177. scale: [quantization.scale],
  1178. min: [quantization.min]
  1179. };
  1180. }
  1181. }
  1182. get values() {
  1183. if (this.encoding === '|') {
  1184. return this._data;
  1185. }
  1186. if (this._data === null) {
  1187. return null;
  1188. }
  1189. return this._data instanceof Uint8Array ? this._data : this._data.peek();
  1190. }
  1191. };
  1192. keras.TensorType = class {
  1193. constructor(dataType, shape) {
  1194. this.dataType = dataType;
  1195. this.shape = shape;
  1196. }
  1197. toString() {
  1198. return this.dataType + this.shape.toString();
  1199. }
  1200. };
  1201. keras.TensorShape = class {
  1202. constructor(dimensions) {
  1203. this.dimensions = dimensions;
  1204. }
  1205. toString() {
  1206. return this.dimensions && this.dimensions.length > 0 ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  1207. }
  1208. };
  1209. keras.GraphMetadata = class {
  1210. constructor(metadata) {
  1211. this._metadata = metadata;
  1212. this._types = new Map();
  1213. }
  1214. type(name) {
  1215. if (this._types.has(name)) {
  1216. return this._types.get(name);
  1217. }
  1218. return this._metadata.type(name);
  1219. }
  1220. attribute(type, name) {
  1221. return this._metadata.attribute(type, name);
  1222. }
  1223. add(type, metadata) {
  1224. this._types.set(type, metadata);
  1225. }
  1226. };
  1227. keras.Weights = class {
  1228. constructor() {
  1229. this._map = new Map();
  1230. }
  1231. get empty() {
  1232. return this._map.size === 0;
  1233. }
  1234. add(layer_name, tensor) {
  1235. if (!this._map.has(layer_name)) {
  1236. this._map.set(layer_name, []);
  1237. }
  1238. this._map.get(layer_name).push(tensor);
  1239. }
  1240. get(group, name) {
  1241. if (group) {
  1242. const list = this._map.get(group.split('/').shift());
  1243. if (list) {
  1244. const match1 = list.filter((tensor) => tensor.name.startsWith(`${name}/`));
  1245. if (match1.length > 0) {
  1246. return match1;
  1247. }
  1248. const match2 = list.filter((tensor) => tensor.name.startsWith(`${group}/${name}/`));
  1249. if (match2.length > 0) {
  1250. return match2;
  1251. }
  1252. }
  1253. } else {
  1254. const match1 = this._map.get(name);
  1255. if (match1 && match1.length > 0) {
  1256. return match1;
  1257. }
  1258. const match2 = this._map.get('');
  1259. if (match2 && match2.length > 0) {
  1260. const match3 = match2.filter((tensor) => tensor.name.startsWith(`${(group ? `${group}/` : '') + name}/`));
  1261. if (match3.length > 0) {
  1262. return match3;
  1263. }
  1264. }
  1265. }
  1266. return [];
  1267. }
  1268. keys() {
  1269. return this._map.keys();
  1270. }
  1271. };
  1272. keras.Error = class extends Error {
  1273. constructor(message) {
  1274. super(message);
  1275. this.name = 'Error loading Keras model.';
  1276. }
  1277. };
  1278. tfjs.Container = class {
  1279. static async open(context) {
  1280. const json = await context.peek('json');
  1281. if (json) {
  1282. if (json.modelTopology && (json.format === 'layers-model' || json.modelTopology.class_name || json.modelTopology.model_config)) {
  1283. return new tfjs.Container(context, '');
  1284. }
  1285. if (Array.isArray(json) && json.every((item) => item.weights && item.paths)) {
  1286. return new tfjs.Container(context, 'weights.json');
  1287. }
  1288. if (json.tfjsVersion) {
  1289. return new tfjs.Container(context, 'metadata');
  1290. }
  1291. }
  1292. const identifier = context.identifier;
  1293. if (/^.*group\d+-shard\d+of\d+(\.bin)?$/.test(identifier)) {
  1294. return new tfjs.Container(context, 'weights.bin');
  1295. }
  1296. return null;
  1297. }
  1298. constructor(context, type) {
  1299. this.context = context;
  1300. this.type = type;
  1301. }
  1302. async read() {
  1303. const context = this.context;
  1304. switch (this.type) {
  1305. case '': {
  1306. const obj = await context.peek('json');
  1307. return this._openModelJson(obj);
  1308. }
  1309. case 'weights.json': {
  1310. this.format = 'TensorFlow.js Weights';
  1311. this.config = null;
  1312. const obj = await context.peek('json');
  1313. const manifests = Array.from(obj);
  1314. for (const manifest of manifests) {
  1315. for (const weight of manifest.weights) {
  1316. const name = weight.name;
  1317. const index = name.lastIndexOf('/');
  1318. weight.identifier = index === -1 ? name : name.substring(0, index);
  1319. }
  1320. }
  1321. return this._openManifests(manifests);
  1322. }
  1323. case 'weights.bin': {
  1324. const content = await this.context.fetch('model.json');
  1325. const obj = await content.read('json');
  1326. return this._openModelJson(obj);
  1327. }
  1328. case 'metadata': {
  1329. const content = await this.context.fetch('model.json');
  1330. const obj = await content.read('json');
  1331. return this._openModelJson(obj);
  1332. }
  1333. default: {
  1334. throw new tfjs.Error(`Unsupported TensorFlow.js format '${this.type}'.`);
  1335. }
  1336. }
  1337. }
  1338. _openShards(manifests, shards) {
  1339. this.weights = new keras.Weights();
  1340. const dtype_size_map = new Map([
  1341. ['float16', 2], ['float32', 4], ['float64', 8],
  1342. ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
  1343. ['uint8', 1], ['uint16', 2], ['uint32', 4], ['uint64', 8]
  1344. ]);
  1345. for (const manifest of manifests) {
  1346. let buffer = null;
  1347. let location = '';
  1348. if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
  1349. const list = manifest.paths.map((path) => shards.get(path));
  1350. location = manifest.paths.join(', ');
  1351. const size = list.reduce((a, b) => a + b.length, 0);
  1352. buffer = new Uint8Array(size);
  1353. let offset = 0;
  1354. for (const item of list) {
  1355. buffer.set(item, offset);
  1356. offset += item.length;
  1357. }
  1358. }
  1359. let offset = 0;
  1360. for (const weight of manifest.weights) {
  1361. const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
  1362. if (!dtype_size_map.has(dtype)) {
  1363. throw new keras.Error(`Unsupported weight data type size '${dtype}'.`);
  1364. }
  1365. const itemsize = dtype_size_map.get(dtype);
  1366. const size = weight.shape.reduce((a, b) => a * b, 1);
  1367. const length = itemsize * size;
  1368. const data = buffer ? buffer.slice(offset, offset + length) : null;
  1369. const tensor = new keras.Tensor(weight.name, weight.shape, dtype, null, weight.quantization, '<', data, location);
  1370. this.weights.add(weight.identifier, tensor);
  1371. offset += length;
  1372. }
  1373. }
  1374. }
  1375. async _openManifests(manifests) {
  1376. const shards = new Map();
  1377. for (const manifest of manifests) {
  1378. for (const path of manifest.paths) {
  1379. if (!shards.has(path)) {
  1380. const promise = this.context.fetch(path);
  1381. shards.set(path, promise);
  1382. }
  1383. }
  1384. }
  1385. const promises = shards.values();
  1386. try {
  1387. const contexts = await Promise.all(promises);
  1388. for (const key of shards.keys()) {
  1389. const context = contexts.shift();
  1390. const buffer = context.stream.peek();
  1391. shards.set(key, buffer);
  1392. }
  1393. this._openShards(manifests, shards);
  1394. } catch {
  1395. shards.clear();
  1396. this._openShards(manifests, shards);
  1397. }
  1398. }
  1399. _openModelJson(obj) {
  1400. if (!obj || !obj.modelTopology || (obj.format !== 'layers-model' && !obj.modelTopology.model_config && !obj.modelTopology.config)) {
  1401. throw new tfjs.Error('File format is not TensorFlow.js layers-model.');
  1402. }
  1403. const modelTopology = obj.modelTopology;
  1404. this.format = `TensorFlow.js ${obj.format ? obj.format : `Keras${modelTopology.keras_version ? (` v${modelTopology.keras_version}`) : ''}`}`;
  1405. this.producer = obj.convertedBy || obj.generatedBy || '';
  1406. this.backend = modelTopology.backend || '';
  1407. const manifests = obj.weightsManifest;
  1408. for (const manifest of manifests) {
  1409. for (const weight of manifest.weights) {
  1410. weight.identifier = '';
  1411. }
  1412. }
  1413. this.config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
  1414. return this._openManifests(manifests);
  1415. }
  1416. };
  1417. tfjs.Error = class extends Error {
  1418. constructor(message) {
  1419. super(message);
  1420. this.name = 'Error loading TensorFlow.js model.';
  1421. }
  1422. };
  1423. export const ModelFactory = keras.ModelFactory;