keras.js 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483
  1. import * as json from './json.js';
  2. import * as python from './python.js';
  3. const keras = {};
  4. const tfjs = {};
  5. keras.ModelFactory = class {
  6. async match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  9. const group = await context.peek('hdf5');
  10. if (group && group.attributes && group.attributes.get('CLASS') !== 'hickle') {
  11. if (identifier === 'model.weights.h5') {
  12. return context.set('keras.model.weights.h5', group);
  13. }
  14. if (identifier === 'parameter.h5') {
  15. return context.set('hdf5.parameter.h5', group);
  16. }
  17. return context.set('keras.h5', group);
  18. }
  19. const json = await context.peek('json');
  20. if (json) {
  21. if (json.mxnet_version || (json.nodes && json.arg_nodes && json.heads)) {
  22. return null;
  23. }
  24. if (json.model_config || (json.class_name && json.config)) {
  25. return context.set('keras.config.json', json);
  26. }
  27. if (identifier === 'metadata.json' && json.keras_version) {
  28. return context.set('keras.metadata.json', json);
  29. }
  30. }
  31. const container = await tfjs.Container.open(context);
  32. if (container) {
  33. return context.set('tfjs', container);
  34. }
  35. const pickle = await context.peek('pkl');
  36. if (pickle && pickle.__class__ &&
  37. pickle.__class__.__module__ === 'keras.engine.sequential' &&
  38. pickle.__class__.__name__ === 'Sequential') {
  39. return context.set('tfjs.pickle', pickle);
  40. }
  41. // model.weights.npz
  42. const entries = await context.peek('npz');
  43. const regex = /^(__root__|layers\/.+|_layer_checkpoint_dependencies\/.+)\.npy$/;
  44. if (entries instanceof Map && entries.size > 0 && Array.from(entries).every(([name]) => regex.test(name))) {
  45. return context.set('keras.model.weights.npz', entries);
  46. }
  47. // keras_metadata.pb
  48. if (extension === 'pb' && context.stream && context.stream.length > 16) {
  49. const tags = await context.tags('pb');
  50. if (tags.size === 1 && tags.get(1) === 2) {
  51. const stream = context.stream;
  52. const buffer = stream.peek(Math.min(stream.length, 1024));
  53. const content = String.fromCharCode.apply(null, buffer);
  54. if (/root"/.test(content) && /\{\s*"class_name"\s*:/.test(content)) {
  55. return context.set('keras.pb.SavedMetadata');
  56. }
  57. }
  58. }
  59. return null;
  60. }
  61. filter(context, match) {
  62. if (context.type === 'keras.metadata.json' && (match.type === 'keras.config.json' || match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
  63. return false;
  64. }
  65. if (context.type === 'keras.config.json' && (match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
  66. return false;
  67. }
  68. if (context.type === 'tfjs' && match.type === 'tf.tfjs.weights') {
  69. return false;
  70. }
  71. return true;
  72. }
  73. async open(context) {
  74. const request_json = async (context, name) => {
  75. try {
  76. context = await context.fetch(name);
  77. } catch {
  78. return null;
  79. }
  80. return await context.read('json');
  81. };
  82. const _create_config = (weights_store) => {
  83. const config = {};
  84. config.class_name = 'Model';
  85. config.config = {};
  86. config.config.layers = [];
  87. const snake_to_pascal_case = (name) => {
  88. return name.replace(/(^|_|\d)([a-z])/g, (match, p1, p2) => p1 === '_' ? p2.toUpperCase() : p1 + p2.toUpperCase());
  89. };
  90. for (const [name, value] of weights_store) {
  91. const layer = {};
  92. layer.name = name;
  93. layer.class_name = name.split('/').pop().replace(/_[0-9]+$/, '');
  94. layer.class_name = snake_to_pascal_case(layer.class_name);
  95. layer.config = {};
  96. layer.config.name = name;
  97. layer._trainable_variables = value;
  98. config.config.layers.push(layer);
  99. }
  100. return config;
  101. };
  102. const _load_state = (trackable, weights_store, assets_store, inner_path) => {
  103. inner_path = inner_path || '';
  104. if (trackable && trackable.config && Array.isArray(trackable.config.layers)) {
  105. /* eslint-disable no-use-before-define */
  106. _load_container_state(trackable, weights_store, assets_store, inner_path ? `${inner_path}/layers` : 'layers');
  107. /* eslint-enable no-use-before-define */
  108. } else {
  109. const weights = weights_store.get(inner_path);
  110. if (weights) {
  111. trackable._trainable_variables = weights;
  112. }
  113. }
  114. };
  115. const _load_container_state = (container, weights_store, assets_store, inner_path) => {
  116. const used_names = new Map();
  117. for (const trackable of container.config.layers) {
  118. const pascal_to_snake_case = (name) => {
  119. name = name.replace(/\W+/g, "");
  120. name = name.replace(/(.)([A-Z][a-z]+)/g, (match, p1, p2) => `${p1}_${p2}`);
  121. name = name.replace(/([a-z])([A-Z])/g, (match, p1, p2) => `${p1}_${p2}`);
  122. return name.toLowerCase();
  123. };
  124. let name = pascal_to_snake_case(trackable.class_name);
  125. if (used_names.has(name)) {
  126. const next = used_names.get(name) + 1;
  127. used_names.set(name, next);
  128. name = `${name}_${next}`;
  129. } else {
  130. used_names.set(name, 0);
  131. }
  132. _load_state(trackable, weights_store, assets_store, `${inner_path}/${name}`);
  133. }
  134. };
  135. const read_weights_hdf5 = (group) => {
  136. const weights_store = new Map();
  137. const stack = [[group, '']];
  138. while (stack.length > 0) {
  139. const [group, path] = stack.pop();
  140. if (group.groups instanceof Map) {
  141. const checkpoint = group.groups.get('layers') || group.groups.get('_layer_checkpoint_dependencies');
  142. if (checkpoint) {
  143. for (const [key, layer] of checkpoint.groups) {
  144. const name = `${path ? `${path}/` : ''}layers/${key}`;
  145. stack.push([layer, name]);
  146. const values = [];
  147. for (const vars of layer.groups) {
  148. for (const [name, group] of vars[1].groups) {
  149. const variable = group.value;
  150. if (variable) {
  151. const layout = variable.littleEndian ? '<' : '>';
  152. const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.data);
  153. values.push(tensor);
  154. }
  155. }
  156. }
  157. if (values.length > 0) {
  158. weights_store.set(name, values);
  159. }
  160. }
  161. }
  162. }
  163. }
  164. return weights_store;
  165. };
  166. const read_weights_numpy = (entries) => {
  167. const weights_store = new Map();
  168. for (const [path, array] of entries) {
  169. const file = path.split('/').map((name) => name === '_layer_checkpoint_dependencies' ? 'layers' : name).join('/');
  170. if (file.endsWith('.npy') && file.startsWith('layers/')) {
  171. if (array.dtype.name === 'object' && array.shape.length === 0 && Array.isArray(array.data) && array.data.length === 1) {
  172. const values = Object.values(array.data[0]).map((array) => {
  173. const stride = array.strides.map((stride) => stride / array.itemsize);
  174. const dataType = array.dtype.__name__;
  175. const values = dataType === 'string' || dataType === 'object' ? array.flatten().tolist() : array.tobytes();
  176. const encoding = dataType === 'string' || dataType === 'object' ? '|' : array.dtype.byteorder;
  177. return new keras.Tensor('', array.shape, dataType, stride, null, encoding, values);
  178. });
  179. if (values.length > 0) {
  180. const name = file.replace(/\.npy$/, '');
  181. weights_store.set(name, values);
  182. }
  183. }
  184. }
  185. }
  186. return weights_store;
  187. };
  188. const request_weights = async (context) => {
  189. const formats = [
  190. ['model.weights.h5', 'hdf5', read_weights_hdf5],
  191. ['model.weights.npz', 'npz', read_weights_numpy],
  192. ];
  193. for (const [name, type, callback] of formats) {
  194. let content = null;
  195. try {
  196. /* eslint-disable no-await-in-loop */
  197. content = await context.fetch(name);
  198. /* eslint-enable no-await-in-loop */
  199. } catch {
  200. // continue regardless of error
  201. }
  202. if (content) {
  203. /* eslint-disable no-await-in-loop */
  204. const obj = await content.peek(type);
  205. /* eslint-enable no-await-in-loop */
  206. if (obj) {
  207. return callback(obj);
  208. }
  209. }
  210. }
  211. return new Map();
  212. };
  213. const open_model = async (format, producer, backend, config, weights) => {
  214. const metadata = await context.metadata('keras-metadata.json');
  215. return new keras.Model(metadata, format, producer, backend, config, weights);
  216. };
  217. switch (context.type) {
  218. case 'keras.config.json': {
  219. const obj = context.value;
  220. const config = obj.model_config ? obj.model_config : obj;
  221. const backend = obj.backend || '';
  222. let version = obj.keras_version ? obj.keras_version : null;
  223. if (!version) {
  224. const metadata = await request_json(context, 'metadata.json');
  225. if (metadata && metadata.keras_version) {
  226. version = metadata.keras_version;
  227. }
  228. }
  229. const format = `Keras${version ? ` v${version}` : ''}`;
  230. const weights_store = await request_weights(context);
  231. _load_state(config, weights_store);
  232. return open_model(format, '', backend, config, null);
  233. }
  234. case 'keras.model.weights.h5': {
  235. const group = context.value;
  236. const weights_store = read_weights_hdf5(group);
  237. const metadata = await request_json(context, 'metadata.json');
  238. let config = await request_json(context, 'config.json');
  239. const name = config ? 'Keras' : 'Keras Weights';
  240. const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
  241. if (config) {
  242. _load_state(config, weights_store);
  243. } else {
  244. config = _create_config(weights_store);
  245. }
  246. return await open_model(format, '', '', config, null);
  247. }
  248. case 'keras.model.weights.npz': {
  249. const entries = context.value;
  250. const weights_store = read_weights_numpy(entries);
  251. const metadata = await request_json(context, 'metadata.json');
  252. let config = await request_json(context, 'config.json');
  253. const name = config ? 'Keras' : 'Keras Weights';
  254. const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
  255. if (config) {
  256. _load_state(config, weights_store);
  257. } else {
  258. config = _create_config(weights_store);
  259. }
  260. return await open_model(format, '', '', config, null);
  261. }
  262. case 'keras.metadata.json': {
  263. const metadata = context.value;
  264. let config = await request_json(context, 'config.json');
  265. const name = config ? 'Keras' : 'Keras Weights';
  266. const format = name + (metadata.keras_version ? ` v${metadata.keras_version}` : '');
  267. const weights_store = await request_weights(context);
  268. if (!config && (!weights_store || weights_store.size === 0)) {
  269. throw new keras.Error("'config.json' or 'model.weights.*' not present.");
  270. }
  271. if (config) {
  272. _load_state(config, weights_store);
  273. } else {
  274. config = _create_config(weights_store);
  275. }
  276. return await open_model(format, '', '', config, null);
  277. }
  278. case 'hdf5.parameter.h5':
  279. case 'keras.h5': {
  280. const find_root_group = (root_group) => {
  281. const kerasmodel = root_group.group('model/kerasmodel');
  282. if (kerasmodel && kerasmodel.attributes.has('model_config')) {
  283. return kerasmodel;
  284. }
  285. return root_group;
  286. };
  287. const read_model_config = (group) => {
  288. if (group.attributes.has('model_config')) {
  289. const buffer = group.attributes.get('model_config');
  290. const reader = json.TextReader.open(buffer);
  291. if (reader) {
  292. return reader.read();
  293. }
  294. }
  295. return null;
  296. };
  297. const load_attributes_from_hdf5_group = (group, name) => {
  298. if (group.attributes.has(name)) {
  299. return group.attributes.get(name);
  300. }
  301. if (group.attributes.has(`${name}0`)) {
  302. let index = 0;
  303. let value = [];
  304. while (group.attributes.has(name + index.toString())) {
  305. const chunk = group.attributes.get(name + index.toString());
  306. value = value.concat(chunk);
  307. index++;
  308. }
  309. return value;
  310. }
  311. return null;
  312. };
  313. const weights = new keras.Weights();
  314. const group = context.value;
  315. const root_group = find_root_group(group);
  316. const model_config = read_model_config(root_group);
  317. if (model_config) {
  318. const backend = root_group.attributes.get('backend') || '';
  319. const version = root_group.attributes.get('keras_version') || '';
  320. const format = `Keras${version ? ` v${version}` : ''}`;
  321. const model_weights_group = root_group.group('model_weights');
  322. if (model_weights_group) {
  323. const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
  324. for (const layer_name of layer_names) {
  325. const layer_weights = model_weights_group.group(layer_name);
  326. if (layer_weights) {
  327. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  328. if (Array.isArray(weight_names) && weight_names.length > 0) {
  329. for (const weight_name of weight_names) {
  330. const weight = layer_weights.group(weight_name);
  331. if (weight && weight.value) {
  332. const variable = weight.value;
  333. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, variable.littleEndian ? '<' : '>', variable.data);
  334. weights.add(layer_name, tensor);
  335. }
  336. }
  337. }
  338. }
  339. }
  340. }
  341. if (!model_config) {
  342. throw new keras.Error("'model_config' is not present.");
  343. }
  344. if (!model_config.class_name) {
  345. throw new keras.Error("'class_name' is not present.");
  346. }
  347. return open_model(format, '', backend, model_config, weights);
  348. }
  349. const layer_names = load_attributes_from_hdf5_group(root_group, 'layer_names');
  350. if (layer_names && Array.isArray(layer_names)) {
  351. const version = root_group.attributes.get('keras_version') || '';
  352. const format = `Keras Weights${version ? ` v${version}` : ''}`;
  353. const backend = root_group.attributes.get('backend') || '';
  354. for (const layer_name of layer_names) {
  355. const layer_weights = root_group.group(layer_name);
  356. if (layer_weights) {
  357. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  358. if (Array.isArray(weight_names) && weight_names.length > 0) {
  359. for (const weight_name of weight_names) {
  360. const weight = layer_weights.group(weight_name);
  361. if (weight && weight.value) {
  362. const variable = weight.value;
  363. const components = weight_name.split('/');
  364. components.pop();
  365. const name = (components.length === 0 || components[0] !== layer_name) ? [layer_name].concat(components).join('/') : components.join('/');
  366. const encoding = variable.littleEndian ? '<' : '>';
  367. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, encoding, variable.data);
  368. weights.add(name, tensor);
  369. }
  370. }
  371. }
  372. }
  373. }
  374. return open_model(format, '', backend, null, weights);
  375. }
  376. const rootKeys = new Set(root_group.attributes.keys());
  377. rootKeys.delete('nb_layers');
  378. if (rootKeys.size > 0 || root_group.value !== null) {
  379. throw new keras.Error('File format is not HDF5 Weights.');
  380. }
  381. const format = 'HDF5 Weights';
  382. let weights_group = root_group;
  383. if (root_group.attributes.size === 0 && root_group.value === null && root_group.groups.size === 1) {
  384. const group = root_group.groups.values().next().value;
  385. if (group.attributes.size === 0 && group.value === null) {
  386. weights_group = group;
  387. }
  388. }
  389. const tensorKeys = new Set(['name', 'shape', 'quantization']);
  390. const groups = Array.from(weights_group.groups.values());
  391. if (groups.every((group) => group.attributes.size === 0 && group.groups.length === 0 && group.value !== null)) {
  392. for (const group of groups) {
  393. const variable = group.value;
  394. const layout = variable.littleEndian ? '<' : '>';
  395. const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  396. weights.add('', tensor);
  397. }
  398. return open_model(format, '', '', null, weights);
  399. }
  400. if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
  401. for (const group of groups) {
  402. const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
  403. for (const variableGroup of group.groups.values()) {
  404. if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
  405. throw new keras.Error('Variable format is not HDF5 Weights.');
  406. }
  407. const variable = variableGroup.value;
  408. if (!variable) {
  409. throw new keras.Error('Variable value is not HDF5 Weights.');
  410. }
  411. const name = moduleName ? [moduleName, variableGroup.name].join('/') : moduleName.name;
  412. const layout = variable.littleEndian ? '<' : '>';
  413. const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  414. weights.add(moduleName, tensor);
  415. }
  416. }
  417. return open_model(format, '', '', null, weights);
  418. }
  419. const walk = function(group) {
  420. if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
  421. for (const subGroup of group.groups.values()) {
  422. walk(subGroup);
  423. }
  424. return;
  425. }
  426. const subKeys = new Set(['index', 'need_grad']);
  427. const attribtues = Array.from(group.attributes.keys());
  428. const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
  429. if (match && group.value !== null && group.groups.size === 0) {
  430. const variable = group.value;
  431. const variableName = group.path;
  432. let moduleName = variableName;
  433. const parts = variableName.split('/');
  434. if (parts.length > 1) {
  435. parts.pop();
  436. moduleName = parts.join('/');
  437. }
  438. const layout = variable.littleEndian ? '<' : '>';
  439. const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
  440. weights.add(moduleName, tensor);
  441. return;
  442. }
  443. throw new keras.Error('Module group format is not HDF5 Weights.');
  444. };
  445. walk(weights_group);
  446. return open_model(format, '', '', null, weights);
  447. }
  448. case 'tfjs': {
  449. const target = context.value;
  450. await target.read();
  451. return open_model(target.format, target.producer, target.backend, target.config, target.weights);
  452. }
  453. case 'keras.pickle': {
  454. const obj = context.value;
  455. const execution = new python.Execution();
  456. const decoder = new TextDecoder('utf-8');
  457. const format = `Keras Pickle${obj.keras_version ? ` v${decoder.decode(obj.keras_version)}` : ''}`;
  458. const backend = obj.backend ? decoder.decode(obj.backend) : '';
  459. const reader = json.TextReader.open(obj.model_config);
  460. const model_config = reader.read();
  461. const weights = new keras.Weights();
  462. const model_weights_group = obj.model_weights;
  463. if (model_weights_group) {
  464. const layer_names = model_weights_group.layer_names.map((buffer) => decoder.decode(buffer));
  465. for (const layer_name of layer_names) {
  466. const layer_weights = model_weights_group[layer_name];
  467. if (layer_weights) {
  468. const weight_names = layer_weights.weight_names.map((buffer) => decoder.decode(buffer));
  469. if (Array.isArray(weight_names) && weight_names.length > 0) {
  470. for (const weight_name of weight_names) {
  471. const buffer = layer_weights[weight_name];
  472. const unpickler = execution.invoke('pickle.Unpickler', [buffer]);
  473. const variable = unpickler.load();
  474. const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.__name__, null, null, '<', variable.data);
  475. weights.add(layer_name, tensor);
  476. }
  477. }
  478. }
  479. }
  480. }
  481. return open_model(format, '', backend, model_config, weights);
  482. }
  483. case 'keras.pb.SavedMetadata': {
  484. keras.proto = await context.require('./keras-proto');
  485. const format = 'Keras Saved Metadata';
  486. const reader = await context.read('protobuf.binary');
  487. const saved_metadata = keras.proto.third_party.tensorflow.python.keras.protobuf.SavedMetadata.decode(reader);
  488. if (!saved_metadata || !Array.isArray(saved_metadata.nodes) ||
  489. !saved_metadata.nodes.every((node) => node && typeof node.metadata === 'string' && node.metadata.length > 0)) {
  490. throw new keras.Error('Invalid keras.protobuf.SavedMetadata.');
  491. }
  492. const objects = new Map();
  493. for (const node of saved_metadata.nodes) {
  494. const reader = json.TextReader.open(node.metadata);
  495. node.metadata = reader.read();
  496. objects.set(node.node_path, node);
  497. }
  498. const model_config = objects.get('root').metadata;
  499. return open_model(format, '', '', model_config, null);
  500. }
  501. default: {
  502. throw new keras.Error(`Unsupported Keras format '${context.type}'.`);
  503. }
  504. }
  505. }
  506. };
  507. keras.Model = class {
  508. constructor(metadata, format, producer, backend, config, weights) {
  509. this.format = format;
  510. this.runtime = backend;
  511. this.producer = producer;
  512. metadata = new keras.GraphMetadata(metadata);
  513. this.modules = [new keras.Graph(metadata, config, weights)];
  514. }
  515. };
  516. keras.Graph = class {
  517. constructor(metadata, config, weights, group) {
  518. this.inputs = [];
  519. this.outputs = [];
  520. this.nodes = [];
  521. group = group || '';
  522. const values = new Map();
  523. values.map = (name, type, tensor) => {
  524. if (tensor) {
  525. return new keras.Value(name, type || null, tensor);
  526. }
  527. if (!values.has(name)) {
  528. values.set(name, new keras.Value(name, type || null, tensor || null));
  529. } else if (type || tensor) {
  530. throw new keras.Error(`Duplicate value '${name}'.`);
  531. }
  532. return values.get(name);
  533. };
  534. if (config) {
  535. const getInputType = (layer) => {
  536. if (layer && layer.config) {
  537. let dataType = '?';
  538. let shape = [];
  539. const config = layer.config;
  540. if (config.dtype) {
  541. dataType = config.dtype;
  542. delete config.dtype;
  543. }
  544. if (Array.isArray(config.batch_input_shape)) {
  545. shape = config.batch_input_shape.map((s) => s === null ? '?' : s);
  546. delete config.batch_input_shape;
  547. } else if (config.batch_input_shape &&
  548. config.batch_input_shape.class_name === '__tuple__' &&
  549. Array.isArray(config.batch_input_shape.items)) {
  550. shape = config.batch_input_shape.items.map((s) => s === null ? '?' : s);
  551. delete config.batch_input_shape;
  552. }
  553. return new keras.TensorType(dataType, new keras.TensorShape(shape));
  554. }
  555. return null;
  556. };
  557. this.name = config.name || (config.config && config.config.name ? config.config.name : '');
  558. this.description = config.class_name;
  559. let baseType = config.class_name;
  560. switch (baseType) {
  561. case '__Function__':
  562. this.type = 'function';
  563. break;
  564. case 'Sequential':
  565. case 'Functional':
  566. case 'Model': {
  567. break;
  568. }
  569. case 'Tokenizer': {
  570. config = { config: { layers: [config] } };
  571. baseType = 'Functional';
  572. break;
  573. }
  574. default: {
  575. const layers = Array.from(config.layers ? config.layers : config);
  576. const sequential = layers.every((layer) => layer.inbound_nodes === undefined);
  577. baseType = sequential ? 'Sequential' : 'Functional';
  578. break;
  579. }
  580. }
  581. switch (baseType) {
  582. case 'Sequential': {
  583. config = config.config;
  584. const outputs = null;
  585. let name = 'input';
  586. let index = -1;
  587. const layers = Array.from(config.layers ? config.layers : config);
  588. while (layers.length > 0) {
  589. const layer = layers.shift();
  590. let current = index.toString();
  591. index++;
  592. if (index === 0) {
  593. const type = getInputType(layer);
  594. let remove = false;
  595. if (layer.class_name === 'InputLayer' && layer.config && layer.config.name) {
  596. name = layer.config.name;
  597. remove = true;
  598. }
  599. const value = values.map(name, type);
  600. const argument = new keras.Argument(name, [value]);
  601. this.inputs.push(argument);
  602. if (remove) {
  603. continue;
  604. }
  605. }
  606. const nodeInputs = [{ name }];
  607. if (layer.config && layer.config.name) {
  608. current = layer.config.name;
  609. }
  610. name = current;
  611. let nodeOutputs = [name];
  612. if (index === layers.length) {
  613. if (outputs && outputs.length > 0) {
  614. nodeOutputs = [outputs[0]];
  615. name = null;
  616. }
  617. }
  618. layer.inputs = nodeInputs;
  619. layer.outputs = nodeOutputs;
  620. const node = new keras.Node(metadata, layer, group, weights, values);
  621. this.nodes.push(node);
  622. }
  623. if (name) {
  624. const value = values.map(name);
  625. const argument = new keras.Argument(name, [value]);
  626. this.outputs.push(argument);
  627. }
  628. break;
  629. }
  630. case '__Function__':
  631. case 'Functional':
  632. case 'Model': {
  633. config = config.config;
  634. const nodes = new Map();
  635. if (config.layers) {
  636. const is_constant = (item) => {
  637. return Array.isArray(item) && (item.length === 3 || item.length === 4) && item[0] === '_CONSTANT_VALUE' && item[1] === -1;
  638. };
  639. const is_connection = (item) => {
  640. return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string' && typeof item[1] === 'number' && typeof item[2] === 'number';
  641. };
  642. const read_value = (input_data) => {
  643. if (!Array.isArray(input_data)) {
  644. return input_data;
  645. }
  646. const transform = (value) => {
  647. if (value.every((item) => is_constant(item))) {
  648. for (let i = 0; i < value.length; i++) {
  649. /* eslint-disable prefer-destructuring */
  650. value[i] = value[i][2];
  651. /* eslint-enable prefer-destructuring */
  652. }
  653. } else if (value.every((item) => Array.isArray(item))) {
  654. const dims = value.map((item) => transform(item));
  655. const [dim] = dims;
  656. for (let i = 1; i < dims.length; i++) {
  657. if (dim.length === dims[i].length) {
  658. if (!dims[i].every((value, i) => value === dim[i])) {
  659. throw new python.Error('Invalid array shape.');
  660. }
  661. }
  662. }
  663. return [value.length].concat(dim);
  664. }
  665. return [value.length];
  666. };
  667. const shape = transform(input_data);
  668. const flatten = (input) => input.reduce((a, b) => a.concat(Array.isArray(b) ? flatten(b) : b), []);
  669. const value = flatten(input_data);
  670. return { shape, value };
  671. };
  672. const functional = config.layers.every((layer) => Array.isArray(layer.inbound_nodes));
  673. const layers = new Map();
  674. if (functional) {
  675. const read_connection = (input_data) => {
  676. const [node_name, node_index, tensor_index] = input_data;
  677. const inbound_node_key = `${node_name}[${node_index}]`;
  678. const inbound_node = nodes.get(inbound_node_key);
  679. const tensor_key = `${node_name}[${node_index}][${tensor_index}]`;
  680. if (inbound_node) {
  681. while (tensor_index >= inbound_node.outputs.length) {
  682. inbound_node.outputs.push(undefined);
  683. }
  684. inbound_node.outputs[tensor_index] = tensor_key;
  685. }
  686. return tensor_key;
  687. };
  688. const process_node = (node, inbound_node) => {
  689. if (Array.isArray(inbound_node) && inbound_node.length === 4 && typeof inbound_node[0] === 'string') {
  690. const key = read_connection(inbound_node);
  691. node.inputs.push({ name: key });
  692. for (const [name, value] of Object.entries(inbound_node[3])) {
  693. if (is_connection(value)) {
  694. const key = read_connection(value);
  695. node.inputs.push({ name: key });
  696. } else if (Array.isArray(value)) {
  697. const array = read_value(value);
  698. node.args[name] = array;
  699. } else {
  700. node.args[name] = value;
  701. }
  702. }
  703. } else if (Array.isArray(inbound_node)) {
  704. for (const input_data of inbound_node) {
  705. // [ 'conv2d', 0, 0 ] or [ 'conv2d', 0, 0, {} ]
  706. if (Array.isArray(input_data) && is_connection(input_data)) {
  707. const key = read_connection(input_data);
  708. node.inputs.push({ name: key });
  709. } else if (Array.isArray(input_data) && input_data.every((item) => is_connection(item))) {
  710. for (const input of input_data) {
  711. const key = read_connection(input);
  712. node.inputs.push({ name: key });
  713. }
  714. } else if (Array.isArray(input_data)) {
  715. const value = read_value(input_data);
  716. node.inputs.push(value);
  717. } else {
  718. throw new keras.Error(`Invalid inbound connection '${JSON.stringify(input_data)}'.`);
  719. }
  720. }
  721. } else if (inbound_node && inbound_node.args) {
  722. for (const arg of inbound_node.args) {
  723. if (arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history)) {
  724. const key = read_connection(arg.config.keras_history);
  725. node.inputs.push({ name: key });
  726. } else if (Array.isArray(arg) && arg.every((arg) => arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history))) {
  727. for (const input of arg) {
  728. const key = read_connection(input.config.keras_history);
  729. node.inputs.push({ name: key });
  730. }
  731. }
  732. }
  733. }
  734. };
  735. let legacy_format = true;
  736. for (const layer of config.layers) {
  737. if (Array.isArray(layer.inbound_nodes)) {
  738. for (const inbound_node of layer.inbound_nodes) {
  739. if (Array.isArray(inbound_node.args)) {
  740. legacy_format = false;
  741. }
  742. }
  743. }
  744. }
  745. for (const layer of config.layers) {
  746. const class_name = layer.class_name;
  747. let first_index = 0;
  748. if (legacy_format) {
  749. const keys = new Set(Object.keys(layer.config));
  750. const is_functional_config = keys.has('name') && keys.has('layers') && keys.has('input_layers') && keys.has('output_layers');
  751. if (class_name === 'Sequential' ||
  752. (is_functional_config && Array.isArray(layer.config.layers) && layer.config.layers.length > 0 && layer.config.layers[0].class_name === 'InputLayer')) {
  753. first_index++;
  754. }
  755. }
  756. layers.set(layer.name, layers);
  757. if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 0) {
  758. layer.inputs = [];
  759. layer.outputs = [];
  760. layer.args = {};
  761. nodes.set(`${layer.name}[${first_index}]`, layer);
  762. } else if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 1) {
  763. layer.inputs = [];
  764. layer.outputs = [];
  765. layer.args = {};
  766. [layer.inbound_node] = layer.inbound_nodes;
  767. nodes.set(`${layer.name}[${first_index}]`, layer);
  768. } else {
  769. let config = {};
  770. switch (class_name) {
  771. case 'Functional':
  772. case 'Sequential':
  773. case 'Model': {
  774. config = layer;
  775. break;
  776. }
  777. default: {
  778. config.class_name = '__Function__';
  779. config.name = layer.name;
  780. config.config = {};
  781. config.config.layers = [{ ...layer }];
  782. delete config.config.layers[0].inbound_nodes;
  783. delete config.config.layers[0].input_layers;
  784. delete config.config.layers[0].output_layers;
  785. break;
  786. }
  787. }
  788. const type = new keras.Graph(metadata, config, weights, '');
  789. for (let i = 0; i < layer.inbound_nodes.length; i++) {
  790. const index = i + first_index;
  791. const key = `${layer.name}[${index}]`;
  792. const node = {};
  793. node.name = key;
  794. node.class_name = '__Function__';
  795. node.config = {};
  796. node.config.name = key;
  797. node.inputs = [];
  798. node.outputs = [];
  799. node.args = {};
  800. node.__type__ = type;
  801. node.inbound_node = layer.inbound_nodes[i];
  802. nodes.set(key, node);
  803. }
  804. }
  805. }
  806. for (const entry of nodes) {
  807. if (entry[1].inbound_node) {
  808. process_node(entry[1], entry[1].inbound_node);
  809. }
  810. }
  811. if (Array.isArray(config.input_layers)) {
  812. if (config.input_layers.length === 3 && typeof config.input_layers[0] === 'string' && Number.isInteger(config.input_layers[1]) && Number.isInteger(config.input_layers[2])) {
  813. config.input_layers = [config.input_layers];
  814. }
  815. for (let i = 0; i < config.input_layers.length; i++) {
  816. const input_data = config.input_layers[i];
  817. const name = read_connection(input_data);
  818. const [node_name, node_index] = input_data;
  819. const inbound_node_key = `${node_name}[${node_index}]`;
  820. const node = nodes.get(inbound_node_key);
  821. let type = null;
  822. if (node && node.class_name === 'InputLayer') {
  823. type = getInputType(node);
  824. nodes.delete(name);
  825. nodes.delete(inbound_node_key);
  826. }
  827. const value = values.map(name, type);
  828. const argument = new keras.Argument(node_name, [value]);
  829. this.inputs.push(argument);
  830. }
  831. }
  832. if (Array.isArray(config.output_layers)) {
  833. if (config.output_layers.length === 3 && typeof config.output_layers[0] === 'string' && Number.isInteger(config.output_layers[1]) && Number.isInteger(config.output_layers[2])) {
  834. config.output_layers = [config.output_layers];
  835. }
  836. for (let i = 0; i < config.output_layers.length; i++) {
  837. const output_data = config.output_layers[i];
  838. const [name] = output_data;
  839. const key = read_connection(output_data);
  840. const value = values.map(key);
  841. const argument = new keras.Argument(name, [value]);
  842. this.outputs.push(argument);
  843. }
  844. }
  845. } else {
  846. for (const layer of config.layers) {
  847. layer.inputs = [];
  848. layer.outputs = [];
  849. layer.args = {};
  850. nodes.set(`${layer.name}[0]`, layer);
  851. }
  852. }
  853. }
  854. for (const entry of nodes) {
  855. const node = new keras.Node(metadata, entry[1], group, weights, values);
  856. this.nodes.push(node);
  857. }
  858. break;
  859. }
  860. default: {
  861. throw new keras.Error(`'${config.class_name}' is not supported.`);
  862. }
  863. }
  864. } else if (weights) {
  865. this.type = 'weights';
  866. for (const name of weights.keys()) {
  867. if (weights.get('', name).length <= 6) {
  868. const layer = { class_name: 'Weights', config: { name } };
  869. const node = new keras.Node(metadata, layer, '', weights, values);
  870. this.nodes.push(node);
  871. }
  872. }
  873. }
  874. }
  875. };
  876. keras.Argument = class {
  877. constructor(name, value, type, visible) {
  878. this.name = name;
  879. this.value = value;
  880. this.type = type || null;
  881. this.visible = visible !== false;
  882. }
  883. };
  884. keras.Value = class {
  885. constructor(name, type, initializer) {
  886. if (typeof name !== 'string') {
  887. throw new keras.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  888. }
  889. this.name = name;
  890. this.type = !type && initializer ? initializer.type : type;
  891. this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
  892. this.initializer = initializer || null;
  893. }
  894. };
  895. keras.Node = class {
  896. constructor(metadata, layer, group, weights, values) {
  897. const config = layer.config || {};
  898. const args = layer.args || {};
  899. let inputs = layer.inputs || [];
  900. let outputs = layer.outputs || [];
  901. const name = config && config.name ? config.name : '';
  902. this.group = group || '';
  903. this.name = (this.group ? `${this.group}/` : '') + name;
  904. this.inputs = [];
  905. this.outputs = [];
  906. this.attributes = [];
  907. this.chain = [];
  908. let names = [name];
  909. let class_name = layer.class_name;
  910. let model = false;
  911. switch (class_name) {
  912. case '__Function__': {
  913. this.type = layer.__type__;
  914. model = true;
  915. break;
  916. }
  917. case 'Model':
  918. case 'Functional':
  919. case 'Sequential': {
  920. const name = layer.name || (layer.config ? layer.config.name : '');
  921. this.type = new keras.Graph(metadata, layer, weights, (group ? `${group}/` : '') + name);
  922. model = true;
  923. if (config) {
  924. delete config.layers;
  925. delete config.input_layers;
  926. delete config.output_layers;
  927. }
  928. this.inputs = [new keras.Argument('inputs', inputs.map((input) => values.map(input.name)))];
  929. this.outputs = [new keras.Argument('outputs', outputs.map((name) => values.map(name)))];
  930. inputs = [];
  931. outputs = [];
  932. break;
  933. }
  934. case 'Wrapper':
  935. case 'Bidirectional':
  936. case 'TimeDistributed': {
  937. if (config && config.layer) {
  938. const inner = config.layer;
  939. delete config.layer;
  940. this.inner = new keras.Node(metadata, inner, null, null, values);
  941. if (class_name === 'Bidirectional' && inner.config.name) {
  942. names = [`${name}/forward_${inner.config.name}`, `${name}/backward_${inner.config.name}`];
  943. if (!group) {
  944. group = name;
  945. }
  946. }
  947. }
  948. this.type = metadata.type(class_name) || { name: class_name };
  949. break;
  950. }
  951. case 'TFOpLambda': {
  952. if (config && config.function) {
  953. class_name = config.function;
  954. delete config.function;
  955. }
  956. this.type = metadata.type(class_name) || { name: class_name };
  957. break;
  958. }
  959. default: {
  960. this.type = metadata.type(class_name) || { name: class_name };
  961. break;
  962. }
  963. }
  964. if (layer._trainable_variables) {
  965. if (inputs.length === 0 && Array.isArray(this.type.inputs) && this.type.inputs.length > 0) {
  966. // weights-only, remove 'input' from type metadata
  967. this.type = { ...this.type };
  968. this.type.inputs = this.type.inputs.slice(1);
  969. }
  970. for (const variable of layer._trainable_variables) {
  971. inputs.push({ name: '', initializer: variable });
  972. }
  973. } else if (weights && !model) {
  974. for (const name of names) {
  975. let tensors = weights.get(group, name);
  976. if (tensors.length > 0) {
  977. for (const initializer of tensors) {
  978. inputs.push({ name: initializer.name, initializer });
  979. }
  980. } else {
  981. tensors = weights.get('', name);
  982. for (const initializer of tensors) {
  983. inputs.push({ name: initializer.name, initializer });
  984. }
  985. }
  986. }
  987. }
  988. const attributes = [];
  989. const convertAttributeValue = (value) => {
  990. if (Array.isArray(value) || value !== Object(value)) {
  991. return value;
  992. }
  993. const obj = {};
  994. if (value.class_name) {
  995. obj.__type__ = value.class_name;
  996. }
  997. if (value.config) {
  998. const config = value.config;
  999. for (const [key, value] of Object.entries(config)) {
  1000. obj[key] = convertAttributeValue(value);
  1001. }
  1002. }
  1003. return obj;
  1004. };
  1005. if (config && !Array.isArray(config)) {
  1006. for (const [name, value] of Object.entries(config)) {
  1007. if (class_name !== 'Activation' && name === 'activation' && value !== 'linear') {
  1008. if (typeof value === 'string') {
  1009. const config = { activation: value };
  1010. const node = new keras.Node(metadata, { class_name: 'Activation', config }, null, null, value);
  1011. this.chain.push(node);
  1012. } else if (value && typeof value.class_name === 'string' && value.config) {
  1013. const type = value.class_name;
  1014. if (!metadata.type(type)) {
  1015. metadata.add(type, { name: type, category: 'Activation' });
  1016. }
  1017. const node = new keras.Node(metadata, value, null, null, value);
  1018. this.chain.push(node);
  1019. }
  1020. }
  1021. if (name !== 'name' && name !== 'batch_input_shape') {
  1022. const schema = metadata.attribute(class_name, name);
  1023. attributes.push([schema, name, value]);
  1024. }
  1025. }
  1026. }
  1027. const innerType = this.inner ? this.inner.type : null;
  1028. const innerMetadata = innerType ? metadata.type(innerType) : null;
  1029. let inputIndex = 0;
  1030. while (inputs.length > 0) {
  1031. let list = false;
  1032. let name = null;
  1033. let visible = true;
  1034. if (!innerMetadata || inputIndex === 0) {
  1035. if (this.type && this.type.inputs && inputIndex < this.type.inputs.length) {
  1036. const input = this.type.inputs[inputIndex];
  1037. name = input.name;
  1038. if (class_name === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
  1039. inputIndex++;
  1040. continue;
  1041. }
  1042. visible = input.visible !== false;
  1043. if (this.type.inputs[inputIndex].list) {
  1044. list = true;
  1045. }
  1046. }
  1047. } else {
  1048. switch (class_name) {
  1049. case 'Bidirectional': {
  1050. let innerIndex = inputIndex;
  1051. if (innerMetadata && innerMetadata.inputs) {
  1052. if (innerIndex < innerMetadata.inputs.length) {
  1053. name = `forward_${innerMetadata.inputs[innerIndex].name}`;
  1054. } else {
  1055. innerIndex = innerIndex - innerMetadata.inputs.length + 1;
  1056. if (innerIndex < innerMetadata.inputs.length) {
  1057. name = `backward_${innerMetadata.inputs[innerIndex].name}`;
  1058. }
  1059. }
  1060. }
  1061. visible = false;
  1062. break;
  1063. }
  1064. case 'TimeDistributed':
  1065. if (innerMetadata && innerMetadata.inputs && inputIndex < innerMetadata.inputs.length) {
  1066. name = innerMetadata.inputs[inputIndex].name;
  1067. }
  1068. break;
  1069. default:
  1070. break;
  1071. }
  1072. }
  1073. const input = list ? inputs.splice(0, inputs.length) : [inputs.shift()];
  1074. const inputArguments = input.map((input) => {
  1075. if (input.name) {
  1076. return values.map(input.name, null, input.initializer);
  1077. }
  1078. if (input.initializer) {
  1079. return values.map(input.name, null, input.initializer);
  1080. }
  1081. if (input.value !== undefined) {
  1082. const tensor = new keras.Tensor('', input.shape, config.dtype || '?', null, null, '|', input.value);
  1083. return values.map('', null, tensor);
  1084. }
  1085. throw new keras.Error(`Invalid argument '${JSON.stringify(input.name)}'.`);
  1086. });
  1087. if (!name && inputArguments.length === 1 && inputArguments[0].initializer && inputArguments[0].initializer.name) {
  1088. if (names.length === 1 && names[0] === '') {
  1089. name = inputArguments[0].initializer.name;
  1090. } else {
  1091. const parts = inputArguments[0].initializer.name.split('/').pop().split(':').shift().split('_');
  1092. const inputName1 = parts.pop();
  1093. const inputName2 = parts.length > 0 ? [parts.pop(), inputName1].join('_') : '';
  1094. const inputNames = new Set(['recurrent_kernel', 'running_mean', 'running_std', 'moving_mean', 'moving_variance', 'depthwise_filter', 'pointwise_filter']);
  1095. name = inputNames.has(inputName2) ? inputName2 : inputName1;
  1096. }
  1097. }
  1098. const argument = new keras.Argument(name || inputIndex.toString(), inputArguments, null, visible);
  1099. this.inputs.push(argument);
  1100. inputIndex++;
  1101. }
  1102. for (let i = 0; i < outputs.length; i++) {
  1103. const output = outputs[i];
  1104. const name = this.type && this.type.outputs && i < this.type.outputs.length && this.type.outputs[i] && this.type.outputs[i].name ? this.type.outputs[i].name : i.toString();
  1105. const argument = new keras.Argument(name, output === undefined || output.length === 0 ? [] : [values.map(output)]);
  1106. this.outputs.push(argument);
  1107. }
  1108. const inputTypes = new Map((this.type.inputs || []).map((input) => [input.name, input.type]));
  1109. for (const [name, arg] of Object.entries(args)) {
  1110. if (name !== 'name') {
  1111. if ((arg && arg.name) || (inputTypes.has(name) && inputTypes.get(name) === 'Tensor' && arg)) {
  1112. if (arg.name) {
  1113. const value = values.map(arg.name);
  1114. const argument = new keras.Argument(name, [value]);
  1115. this.inputs.push(argument);
  1116. } else {
  1117. const tensor = new keras.Tensor('', arg.shape, config.dtype || '?', null, null, '|', arg.value);
  1118. const value = values.map('', null, tensor);
  1119. const argument = new keras.Argument(name, [value]);
  1120. this.inputs.push(argument);
  1121. }
  1122. } else {
  1123. const schema = metadata.attribute(class_name, name);
  1124. this.attributes.push([schema, name, arg]);
  1125. }
  1126. }
  1127. }
  1128. this.attributes = attributes.map(([metadata, name, value]) => {
  1129. let type = null;
  1130. let visible = true;
  1131. if (value && typeof value === 'object' && value.class_name && value.config) {
  1132. value = convertAttributeValue(value);
  1133. }
  1134. switch (name) {
  1135. case 'trainable':
  1136. type = 'boolean';
  1137. visible = false;
  1138. break;
  1139. case 'dtype':
  1140. visible = false;
  1141. break;
  1142. default: {
  1143. if (metadata) {
  1144. type = metadata.type ? metadata.type : type;
  1145. if (metadata.visible === false) {
  1146. visible = false;
  1147. } else if (metadata.default !== undefined) {
  1148. if (Array.isArray(value)) {
  1149. if (Array.isArray(metadata.default)) {
  1150. visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
  1151. } else {
  1152. visible = !value.every((item) => item === metadata.default);
  1153. }
  1154. } else {
  1155. visible = value !== metadata.default;
  1156. }
  1157. }
  1158. }
  1159. break;
  1160. }
  1161. }
  1162. return new keras.Argument(name, value, type, visible);
  1163. });
  1164. if (typeof this.type.name !== 'string' || !this.type.name.split) { // #416
  1165. throw new keras.Error(`Unsupported node type '${JSON.stringify(this.type.name)}'.`);
  1166. }
  1167. }
  1168. };
  1169. keras.Tensor = class {
  1170. constructor(name, shape, type, stride, quantization, encoding, data, location) {
  1171. this.name = name;
  1172. this.type = new keras.TensorType(type, new keras.TensorShape(shape));
  1173. this.stride = stride;
  1174. this.encoding = encoding;
  1175. this._data = data;
  1176. this.location = location;
  1177. if (quantization && (quantization.scale !== 0 || quantization.min !== 0)) {
  1178. this.quantization = {
  1179. type: 'linear',
  1180. scale: [quantization.scale],
  1181. min: [quantization.min]
  1182. };
  1183. }
  1184. }
  1185. get values() {
  1186. if (this.encoding === '|') {
  1187. return this._data;
  1188. }
  1189. if (this._data === null) {
  1190. return null;
  1191. }
  1192. return this._data instanceof Uint8Array ? this._data : this._data.peek();
  1193. }
  1194. };
  1195. keras.TensorType = class {
  1196. constructor(dataType, shape) {
  1197. this.dataType = dataType;
  1198. this.shape = shape;
  1199. }
  1200. toString() {
  1201. return this.dataType + this.shape.toString();
  1202. }
  1203. };
  1204. keras.TensorShape = class {
  1205. constructor(dimensions) {
  1206. this.dimensions = dimensions;
  1207. }
  1208. toString() {
  1209. return this.dimensions && this.dimensions.length > 0 ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  1210. }
  1211. };
  1212. keras.GraphMetadata = class {
  1213. constructor(metadata) {
  1214. this._metadata = metadata;
  1215. this._types = new Map();
  1216. }
  1217. type(name) {
  1218. if (this._types.has(name)) {
  1219. return this._types.get(name);
  1220. }
  1221. return this._metadata.type(name);
  1222. }
  1223. attribute(type, name) {
  1224. return this._metadata.attribute(type, name);
  1225. }
  1226. add(type, metadata) {
  1227. this._types.set(type, metadata);
  1228. }
  1229. };
  1230. keras.Weights = class {
  1231. constructor() {
  1232. this._map = new Map();
  1233. }
  1234. get empty() {
  1235. return this._map.size === 0;
  1236. }
  1237. add(layer_name, tensor) {
  1238. if (!this._map.has(layer_name)) {
  1239. this._map.set(layer_name, []);
  1240. }
  1241. this._map.get(layer_name).push(tensor);
  1242. }
  1243. get(group, name) {
  1244. if (group) {
  1245. const list = this._map.get(group.split('/').shift());
  1246. if (list) {
  1247. const match1 = list.filter((tensor) => tensor.name.startsWith(`${name}/`));
  1248. if (match1.length > 0) {
  1249. return match1;
  1250. }
  1251. const match2 = list.filter((tensor) => tensor.name.startsWith(`${group}/${name}/`));
  1252. if (match2.length > 0) {
  1253. return match2;
  1254. }
  1255. }
  1256. } else {
  1257. const match1 = this._map.get(name);
  1258. if (match1 && match1.length > 0) {
  1259. return match1;
  1260. }
  1261. const match2 = this._map.get('');
  1262. if (match2 && match2.length > 0) {
  1263. const match3 = match2.filter((tensor) => tensor.name.startsWith(`${(group ? `${group}/` : '') + name}/`));
  1264. if (match3.length > 0) {
  1265. return match3;
  1266. }
  1267. }
  1268. }
  1269. return [];
  1270. }
  1271. keys() {
  1272. return this._map.keys();
  1273. }
  1274. };
  1275. keras.Error = class extends Error {
  1276. constructor(message) {
  1277. super(message);
  1278. this.name = 'Error loading Keras model.';
  1279. }
  1280. };
  1281. tfjs.Container = class {
  1282. static async open(context) {
  1283. const json = await context.peek('json');
  1284. if (json) {
  1285. if (json.modelTopology && (json.format === 'layers-model' || json.modelTopology.class_name || json.modelTopology.model_config)) {
  1286. return new tfjs.Container(context, '');
  1287. }
  1288. if (Array.isArray(json) && json.every((item) => item.weights && item.paths)) {
  1289. return new tfjs.Container(context, 'weights.json');
  1290. }
  1291. if (json.tfjsVersion) {
  1292. return new tfjs.Container(context, 'metadata');
  1293. }
  1294. }
  1295. const identifier = context.identifier;
  1296. if (/^.*group\d+-shard\d+of\d+(\.bin)?$/.test(identifier)) {
  1297. return new tfjs.Container(context, 'weights.bin');
  1298. }
  1299. return null;
  1300. }
  1301. constructor(context, type) {
  1302. this.context = context;
  1303. this.type = type;
  1304. }
  1305. async read() {
  1306. const context = this.context;
  1307. switch (this.type) {
  1308. case '': {
  1309. const obj = await context.peek('json');
  1310. return this._openModelJson(obj);
  1311. }
  1312. case 'weights.json': {
  1313. this.format = 'TensorFlow.js Weights';
  1314. this.config = null;
  1315. const obj = await context.peek('json');
  1316. const manifests = Array.from(obj);
  1317. for (const manifest of manifests) {
  1318. for (const weight of manifest.weights) {
  1319. const name = weight.name;
  1320. const index = name.lastIndexOf('/');
  1321. weight.identifier = index === -1 ? name : name.substring(0, index);
  1322. }
  1323. }
  1324. return this._openManifests(manifests);
  1325. }
  1326. case 'weights.bin': {
  1327. const content = await this.context.fetch('model.json');
  1328. const obj = await content.read('json');
  1329. return this._openModelJson(obj);
  1330. }
  1331. case 'metadata': {
  1332. const content = await this.context.fetch('model.json');
  1333. const obj = await content.read('json');
  1334. return this._openModelJson(obj);
  1335. }
  1336. default: {
  1337. throw new tfjs.Error(`Unsupported TensorFlow.js format '${this.type}'.`);
  1338. }
  1339. }
  1340. }
  1341. _openShards(manifests, shards) {
  1342. this.weights = new keras.Weights();
  1343. const dtype_size_map = new Map([
  1344. ['float16', 2], ['float32', 4], ['float64', 8],
  1345. ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
  1346. ['uint8', 1], ['uint16', 2], ['uint32', 4], ['uint64', 8]
  1347. ]);
  1348. for (const manifest of manifests) {
  1349. let buffer = null;
  1350. let location = '';
  1351. if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
  1352. const list = manifest.paths.map((path) => shards.get(path));
  1353. location = manifest.paths.join(', ');
  1354. const size = list.reduce((a, b) => a + b.length, 0);
  1355. buffer = new Uint8Array(size);
  1356. let offset = 0;
  1357. for (const item of list) {
  1358. buffer.set(item, offset);
  1359. offset += item.length;
  1360. }
  1361. }
  1362. let offset = 0;
  1363. for (const weight of manifest.weights) {
  1364. const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
  1365. if (!dtype_size_map.has(dtype)) {
  1366. throw new keras.Error(`Unsupported weight data type size '${dtype}'.`);
  1367. }
  1368. const itemsize = dtype_size_map.get(dtype);
  1369. const size = weight.shape.reduce((a, b) => a * b, 1);
  1370. const length = itemsize * size;
  1371. const data = buffer ? buffer.slice(offset, offset + length) : null;
  1372. const tensor = new keras.Tensor(weight.name, weight.shape, dtype, null, weight.quantization, '<', data, location);
  1373. this.weights.add(weight.identifier, tensor);
  1374. offset += length;
  1375. }
  1376. }
  1377. }
  1378. async _openManifests(manifests) {
  1379. const shards = new Map();
  1380. for (const manifest of manifests) {
  1381. for (const path of manifest.paths) {
  1382. if (!shards.has(path)) {
  1383. const promise = this.context.fetch(path);
  1384. shards.set(path, promise);
  1385. }
  1386. }
  1387. }
  1388. const promises = shards.values();
  1389. try {
  1390. const contexts = await Promise.all(promises);
  1391. for (const key of shards.keys()) {
  1392. const context = contexts.shift();
  1393. const buffer = context.stream.peek();
  1394. shards.set(key, buffer);
  1395. }
  1396. this._openShards(manifests, shards);
  1397. } catch {
  1398. shards.clear();
  1399. this._openShards(manifests, shards);
  1400. }
  1401. }
  1402. _openModelJson(obj) {
  1403. if (!obj || !obj.modelTopology || (obj.format !== 'layers-model' && !obj.modelTopology.model_config && !obj.modelTopology.config)) {
  1404. throw new tfjs.Error('File format is not TensorFlow.js layers-model.');
  1405. }
  1406. const modelTopology = obj.modelTopology;
  1407. this.format = `TensorFlow.js ${obj.format ? obj.format : `Keras${modelTopology.keras_version ? (` v${modelTopology.keras_version}`) : ''}`}`;
  1408. this.producer = obj.convertedBy || obj.generatedBy || '';
  1409. this.backend = modelTopology.backend || '';
  1410. const manifests = obj.weightsManifest;
  1411. for (const manifest of manifests) {
  1412. for (const weight of manifest.weights) {
  1413. weight.identifier = '';
  1414. }
  1415. }
  1416. this.config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
  1417. return this._openManifests(manifests);
  1418. }
  1419. };
  1420. tfjs.Error = class extends Error {
  1421. constructor(message) {
  1422. super(message);
  1423. this.name = 'Error loading TensorFlow.js model.';
  1424. }
  1425. };
  1426. export const ModelFactory = keras.ModelFactory;