keras.js 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264
  1. var keras = keras || {};
  2. var tfjs = tfjs || {};
  3. var json = require('./json');
  4. var python = require('./python');
  5. keras.ModelFactory = class {
  6. match(context) {
  7. const group = context.open('hdf5');
  8. if (group && group.attributes.get('CLASS') !== 'hickle') {
  9. return 'keras.h5';
  10. }
  11. const json = context.open('json');
  12. if (json) {
  13. if (json.mxnet_version || (json.nodes && json.arg_nodes && json.heads)) {
  14. return null;
  15. }
  16. if (json.model_config || (json.class_name && json.config)) {
  17. return 'keras.json';
  18. }
  19. }
  20. if (tfjs.Container.open(context)) {
  21. return 'tfjs.json';
  22. }
  23. const pickle = context.open('pkl');
  24. if (pickle &&
  25. pickle.__class__ &&
  26. pickle.__class__.__module__ === 'keras.engine.sequential' &&
  27. pickle.__class__.__name__ === 'Sequential') {
  28. return 'keras.pickle';
  29. }
  30. return null;
  31. }
  32. open(context, match) {
  33. const openModel = (format, producer, backend, config, weights) => {
  34. return context.metadata('keras-metadata.json').then((metadata) => {
  35. return new keras.Model(metadata, format, producer, backend, config, weights);
  36. });
  37. };
  38. switch (match) {
  39. case 'keras.h5': {
  40. const find_root_group = (root_group) => {
  41. const kerasmodel = root_group.group('model/kerasmodel');
  42. if (kerasmodel && kerasmodel.attributes.has('model_config')) {
  43. return kerasmodel;
  44. }
  45. return root_group;
  46. };
  47. const read_model_config = (group) => {
  48. if (group.attributes.has('model_config')) {
  49. const buffer = group.attributes.get('model_config');
  50. const reader = json.TextReader.open(buffer);
  51. if (reader) {
  52. return reader.read();
  53. }
  54. }
  55. return null;
  56. };
  57. const load_attributes_from_hdf5_group = (group, name) => {
  58. if (group.attributes.has(name)) {
  59. return group.attributes.get(name);
  60. }
  61. if (group.attributes.has(name + '0')) {
  62. let index = 0;
  63. let value = [];
  64. while (group.attributes.has(name + index.toString())) {
  65. const chunk = group.attributes.get(name + index.toString());
  66. value = value.concat(chunk);
  67. index++;
  68. }
  69. return value;
  70. }
  71. return null;
  72. };
  73. const weights = new keras.Weights();
  74. const group = context.open('hdf5');
  75. const root_group = find_root_group(group);
  76. const model_config = read_model_config(root_group);
  77. if (model_config) {
  78. const backend = root_group.attributes.get('backend') || '';
  79. const version = root_group.attributes.get('keras_version') || '';
  80. const format = 'Keras' + (version ? ' v' + version : '');
  81. const model_weights_group = root_group.group('model_weights');
  82. if (model_weights_group) {
  83. const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
  84. for (const layer_name of layer_names) {
  85. const layer_weights = model_weights_group.group(layer_name);
  86. if (layer_weights) {
  87. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  88. if (Array.isArray(weight_names) && weight_names.length > 0) {
  89. for (const weight_name of weight_names) {
  90. const weight = layer_weights.group(weight_name);
  91. if (weight && weight.value) {
  92. const variable = weight.value;
  93. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, variable.littleEndian ? '<' : '>', variable.data);
  94. weights.add(layer_name, tensor);
  95. }
  96. }
  97. }
  98. }
  99. }
  100. }
  101. if (!model_config) {
  102. throw new keras.Error("'model_config' is not present.");
  103. }
  104. if (!model_config.class_name) {
  105. throw new keras.Error("'class_name' is not present.");
  106. }
  107. return openModel(format, '', backend, model_config, weights);
  108. }
  109. const layer_names = load_attributes_from_hdf5_group(root_group, 'layer_names');
  110. if (layer_names && Array.isArray(layer_names)) {
  111. const version = root_group.attributes.get('keras_version') || '';
  112. const format = 'Keras Weights' + (version ? ' v' + version : '');
  113. const backend = root_group.attributes.get('backend') || '';
  114. for (const layer_name of layer_names) {
  115. const layer_weights = root_group.group(layer_name);
  116. if (layer_weights) {
  117. const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
  118. if (Array.isArray(weight_names) && weight_names.length > 0) {
  119. for (const weight_name of weight_names) {
  120. const weight = layer_weights.group(weight_name);
  121. if (weight && weight.value) {
  122. const variable = weight.value;
  123. const components = weight_name.split('/');
  124. components.pop();
  125. const name = (components.length == 0 || components[0] !== layer_name) ? [ layer_name ].concat(components).join('/') : components.join('/');
  126. const layout = variable.littleEndian ? '<' : '>';
  127. const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, layout, variable.data);
  128. weights.add(name, tensor);
  129. }
  130. }
  131. }
  132. }
  133. }
  134. return openModel(format, '', backend, null, weights);
  135. }
  136. const rootKeys = new Set(root_group.attributes.keys());
  137. rootKeys.delete('nb_layers');
  138. if (rootKeys.size > 0 || root_group.value !== null) {
  139. throw new keras.Error('File format is not HDF5 Weights.');
  140. }
  141. const format = 'HDF5 Weights';
  142. let weights_group = root_group;
  143. if (root_group.attributes.size === 0 && root_group.value === null && root_group.groups.size == 1) {
  144. const group = root_group.groups.values().next().value;
  145. if (group.attributes.size === 0 && group.value === null) {
  146. weights_group = group;
  147. }
  148. }
  149. const tensorKeys = new Set([ 'name', 'shape', 'quantization' ]);
  150. const groups = Array.from(weights_group.groups.values());
  151. if (groups.every((group) => group.attributes.size === 0 && group.groups.length == 0 && group.value !== null)) {
  152. for (const group of groups) {
  153. const variable = group.value;
  154. const layout = variable.littleEndian ? '<' : '>';
  155. const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, layout, variable.type === 'string' ? variable.value : variable.data);
  156. weights.add('', tensor);
  157. }
  158. return openModel(format, '', '', null, weights);
  159. }
  160. if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
  161. for (const group of groups) {
  162. const moduleName = group.attributes.has('name') ? group.attributes.get('name') : group.name;
  163. for (const variableGroup of group.groups.values()) {
  164. if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
  165. throw new keras.Error('Variable format is not HDF5 Weights.');
  166. }
  167. const variable = variableGroup.value;
  168. if (!variable) {
  169. throw new keras.Error('Variable value is not HDF5 Weights.');
  170. }
  171. const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
  172. const layout = variable.littleEndian ? '<' : '>';
  173. const tensor = new keras.Tensor(name, variable.shape, variable.type, null, layout, variable.type === 'string' ? variable.value : variable.data);
  174. weights.add(moduleName, tensor);
  175. }
  176. }
  177. return openModel(format, '', '', null, weights);
  178. }
  179. const walk = function(group) {
  180. if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
  181. for (const subGroup of group.groups.values()) {
  182. walk(subGroup);
  183. }
  184. return;
  185. }
  186. const subKeys = new Set([ 'index', 'need_grad' ]);
  187. const attribtues = Array.from(group.attributes.keys());
  188. const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
  189. if (match && group.value !== null && group.groups.size === 0) {
  190. const variable = group.value;
  191. const variableName = group.path;
  192. let moduleName = variableName;
  193. const parts = variableName.split('/');
  194. if (parts.length > 1) {
  195. parts.pop();
  196. moduleName = parts.join('/');
  197. }
  198. const layout = variable.littleEndian ? '<' : '>';
  199. const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, layout, variable.type === 'string' ? variable.value : variable.data);
  200. weights.add(moduleName, tensor);
  201. return;
  202. }
  203. throw new keras.Error('Module group format is not HDF5 Weights.');
  204. };
  205. walk(weights_group);
  206. return openModel(format, '', '', null, weights);
  207. }
  208. case 'keras.json': {
  209. const obj = context.open('json');
  210. const format = 'Keras' + (obj.keras_version ? ' v' + obj.keras_version : '');
  211. const backend = obj.backend || '';
  212. const config = obj.model_config ? obj.model_config : obj;
  213. const weights = new keras.Weights();
  214. return openModel(format, '', backend, config, weights);
  215. }
  216. case 'tfjs.json': {
  217. const container = tfjs.Container.open(context);
  218. return container.open().then(() => {
  219. return openModel(container.format, container.producer, container.backend, container.config, container.weights);
  220. });
  221. }
  222. case 'keras.pickle': {
  223. const execution = new python.Execution();
  224. const obj = context.open('pkl');
  225. const decoder = new TextDecoder('utf-8');
  226. const format = 'Keras Pickle' + (obj.keras_version ? ' v' + decoder.decode(obj.keras_version) : '');
  227. const backend = obj.backend ? decoder.decode(obj.backend) : '';
  228. const reader = json.TextReader.open(obj.model_config);
  229. const model_config = reader.read();
  230. const weights = new keras.Weights();
  231. const model_weights_group = obj.model_weights;
  232. if (model_weights_group) {
  233. const layer_names = model_weights_group.layer_names.map((buffer) => decoder.decode(buffer));
  234. for (const layer_name of layer_names) {
  235. const layer_weights = model_weights_group[layer_name];
  236. if (layer_weights) {
  237. const weight_names = layer_weights.weight_names.map((buffer) => decoder.decode(buffer));
  238. if (Array.isArray(weight_names) && weight_names.length > 0) {
  239. for (const weight_name of weight_names) {
  240. const buffer = layer_weights[weight_name];
  241. const unpickler = execution.invoke('pickle.Unpickler', [ buffer ]);
  242. const variable = unpickler.load();
  243. const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.__name__, null, '<', variable.data);
  244. weights.add(layer_name, tensor);
  245. }
  246. }
  247. }
  248. }
  249. }
  250. return openModel(format, '', backend, model_config, weights);
  251. }
  252. default: {
  253. throw new keras.Error("Unsupported Keras format '" + match + "'.");
  254. }
  255. }
  256. }
  257. };
  258. keras.Model = class {
  259. constructor(metadata, format, producer, backend, config, weights) {
  260. this._format = format;
  261. this._backend = backend;
  262. this._producer = producer;
  263. metadata = new keras.GraphMetadata(metadata);
  264. this._graphs = [ new keras.Graph(metadata, config, weights) ];
  265. }
  266. get name() {
  267. return null;
  268. }
  269. get description() {
  270. return null;
  271. }
  272. get format() {
  273. return this._format;
  274. }
  275. get producer() {
  276. return this._producer;
  277. }
  278. get runtime() {
  279. return this._backend;
  280. }
  281. get graphs() {
  282. return this._graphs;
  283. }
  284. };
  285. keras.Graph = class {
  286. constructor(metadata, config, weights, group) {
  287. this._metadata = metadata;
  288. this._inputs = [];
  289. this._outputs = [];
  290. this._nodes = [];
  291. group = group || '';
  292. const loadNode = (layer, inputs, outputs, weights, group) => {
  293. layer = Object.assign({}, layer);
  294. layer.inputs = inputs;
  295. layer.outputs = outputs;
  296. return new keras.Node(this._metadata, layer, group, weights);
  297. };
  298. const getInputType = (layer) => {
  299. if (layer && layer.config) {
  300. let dataType = '?';
  301. let shape = [];
  302. const config = layer.config;
  303. if (config.dtype) {
  304. dataType = config.dtype;
  305. delete config.dtype;
  306. }
  307. if (config.batch_input_shape) {
  308. shape = config.batch_input_shape.map(s => s == null ? '?' : s);
  309. delete config.batch_input_shape;
  310. }
  311. return new keras.TensorType(dataType, new keras.TensorShape(shape));
  312. }
  313. return null;
  314. };
  315. if (config) {
  316. this._name = config.name || (config.config && config.config.name ? config.config.name : '');
  317. const is_connection = (item) => {
  318. return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string' && typeof item[1] === 'number' && typeof item[2] === 'number';
  319. };
  320. const is_constant = (item) => {
  321. return Array.isArray(item) && (item.length === 3 || item.length === 4) && item[0] === '_CONSTANT_VALUE' && item[1] === -1;
  322. };
  323. switch (config.class_name) {
  324. case 'AllCNN':
  325. case 'Sequential': {
  326. config = config.config;
  327. const inputs = null;
  328. const outputs = null;
  329. const inputName = 'input';
  330. let inputType = null;
  331. let argument = inputName;
  332. let index = 0;
  333. const layers = config.layers ? config.layers : config;
  334. for (const layer of layers) {
  335. let name = index.toString();
  336. let nodeInputs = [ { name: argument } ];
  337. if (index == 0) {
  338. if (inputs && inputs.length > 0) {
  339. nodeInputs = [ inputs[0] ];
  340. }
  341. else {
  342. inputType = getInputType(layer);
  343. }
  344. }
  345. index++;
  346. if (layer.config && layer.config.name) {
  347. name = layer.config.name;
  348. }
  349. argument = name;
  350. let nodeOutputs = [ argument ];
  351. if (index == layers.length) {
  352. if (outputs && outputs.length > 0) {
  353. nodeOutputs = [ outputs[0] ];
  354. argument = null;
  355. }
  356. }
  357. this.nodes.push(loadNode(layer, nodeInputs, nodeOutputs, weights, group));
  358. }
  359. if (!inputs) {
  360. this._inputs.push(new keras.Parameter(inputName, true, [ new keras.Argument(inputName, inputType, null) ]));
  361. }
  362. if (argument) {
  363. this._outputs.push(new keras.Parameter(argument, true, [ new keras.Argument(argument, null, null) ]));
  364. }
  365. break;
  366. }
  367. case 'Functional':
  368. case 'Model': {
  369. config = config.config;
  370. const nodes = new Map();
  371. if (config.layers) {
  372. for (const layer of config.layers) {
  373. layer.inputs = [];
  374. layer.outputs = [];
  375. layer.args = {};
  376. if (layer.name && !nodes.has(layer.name)) {
  377. nodes.set(layer.name, layer);
  378. }
  379. }
  380. const read_connection = (input_data) => {
  381. let name = input_data[0];
  382. const node = nodes.get(name);
  383. if (node) {
  384. // const node_index = input_data[1];
  385. const tensor_index = input_data[2];
  386. if (tensor_index !== 0) {
  387. name += ':' + tensor_index.toString();
  388. }
  389. while (tensor_index >= node.outputs.length) {
  390. node.outputs.push('');
  391. }
  392. node.outputs[tensor_index] = name;
  393. }
  394. return { name: name };
  395. };
  396. const read_value = (input_data) => {
  397. if (!Array.isArray(input_data)) {
  398. return { shape: [], value: [ input_data ] };
  399. }
  400. const shape = (value) => {
  401. if (value.every((item) => is_constant(item))) {
  402. for (let i = 0; i < value.length; i++) {
  403. value[i] = value[i][2];
  404. }
  405. }
  406. else if (value.every((item) => Array.isArray(item))) {
  407. const dims = value.map((item) => shape(item));
  408. const dim = dims[0];
  409. for (let i = 1; i < dims.length; i++) {
  410. if (dim.length === dims[i].length) {
  411. if (!dims[i].every((value, i) => value ===dim[i])) {
  412. throw new python.Error('Invalid array shape.');
  413. }
  414. }
  415. }
  416. return [ value.length ].concat(dim);
  417. }
  418. return [ value.length ];
  419. };
  420. const flatten = (input) => input.reduce((a, b) => a.concat(Array.isArray(b) ? flatten(b) : b), []);
  421. return { shape: shape(input_data), value: flatten(input_data) };
  422. };
  423. for (const layer of config.layers) {
  424. if (layer.inbound_nodes) {
  425. for (const inbound_node of layer.inbound_nodes) {
  426. if (is_constant(inbound_node)) {
  427. layer.inputs.push(read_value(inbound_node[2]));
  428. const args = inbound_node[3] || {};
  429. layer.args = {};
  430. for (const entry of Object.entries(args)) {
  431. const key = entry[0];
  432. const value = entry[1];
  433. layer.args[key] = is_connection(value) ? read_connection(value) : read_value(value);
  434. }
  435. }
  436. else if (is_connection(inbound_node)) {
  437. layer.inputs.push(read_connection(inbound_node));
  438. const args = inbound_node[3] || {};
  439. layer.args = {};
  440. for (const entry of Object.entries(args)) {
  441. const key = entry[0];
  442. const value = entry[1];
  443. layer.args[key] = is_connection(value) ? read_connection(value) : read_value(value);
  444. }
  445. }
  446. else if (Array.isArray(inbound_node)) {
  447. for (const input_data of inbound_node) {
  448. if (is_connection(input_data)) {
  449. layer.inputs.push(read_connection(input_data));
  450. }
  451. else if (Array.isArray(input_data) && input_data.every((item) => is_connection(item))) {
  452. for (const input of input_data) {
  453. layer.inputs.push(read_connection(input));
  454. }
  455. }
  456. else if (Array.isArray(input_data)) {
  457. layer.inputs.push(read_value(input_data));
  458. }
  459. else {
  460. throw new keras.Error("Invalid inbound connection '" + JSON.stringify(input_data) + "'.");
  461. }
  462. }
  463. }
  464. else {
  465. throw new keras.Error("Invalid inbound node '" + JSON.stringify(inbound_node) + "'.");
  466. }
  467. }
  468. }
  469. }
  470. }
  471. const input_layers = is_connection(config.input_layers) ? [ config.input_layers ] : config.input_layers;
  472. if (input_layers) {
  473. for (let i = 0; i < input_layers.length; i++) {
  474. const input_layer = input_layers[i];
  475. const name = input_layer[0];
  476. let type = null;
  477. const node = nodes.get(name);
  478. if (node && node.class_name == 'InputLayer') {
  479. type = getInputType(node);
  480. nodes.delete(name);
  481. }
  482. const argument = new keras.Argument(name, type, null);
  483. const parameter = new keras.Parameter(name, true, [ argument ]);
  484. this._inputs.push(parameter);
  485. }
  486. }
  487. const output_layers = is_connection(config.output_layers) ? [ config.output_layers ] : config.output_layers;
  488. if (output_layers) {
  489. for (let j = 0; j < output_layers.length; j++) {
  490. const output_layer = output_layers[j];
  491. let outputName = output_layer[0];
  492. const outputNode = nodes.get(outputName);
  493. if (outputNode) {
  494. const outputIndex = output_layer[2];
  495. if (outputIndex != 0) {
  496. outputName += ':' + outputIndex.toString();
  497. }
  498. while (outputIndex >= outputNode.outputs.length) {
  499. outputNode.outputs.push('');
  500. }
  501. outputNode.outputs[outputIndex] = outputName;
  502. }
  503. const argument = new keras.Argument(outputName, null, null);
  504. const parameter = new keras.Parameter(outputName, true, [ argument ]);
  505. this._outputs.push(parameter);
  506. }
  507. }
  508. if (config.layers) {
  509. for (const layer of config.layers) {
  510. if (nodes.has(layer.name)) {
  511. this._nodes.push(loadNode(layer, layer.inputs, layer.outputs, weights, group));
  512. }
  513. }
  514. }
  515. break;
  516. }
  517. default:
  518. throw new keras.Error('\'' + config.class_name + '\' is not supported.');
  519. }
  520. }
  521. else if (weights) {
  522. for (const name of weights.keys()) {
  523. if (weights.get('', name).length <= 6) {
  524. const layer = { class_name: 'Weights', config: { name: name } };
  525. const node = new keras.Node(metadata, layer, '', weights);
  526. this._nodes.push(node);
  527. }
  528. }
  529. }
  530. }
  531. get name() {
  532. return this._name;
  533. }
  534. get inputs() {
  535. return this._inputs;
  536. }
  537. get outputs() {
  538. return this._outputs;
  539. }
  540. get nodes() {
  541. return this._nodes;
  542. }
  543. };
  544. keras.Parameter = class {
  545. constructor(name, visible, args) {
  546. this._name = name;
  547. this._visible = visible;
  548. this._arguments = args;
  549. }
  550. get name() {
  551. return this._name;
  552. }
  553. get visible() {
  554. return this._visible;
  555. }
  556. get arguments() {
  557. return this._arguments;
  558. }
  559. };
  560. keras.Argument = class {
  561. constructor(name, type, initializer) {
  562. if (typeof name !== 'string') {
  563. throw new keras.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  564. }
  565. this._name= name;
  566. this._type = type || null;
  567. this._initializer = initializer || null;
  568. }
  569. get name() {
  570. return this._name;
  571. }
  572. get type() {
  573. if (this._initializer) {
  574. return this._initializer.type;
  575. }
  576. return this._type;
  577. }
  578. get quantization() {
  579. if (this._initializer) {
  580. return this._initializer.quantization;
  581. }
  582. return null;
  583. }
  584. get initializer() {
  585. return this._initializer;
  586. }
  587. };
  588. keras.Node = class {
  589. constructor(metadata, layer, group, weights) {
  590. const config = layer.config || {};
  591. const args = layer.args || {};
  592. let inputs = layer.inputs || [];
  593. let outputs = layer.outputs || [];
  594. const name = config && config.name ? config.name : '';
  595. this._group = group || '';
  596. this._name = (this._group ? this._group + '/' : '') + name;
  597. this._inputs = [];
  598. this._outputs = [];
  599. this._attributes = [];
  600. this._chain = [];
  601. let names = [ name ];
  602. let type = layer.class_name;
  603. let model = false;
  604. switch (type) {
  605. case 'Model':
  606. case 'Functional':
  607. case 'Sequential': {
  608. const name = layer.name || (layer.config ? layer.config.name : '');
  609. this._type = new keras.Graph(metadata, layer, weights, (group ? group + '/' : '') + name);
  610. model = true;
  611. if (config) {
  612. delete config.layers;
  613. delete config.input_layers;
  614. delete config.output_layers;
  615. }
  616. this._inputs = [ new keras.Parameter('inputs', true, inputs.map((input) => new keras.Argument(input.name, null, null))) ];
  617. this._outputs = [ new keras.Parameter('outputs', true, outputs.map((name) => new keras.Argument(name, null, null))) ];
  618. inputs = [];
  619. outputs = [];
  620. break;
  621. }
  622. case 'Bidirectional':
  623. case 'TimeDistributed': {
  624. if (config && config.layer) {
  625. const inner = config.layer;
  626. delete config.layer;
  627. this._inner = new keras.Node(metadata, inner, null, null);
  628. if (type == 'Bidirectional' && inner.config.name) {
  629. names = [ name + '/forward_' + inner.config.name, name + '/backward_' + inner.config.name ];
  630. if (!group) {
  631. group = name;
  632. }
  633. }
  634. }
  635. this._type = metadata.type(type) || { name: type };
  636. break;
  637. }
  638. case 'TFOpLambda': {
  639. if (config && config.function) {
  640. type = config.function;
  641. delete config.function;
  642. }
  643. this._type = metadata.type(type) || { name: type };
  644. break;
  645. }
  646. default: {
  647. this._type = metadata.type(type) || { name: type };
  648. break;
  649. }
  650. }
  651. const initializers = {};
  652. if (weights && !model) {
  653. for (const name of names) {
  654. let tensors = weights.get(group, name);
  655. if (tensors.length > 0) {
  656. for (const initializer of tensors) {
  657. inputs.push({ name: initializer.name });
  658. initializers[initializer.name] = initializer;
  659. }
  660. }
  661. else {
  662. tensors = weights.get('', name);
  663. for (const initializer of tensors) {
  664. inputs.push({ name: initializer.name });
  665. initializers[initializer.name] = initializer;
  666. }
  667. }
  668. }
  669. }
  670. if (config && !Array.isArray(config)) {
  671. for (const entry of Object.entries(config)) {
  672. const name = entry[0];
  673. const value = entry[1];
  674. if (name === 'activation' && value !== 'linear') {
  675. if (typeof value === 'string') {
  676. const set = new Map([ [ 'elu', 'ELU' ], [ 'exponential', 'Exponential' ], [ 'hard_sigmoid', 'HardSigmoid' ], [ 'linear', 'Linear' ], [ 'relu', 'ReLU' ], [ 'selu', 'SELU' ], [ 'softmax', 'Softmax'], [ 'sigmoid', 'Sigmoid' ], [ 'softplus', 'SoftPlus' ], [ 'softsign', 'SoftSign' ], [ 'tanh', 'TanH' ] ]);
  677. const type = set.has(value) ? set.get(value) : value;
  678. this.chain.push(new keras.Node(metadata, { class_name: type }, null, null));
  679. }
  680. else if (value && typeof value.class_name === 'string' && value.config) {
  681. const type = value.class_name;
  682. if (!metadata.type(type)) {
  683. metadata.add(type, { name: type, category: 'Activation' });
  684. }
  685. this.chain.push(new keras.Node(metadata, value, null, null));
  686. }
  687. }
  688. if (name !== 'name') {
  689. const attribute = new keras.Attribute(metadata.attribute(type, name), name, value);
  690. this._attributes.push(attribute);
  691. }
  692. }
  693. }
  694. const innerType = this.inner ? this.inner.type : null;
  695. const innerSchema = innerType ? metadata.type(innerType) : null;
  696. let inputIndex = 0;
  697. while (inputs.length > 0) {
  698. let list = false;
  699. let inputName = null;
  700. let visible = true;
  701. if (!innerSchema || inputIndex == 0) {
  702. if (this._type && this._type.inputs && inputIndex < this._type.inputs.length) {
  703. const input = this._type.inputs[inputIndex];
  704. inputName = input.name;
  705. if (type === 'BatchNormalization' && inputName === 'gamma' && config.scale === false) {
  706. inputIndex++;
  707. continue;
  708. }
  709. visible = input.visible == false ? false : true;
  710. if (this._type.inputs[inputIndex].list) {
  711. list = true;
  712. }
  713. }
  714. }
  715. else {
  716. switch (type) {
  717. case 'Bidirectional': {
  718. let innerIndex = inputIndex;
  719. if (innerSchema && innerSchema.inputs) {
  720. if (innerIndex < innerSchema.inputs.length) {
  721. inputName = 'forward_' + innerSchema.inputs[innerIndex].name;
  722. }
  723. else {
  724. innerIndex = innerIndex - innerSchema.inputs.length + 1;
  725. if (innerIndex < innerSchema.inputs.length) {
  726. inputName = 'backward_' + innerSchema.inputs[innerIndex].name;
  727. }
  728. }
  729. }
  730. visible = false;
  731. break;
  732. }
  733. case 'TimeDistributed':
  734. if (innerSchema && innerSchema.inputs && inputIndex < innerSchema.inputs.length) {
  735. inputName = innerSchema.inputs[inputIndex].name;
  736. }
  737. break;
  738. default:
  739. break;
  740. }
  741. }
  742. const input = !list ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
  743. const inputArguments = input.map((input) => {
  744. if (input.name) {
  745. return new keras.Argument(input.name, null, initializers[input.name]);
  746. }
  747. if (input.value !== undefined) {
  748. const tensor = new keras.Tensor('', input.shape, config.dtype || '?', null, '|', input.value);
  749. return new keras.Argument('', null, tensor);
  750. }
  751. throw new keras.Error("Invalid argument '" + JSON.stringify(input.name) + "'.");
  752. });
  753. if (!inputName && inputArguments.length == 1 && inputArguments[0].initializer && inputArguments[0].initializer.name) {
  754. if (names.length === 1 && names[0] === '') {
  755. inputName = inputArguments[0].initializer.name;
  756. }
  757. else {
  758. const parts = inputArguments[0].initializer.name.split('/').pop().split(':').shift().split('_');
  759. const inputName1 = parts.pop();
  760. const inputName2 = parts.length > 0 ? [ parts.pop(), inputName1 ].join('_') : '';
  761. const inputNames = new Set([ 'recurrent_kernel', 'running_mean', 'running_std', 'moving_mean', 'moving_variance', 'depthwise_filter', 'pointwise_filter' ]);
  762. inputName = inputNames.has(inputName2) ? inputName2 : inputName1;
  763. }
  764. }
  765. this._inputs.push(new keras.Parameter(inputName || inputIndex.toString(), visible, inputArguments));
  766. inputIndex++;
  767. }
  768. for (let i = 0; i < outputs.length; i++) {
  769. const output = outputs[i];
  770. const outputName = (this._type && this._type.outputs && i < this._type.outputs.length && this._type.outputs[i] && this._type.outputs[i].name) ? this._type.outputs[i].name : i.toString();
  771. const parameter = new keras.Parameter(outputName, true, [ new keras.Argument(output, null, null) ]);
  772. this._outputs.push(parameter);
  773. }
  774. const inputTypes = new Map((this._type.inputs || []).map((input) => [ input.name, input.type ]));
  775. for (const entry of Object.entries(args)) {
  776. const name = entry[0];
  777. const value = entry[1];
  778. if (name !== 'name') {
  779. if (value.name || (inputTypes.has(name) && inputTypes.get(name) === 'Tensor' && value)) {
  780. if (value.name) {
  781. const argument = new keras.Argument(value.name, null, null);
  782. const parameter = new keras.Parameter(name, true, [ argument ]);
  783. this._inputs.push(parameter);
  784. }
  785. else {
  786. const tensor = new keras.Tensor('', value.shape, config.dtype || '?', null, '|', value.value);
  787. const argument = new keras.Argument('', null, tensor);
  788. const parameter = new keras.Parameter(name, true, [ argument ]);
  789. this._inputs.push(parameter);
  790. }
  791. }
  792. else {
  793. const attribute = new keras.Attribute(metadata.attribute(type, name), name, value);
  794. this._attributes.push(attribute);
  795. }
  796. }
  797. }
  798. if (typeof this.type.name !== 'string' || !this.type.name.split) { // #416
  799. throw new keras.Error("Unsupported node type '" + JSON.stringify(this.type.name) + "'.");
  800. }
  801. }
  802. get type() {
  803. return this._type;
  804. }
  805. get name() {
  806. return this._name;
  807. }
  808. get inputs() {
  809. return this._inputs;
  810. }
  811. get outputs() {
  812. return this._outputs;
  813. }
  814. get attributes() {
  815. return this._attributes;
  816. }
  817. get chain() {
  818. return this._chain;
  819. }
  820. get inner() {
  821. return this._inner;
  822. }
  823. };
  824. keras.Attribute = class {
  825. constructor(metadata, name, value) {
  826. this._name = name;
  827. this._value = value;
  828. if (value && typeof value == 'object' && value.class_name && value.config) {
  829. this._value = keras.Attribute._convert(value);
  830. }
  831. switch (name) {
  832. case 'trainable':
  833. this._type = 'boolean';
  834. this._visible = false;
  835. break;
  836. case 'dtype':
  837. this._visible = false;
  838. break;
  839. default: {
  840. if (metadata) {
  841. if (metadata.type) {
  842. this._type = metadata.type;
  843. }
  844. if (Object.prototype.hasOwnProperty.call(metadata, 'visible')) {
  845. this._visible = metadata.visible;
  846. }
  847. else if (metadata.default !== undefined) {
  848. if (Array.isArray(value)) {
  849. if (Array.isArray(metadata.default)) {
  850. this._visible = value.length !== metadata.default || !this.value.every((item, index) => item == metadata.default[index]);
  851. }
  852. else {
  853. this._visible = !this.value.every((item) => item == metadata.default);
  854. }
  855. }
  856. else {
  857. this._visible = this.value !== metadata.default;
  858. }
  859. }
  860. }
  861. break;
  862. }
  863. }
  864. }
  865. get name() {
  866. return this._name;
  867. }
  868. get type() {
  869. return this._type;
  870. }
  871. get value() {
  872. return this._value;
  873. }
  874. get visible() {
  875. return this._visible == false ? false : true;
  876. }
  877. static _convert(value) {
  878. if (Array.isArray(value) || value !== Object(value)) {
  879. return value;
  880. }
  881. const obj = {};
  882. if (value.class_name) {
  883. obj.__type__ = value.class_name;
  884. }
  885. if (value.config) {
  886. for (const entry of Object.entries(value.config)) {
  887. const key = entry[0];
  888. const value = entry[1];
  889. obj[key] = keras.Attribute._convert(value);
  890. }
  891. }
  892. return obj;
  893. }
  894. };
  895. keras.Tensor = class {
  896. constructor(name, shape, type, quantization, layout, data) {
  897. this._name = name;
  898. this._type = new keras.TensorType(type, new keras.TensorShape(shape));
  899. this._quantization = quantization;
  900. this._layout = layout;
  901. this._data = data;
  902. }
  903. get name() {
  904. return this._name;
  905. }
  906. get type() {
  907. return this._type;
  908. }
  909. get layout() {
  910. return this._layout;
  911. }
  912. get quantization() {
  913. if (this._quantization && (this._quantization.scale !== 0 || this._quantization.min !== 0)) {
  914. const scale = this._quantization.scale || 0;
  915. const min = this._quantization.min || 0;
  916. return scale.toString() + ' * ' + (min == 0 ? 'q' : ('(q - ' + min.toString() + ')'));
  917. }
  918. return null;
  919. }
  920. get values() {
  921. if (this._layout === '|') {
  922. return this._data;
  923. }
  924. if (this._data === null) {
  925. return null;
  926. }
  927. return this._data instanceof Uint8Array ? this._data : this._data.peek();
  928. }
  929. };
  930. keras.TensorType = class {
  931. constructor(dataType, shape) {
  932. this._dataType = dataType;
  933. this._shape = shape;
  934. }
  935. get dataType() {
  936. return this._dataType;
  937. }
  938. get shape() {
  939. return this._shape;
  940. }
  941. toString() {
  942. return this._dataType + this._shape.toString();
  943. }
  944. };
  945. keras.TensorShape = class {
  946. constructor(dimensions) {
  947. this._dimensions = dimensions;
  948. }
  949. get dimensions() {
  950. return this._dimensions;
  951. }
  952. toString() {
  953. return this._dimensions && this._dimensions.length > 0 ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  954. }
  955. };
  956. keras.GraphMetadata = class {
  957. constructor(metadata) {
  958. this._metadata = metadata;
  959. this._types = new Map();
  960. }
  961. type(name) {
  962. if (this._types.has(name)) {
  963. return this._types.get(name);
  964. }
  965. return this._metadata.type(name);
  966. }
  967. attribute(type, name) {
  968. return this._metadata.attribute(type, name);
  969. }
  970. add(type, metadata) {
  971. this._types.set(type, metadata);
  972. }
  973. };
  974. keras.Weights = class {
  975. constructor() {
  976. this._map = new Map();
  977. }
  978. add(layer_name, tensor) {
  979. if (!this._map.has(layer_name)) {
  980. this._map.set(layer_name, []);
  981. }
  982. this._map.get(layer_name).push(tensor);
  983. }
  984. get(group, name) {
  985. if (group) {
  986. const list = this._map.get(group.split('/').shift());
  987. if (list) {
  988. const match1 = list.filter((tensor) => tensor.name.startsWith(name + '/'));
  989. if (match1.length > 0) {
  990. return match1;
  991. }
  992. const match2 = list.filter((tensor) => tensor.name.startsWith(group + '/' + name + '/'));
  993. if (match2.length > 0) {
  994. return match2;
  995. }
  996. }
  997. }
  998. else {
  999. const match1 = this._map.get(name);
  1000. if (match1 && match1.length > 0) {
  1001. return match1;
  1002. }
  1003. const match2 = this._map.get('');
  1004. if (match2 && match2.length > 0) {
  1005. const match3 = match2.filter((tensor) => tensor.name.startsWith((group ? group + '/' : '') + name + '/'));
  1006. if (match3.length > 0) {
  1007. return match3;
  1008. }
  1009. }
  1010. }
  1011. return [];
  1012. }
  1013. keys() {
  1014. return this._map.keys();
  1015. }
  1016. };
  1017. keras.Error = class extends Error {
  1018. constructor(message) {
  1019. super(message);
  1020. this.name = 'Error loading Keras model.';
  1021. }
  1022. };
  1023. tfjs.Container = class {
  1024. static open(context) {
  1025. const json = context.open('json');
  1026. if (json) {
  1027. if (json.modelTopology && (json.format === 'layers-model' || json.modelTopology.class_name || json.modelTopology.model_config)) {
  1028. return new tfjs.Container(context, '');
  1029. }
  1030. if (Array.isArray(json) && json.every((item) => item.weights && item.paths)) {
  1031. return new tfjs.Container(context, 'weights');
  1032. }
  1033. if (json.tfjsVersion) {
  1034. return new tfjs.Container(context, 'metadata');
  1035. }
  1036. }
  1037. return null;
  1038. }
  1039. constructor(context, type) {
  1040. this._context = context;
  1041. this._type = type;
  1042. }
  1043. get format() {
  1044. return this._format;
  1045. }
  1046. get producer() {
  1047. return this._producer || '';
  1048. }
  1049. get backend() {
  1050. return this._backend || '';
  1051. }
  1052. get config() {
  1053. return this._config;
  1054. }
  1055. get weights() {
  1056. return this._weights;
  1057. }
  1058. open() {
  1059. switch (this._type) {
  1060. case '': {
  1061. const obj = this._context.open('json');
  1062. return this._openModelJson(obj);
  1063. }
  1064. case 'weights': {
  1065. this._format = 'TensorFlow.js Weights';
  1066. this._config = null;
  1067. const obj = this._context.open('json');
  1068. const manifests = Array.from(obj);
  1069. for (const manifest of manifests) {
  1070. for (const weight of manifest.weights) {
  1071. const name = weight.name;
  1072. const index = name.lastIndexOf('/');
  1073. weight.identifier = index === -1 ? name : name.substring(0, index);
  1074. }
  1075. }
  1076. return this._openManifests(manifests);
  1077. }
  1078. case 'metadata': {
  1079. return this._context.request('model.json').then((stream) => {
  1080. const reader = json.TextReader.open(stream);
  1081. const obj = reader.read();
  1082. return this._openModelJson(obj);
  1083. });
  1084. }
  1085. default: {
  1086. throw new tfjs.Error("Unsupported TensorFlow.js format '" + this._type + "'.");
  1087. }
  1088. }
  1089. }
  1090. _openShards(manifests, shards) {
  1091. this._weights = new keras.Weights();
  1092. const dtype_size_map = new Map([
  1093. [ 'float16', 2 ], [ 'float32', 4 ], [ 'float64', 8 ],
  1094. [ 'int8', 1 ], [ 'int16', 2 ], [ 'int32', 4 ], [ 'int64', 8 ],
  1095. [ 'uint8', 1 ], [ 'uint16', 2 ], [ 'uint32', 4 ], [ 'uint64', 8 ]
  1096. ]);
  1097. for (const manifest of manifests) {
  1098. let buffer = null;
  1099. if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
  1100. const list = manifest.paths.map((path) => shards.get(path));
  1101. const size = list.reduce((a, b) => a + b.length, 0);
  1102. buffer = new Uint8Array(size);
  1103. let offset = 0;
  1104. for (const item of list) {
  1105. buffer.set(item, offset);
  1106. offset += item.length;
  1107. }
  1108. }
  1109. let offset = 0;
  1110. for (const weight of manifest.weights) {
  1111. const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
  1112. if (!dtype_size_map.has(dtype)) {
  1113. throw new keras.Error("Unsupported weight data type size '" + dtype + "'.");
  1114. }
  1115. const itemsize = dtype_size_map.get(dtype);
  1116. const size = weight.shape.reduce((a, b) => a * b, 1);
  1117. const length = itemsize * size;
  1118. const data = buffer ? buffer.slice(offset, offset + length) : null;
  1119. this._weights.add(weight.identifier, new keras.Tensor(weight.name, weight.shape, dtype, weight.quantization, '<', data));
  1120. offset += length;
  1121. }
  1122. }
  1123. }
  1124. _openManifests(manifests) {
  1125. const shards = new Map();
  1126. for (const manifest of manifests) {
  1127. for (const path of manifest.paths) {
  1128. if (!shards.has(path)) {
  1129. const promise = this._context.request(path, null);
  1130. shards.set(path, promise);
  1131. }
  1132. }
  1133. }
  1134. const promises = shards.values();
  1135. return Promise.all(promises).then((streams) => {
  1136. for (const key of shards.keys()) {
  1137. shards.set(key, streams.shift().peek());
  1138. }
  1139. this._openShards(manifests, shards);
  1140. return;
  1141. }).catch(() => {
  1142. shards.clear();
  1143. this._openShards(manifests, shards);
  1144. return;
  1145. });
  1146. }
  1147. _openModelJson(obj) {
  1148. const modelTopology = obj.modelTopology;
  1149. this._format = 'TensorFlow.js ' + (obj.format ? obj.format : 'Keras' + (modelTopology.keras_version ? (' v' + modelTopology.keras_version) : ''));
  1150. this._producer = obj.convertedBy || obj.generatedBy || '';
  1151. this._backend = modelTopology.backend || '';
  1152. const manifests = obj.weightsManifest;
  1153. for (const manifest of manifests) {
  1154. for (const weight of manifest.weights) {
  1155. weight.identifier = '';
  1156. }
  1157. }
  1158. this._config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
  1159. return this._openManifests(manifests);
  1160. }
  1161. };
  1162. tfjs.Error = class extends Error {
  1163. constructor(message) {
  1164. super(message);
  1165. this.name = 'Error loading TensorFlow.js model.';
  1166. }
  1167. };
  1168. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1169. module.exports.ModelFactory = keras.ModelFactory;
  1170. }